Spaces:
Running
Running
Anirudh Esthuri
commited on
Commit
·
3a73f5d
1
Parent(s):
4cd3a4a
Add Gemini 3.0 Pro and Gemini 2.5 Flash models support via AWS Bedrock
Browse files- llm.py +101 -2
- model_config.py +16 -0
llm.py
CHANGED
|
@@ -172,6 +172,105 @@ def chat(messages, persona):
|
|
| 172 |
).strip()
|
| 173 |
total_tok = len(text.split())
|
| 174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
return text, dt, total_tok, (total_tok / dt if dt else total_tok)
|
| 176 |
elif provider == "deepseek":
|
| 177 |
print("Using deepseek: ", MODEL_STRING)
|
|
@@ -378,8 +477,8 @@ def check_credentials():
|
|
| 378 |
# print(f"Ollama connection failed: {e}")
|
| 379 |
# return False
|
| 380 |
|
| 381 |
-
# Check if using Bedrock providers (anthropic, meta, mistral, deepseek)
|
| 382 |
-
bedrock_providers = ["anthropic"]
|
| 383 |
if MODEL_TO_PROVIDER.get(MODEL_STRING) in bedrock_providers:
|
| 384 |
# Test AWS Bedrock connection by trying to invoke a simple model
|
| 385 |
try:
|
|
|
|
| 172 |
).strip()
|
| 173 |
total_tok = len(text.split())
|
| 174 |
|
| 175 |
+
return text, dt, total_tok, (total_tok / dt if dt else total_tok)
|
| 176 |
+
elif provider == "google":
|
| 177 |
+
print("Using google (Gemini): ", MODEL_STRING)
|
| 178 |
+
t0 = time.time()
|
| 179 |
+
|
| 180 |
+
# Add system prompt for better behavior
|
| 181 |
+
system_prompt = ""
|
| 182 |
+
|
| 183 |
+
# Convert messages to Gemini format
|
| 184 |
+
# Gemini uses "user" and "model" roles, and content is an array
|
| 185 |
+
gemini_messages = []
|
| 186 |
+
for msg in messages:
|
| 187 |
+
role = msg.get("role", "user")
|
| 188 |
+
# Gemini uses "model" instead of "assistant"
|
| 189 |
+
if role == "assistant":
|
| 190 |
+
role = "model"
|
| 191 |
+
gemini_messages.append({
|
| 192 |
+
"role": role,
|
| 193 |
+
"parts": [{"text": msg["content"]}]
|
| 194 |
+
})
|
| 195 |
+
|
| 196 |
+
try:
|
| 197 |
+
bedrock_runtime = get_bedrock_client()
|
| 198 |
+
|
| 199 |
+
# Use inference profile ARN if available (for provisioned throughput models)
|
| 200 |
+
# Otherwise use modelId (for on-demand models)
|
| 201 |
+
invoke_kwargs = {
|
| 202 |
+
"contentType": "application/json",
|
| 203 |
+
"accept": "application/json",
|
| 204 |
+
"body": json.dumps(
|
| 205 |
+
{
|
| 206 |
+
"contents": gemini_messages,
|
| 207 |
+
"generationConfig": {
|
| 208 |
+
"maxOutputTokens": 4000,
|
| 209 |
+
"temperature": 0.3,
|
| 210 |
+
}
|
| 211 |
+
}
|
| 212 |
+
),
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
# Add system instruction if provided
|
| 216 |
+
if system_prompt:
|
| 217 |
+
invoke_kwargs["body"] = json.dumps(
|
| 218 |
+
{
|
| 219 |
+
"contents": gemini_messages,
|
| 220 |
+
"systemInstruction": {
|
| 221 |
+
"parts": [{"text": system_prompt}]
|
| 222 |
+
},
|
| 223 |
+
"generationConfig": {
|
| 224 |
+
"maxOutputTokens": 4000,
|
| 225 |
+
"temperature": 0.3,
|
| 226 |
+
}
|
| 227 |
+
}
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# Check if this model has an inference profile ARN (provisioned throughput)
|
| 231 |
+
# For provisioned throughput, use the ARN as the modelId
|
| 232 |
+
if MODEL_STRING in MODEL_TO_INFERENCE_PROFILE_ARN:
|
| 233 |
+
invoke_kwargs["modelId"] = MODEL_TO_INFERENCE_PROFILE_ARN[MODEL_STRING]
|
| 234 |
+
else:
|
| 235 |
+
invoke_kwargs["modelId"] = MODEL_STRING
|
| 236 |
+
|
| 237 |
+
response = bedrock_runtime.invoke_model(**invoke_kwargs)
|
| 238 |
+
|
| 239 |
+
dt = time.time() - t0
|
| 240 |
+
body = json.loads(response["body"].read())
|
| 241 |
+
except ValueError as e:
|
| 242 |
+
# Re-raise ValueError (credential errors) as-is
|
| 243 |
+
raise
|
| 244 |
+
except Exception as e:
|
| 245 |
+
error_msg = str(e)
|
| 246 |
+
if "ValidationException" in error_msg and "model identifier is invalid" in error_msg:
|
| 247 |
+
raise ValueError(
|
| 248 |
+
f"Invalid Bedrock model ID: '{MODEL_STRING}'. "
|
| 249 |
+
f"Error: {error_msg}. "
|
| 250 |
+
"Please verify the model ID is correct and the model is available in your AWS region. "
|
| 251 |
+
"Common Gemini model IDs: 'google.gemini-pro-v1' or 'google.gemini-2.0-flash-exp'"
|
| 252 |
+
) from e
|
| 253 |
+
elif "UnrecognizedClientException" in error_msg or "invalid" in error_msg.lower():
|
| 254 |
+
raise ValueError(
|
| 255 |
+
f"AWS Bedrock authentication failed: {error_msg}. "
|
| 256 |
+
"Please verify your AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY secrets "
|
| 257 |
+
"are correct and have Bedrock access permissions."
|
| 258 |
+
) from e
|
| 259 |
+
raise
|
| 260 |
+
|
| 261 |
+
# Extract text from Gemini response
|
| 262 |
+
# Gemini response format: {"candidates": [{"content": {"parts": [{"text": "..."}]}}]}
|
| 263 |
+
text = ""
|
| 264 |
+
if "candidates" in body and len(body["candidates"]) > 0:
|
| 265 |
+
candidate = body["candidates"][0]
|
| 266 |
+
if "content" in candidate and "parts" in candidate["content"]:
|
| 267 |
+
for part in candidate["content"]["parts"]:
|
| 268 |
+
if "text" in part:
|
| 269 |
+
text += part["text"]
|
| 270 |
+
|
| 271 |
+
text = text.strip()
|
| 272 |
+
total_tok = len(text.split())
|
| 273 |
+
|
| 274 |
return text, dt, total_tok, (total_tok / dt if dt else total_tok)
|
| 275 |
elif provider == "deepseek":
|
| 276 |
print("Using deepseek: ", MODEL_STRING)
|
|
|
|
| 477 |
# print(f"Ollama connection failed: {e}")
|
| 478 |
# return False
|
| 479 |
|
| 480 |
+
# Check if using Bedrock providers (anthropic, google, meta, mistral, deepseek)
|
| 481 |
+
bedrock_providers = ["anthropic", "google"]
|
| 482 |
if MODEL_TO_PROVIDER.get(MODEL_STRING) in bedrock_providers:
|
| 483 |
# Test AWS Bedrock connection by trying to invoke a simple model
|
| 484 |
try:
|
model_config.py
CHANGED
|
@@ -10,6 +10,10 @@ PROVIDER_MODEL_MAP = {
|
|
| 10 |
"anthropic.claude-sonnet-4-5-20250929-v1:0",
|
| 11 |
"anthropic.claude-opus-4-20250514-v1:0",
|
| 12 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
}
|
| 14 |
|
| 15 |
|
|
@@ -28,6 +32,8 @@ MODEL_DISPLAY_NAMES = {
|
|
| 28 |
"anthropic.claude-haiku-4-5-20251001-v1:0": "AWS Bedrock - Anthropic - Claude Haiku 4.5",
|
| 29 |
"anthropic.claude-sonnet-4-5-20250929-v1:0": "AWS Bedrock - Anthropic - Claude Sonnet 4.5",
|
| 30 |
"anthropic.claude-opus-4-20250514-v1:0": "AWS Bedrock - Anthropic - Claude Opus 4",
|
|
|
|
|
|
|
| 31 |
}
|
| 32 |
|
| 33 |
MODEL_CHOICES = [model for models in PROVIDER_MODEL_MAP.values() for model in models]
|
|
@@ -51,3 +57,13 @@ if sonnet_arn:
|
|
| 51 |
opus_arn = os.getenv("BEDROCK_OPUS_4_ARN", "").strip()
|
| 52 |
if opus_arn:
|
| 53 |
MODEL_TO_INFERENCE_PROFILE_ARN["anthropic.claude-opus-4-20250514-v1:0"] = opus_arn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
"anthropic.claude-sonnet-4-5-20250929-v1:0",
|
| 11 |
"anthropic.claude-opus-4-20250514-v1:0",
|
| 12 |
],
|
| 13 |
+
"google": [
|
| 14 |
+
"google.gemini-3.0-pro-v1:0",
|
| 15 |
+
"google.gemini-2.5-flash-v1:0",
|
| 16 |
+
],
|
| 17 |
}
|
| 18 |
|
| 19 |
|
|
|
|
| 32 |
"anthropic.claude-haiku-4-5-20251001-v1:0": "AWS Bedrock - Anthropic - Claude Haiku 4.5",
|
| 33 |
"anthropic.claude-sonnet-4-5-20250929-v1:0": "AWS Bedrock - Anthropic - Claude Sonnet 4.5",
|
| 34 |
"anthropic.claude-opus-4-20250514-v1:0": "AWS Bedrock - Anthropic - Claude Opus 4",
|
| 35 |
+
"google.gemini-3.0-pro-v1:0": "AWS Bedrock - Google - Gemini 3.0 Pro",
|
| 36 |
+
"google.gemini-2.5-flash-v1:0": "AWS Bedrock - Google - Gemini 2.5 Flash",
|
| 37 |
}
|
| 38 |
|
| 39 |
MODEL_CHOICES = [model for models in PROVIDER_MODEL_MAP.values() for model in models]
|
|
|
|
| 57 |
opus_arn = os.getenv("BEDROCK_OPUS_4_ARN", "").strip()
|
| 58 |
if opus_arn:
|
| 59 |
MODEL_TO_INFERENCE_PROFILE_ARN["anthropic.claude-opus-4-20250514-v1:0"] = opus_arn
|
| 60 |
+
|
| 61 |
+
# Gemini 3.0 Pro
|
| 62 |
+
gemini_3_arn = os.getenv("BEDROCK_GEMINI_3_ARN", "").strip()
|
| 63 |
+
if gemini_3_arn:
|
| 64 |
+
MODEL_TO_INFERENCE_PROFILE_ARN["google.gemini-3.0-pro-v1:0"] = gemini_3_arn
|
| 65 |
+
|
| 66 |
+
# Gemini 2.5 Flash
|
| 67 |
+
gemini_2_5_arn = os.getenv("BEDROCK_GEMINI_2_5_ARN", "").strip()
|
| 68 |
+
if gemini_2_5_arn:
|
| 69 |
+
MODEL_TO_INFERENCE_PROFILE_ARN["google.gemini-2.5-flash-v1:0"] = gemini_2_5_arn
|