Anirudh Esthuri commited on
Commit
3a73f5d
·
1 Parent(s): 4cd3a4a

Add Gemini 3.0 Pro and Gemini 2.5 Flash models support via AWS Bedrock

Browse files
Files changed (2) hide show
  1. llm.py +101 -2
  2. model_config.py +16 -0
llm.py CHANGED
@@ -172,6 +172,105 @@ def chat(messages, persona):
172
  ).strip()
173
  total_tok = len(text.split())
174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  return text, dt, total_tok, (total_tok / dt if dt else total_tok)
176
  elif provider == "deepseek":
177
  print("Using deepseek: ", MODEL_STRING)
@@ -378,8 +477,8 @@ def check_credentials():
378
  # print(f"Ollama connection failed: {e}")
379
  # return False
380
 
381
- # Check if using Bedrock providers (anthropic, meta, mistral, deepseek)
382
- bedrock_providers = ["anthropic"]
383
  if MODEL_TO_PROVIDER.get(MODEL_STRING) in bedrock_providers:
384
  # Test AWS Bedrock connection by trying to invoke a simple model
385
  try:
 
172
  ).strip()
173
  total_tok = len(text.split())
174
 
175
+ return text, dt, total_tok, (total_tok / dt if dt else total_tok)
176
+ elif provider == "google":
177
+ print("Using google (Gemini): ", MODEL_STRING)
178
+ t0 = time.time()
179
+
180
+ # Add system prompt for better behavior
181
+ system_prompt = ""
182
+
183
+ # Convert messages to Gemini format
184
+ # Gemini uses "user" and "model" roles, and content is an array
185
+ gemini_messages = []
186
+ for msg in messages:
187
+ role = msg.get("role", "user")
188
+ # Gemini uses "model" instead of "assistant"
189
+ if role == "assistant":
190
+ role = "model"
191
+ gemini_messages.append({
192
+ "role": role,
193
+ "parts": [{"text": msg["content"]}]
194
+ })
195
+
196
+ try:
197
+ bedrock_runtime = get_bedrock_client()
198
+
199
+ # Use inference profile ARN if available (for provisioned throughput models)
200
+ # Otherwise use modelId (for on-demand models)
201
+ invoke_kwargs = {
202
+ "contentType": "application/json",
203
+ "accept": "application/json",
204
+ "body": json.dumps(
205
+ {
206
+ "contents": gemini_messages,
207
+ "generationConfig": {
208
+ "maxOutputTokens": 4000,
209
+ "temperature": 0.3,
210
+ }
211
+ }
212
+ ),
213
+ }
214
+
215
+ # Add system instruction if provided
216
+ if system_prompt:
217
+ invoke_kwargs["body"] = json.dumps(
218
+ {
219
+ "contents": gemini_messages,
220
+ "systemInstruction": {
221
+ "parts": [{"text": system_prompt}]
222
+ },
223
+ "generationConfig": {
224
+ "maxOutputTokens": 4000,
225
+ "temperature": 0.3,
226
+ }
227
+ }
228
+ )
229
+
230
+ # Check if this model has an inference profile ARN (provisioned throughput)
231
+ # For provisioned throughput, use the ARN as the modelId
232
+ if MODEL_STRING in MODEL_TO_INFERENCE_PROFILE_ARN:
233
+ invoke_kwargs["modelId"] = MODEL_TO_INFERENCE_PROFILE_ARN[MODEL_STRING]
234
+ else:
235
+ invoke_kwargs["modelId"] = MODEL_STRING
236
+
237
+ response = bedrock_runtime.invoke_model(**invoke_kwargs)
238
+
239
+ dt = time.time() - t0
240
+ body = json.loads(response["body"].read())
241
+ except ValueError as e:
242
+ # Re-raise ValueError (credential errors) as-is
243
+ raise
244
+ except Exception as e:
245
+ error_msg = str(e)
246
+ if "ValidationException" in error_msg and "model identifier is invalid" in error_msg:
247
+ raise ValueError(
248
+ f"Invalid Bedrock model ID: '{MODEL_STRING}'. "
249
+ f"Error: {error_msg}. "
250
+ "Please verify the model ID is correct and the model is available in your AWS region. "
251
+ "Common Gemini model IDs: 'google.gemini-pro-v1' or 'google.gemini-2.0-flash-exp'"
252
+ ) from e
253
+ elif "UnrecognizedClientException" in error_msg or "invalid" in error_msg.lower():
254
+ raise ValueError(
255
+ f"AWS Bedrock authentication failed: {error_msg}. "
256
+ "Please verify your AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY secrets "
257
+ "are correct and have Bedrock access permissions."
258
+ ) from e
259
+ raise
260
+
261
+ # Extract text from Gemini response
262
+ # Gemini response format: {"candidates": [{"content": {"parts": [{"text": "..."}]}}]}
263
+ text = ""
264
+ if "candidates" in body and len(body["candidates"]) > 0:
265
+ candidate = body["candidates"][0]
266
+ if "content" in candidate and "parts" in candidate["content"]:
267
+ for part in candidate["content"]["parts"]:
268
+ if "text" in part:
269
+ text += part["text"]
270
+
271
+ text = text.strip()
272
+ total_tok = len(text.split())
273
+
274
  return text, dt, total_tok, (total_tok / dt if dt else total_tok)
275
  elif provider == "deepseek":
276
  print("Using deepseek: ", MODEL_STRING)
 
477
  # print(f"Ollama connection failed: {e}")
478
  # return False
479
 
480
+ # Check if using Bedrock providers (anthropic, google, meta, mistral, deepseek)
481
+ bedrock_providers = ["anthropic", "google"]
482
  if MODEL_TO_PROVIDER.get(MODEL_STRING) in bedrock_providers:
483
  # Test AWS Bedrock connection by trying to invoke a simple model
484
  try:
model_config.py CHANGED
@@ -10,6 +10,10 @@ PROVIDER_MODEL_MAP = {
10
  "anthropic.claude-sonnet-4-5-20250929-v1:0",
11
  "anthropic.claude-opus-4-20250514-v1:0",
12
  ],
 
 
 
 
13
  }
14
 
15
 
@@ -28,6 +32,8 @@ MODEL_DISPLAY_NAMES = {
28
  "anthropic.claude-haiku-4-5-20251001-v1:0": "AWS Bedrock - Anthropic - Claude Haiku 4.5",
29
  "anthropic.claude-sonnet-4-5-20250929-v1:0": "AWS Bedrock - Anthropic - Claude Sonnet 4.5",
30
  "anthropic.claude-opus-4-20250514-v1:0": "AWS Bedrock - Anthropic - Claude Opus 4",
 
 
31
  }
32
 
33
  MODEL_CHOICES = [model for models in PROVIDER_MODEL_MAP.values() for model in models]
@@ -51,3 +57,13 @@ if sonnet_arn:
51
  opus_arn = os.getenv("BEDROCK_OPUS_4_ARN", "").strip()
52
  if opus_arn:
53
  MODEL_TO_INFERENCE_PROFILE_ARN["anthropic.claude-opus-4-20250514-v1:0"] = opus_arn
 
 
 
 
 
 
 
 
 
 
 
10
  "anthropic.claude-sonnet-4-5-20250929-v1:0",
11
  "anthropic.claude-opus-4-20250514-v1:0",
12
  ],
13
+ "google": [
14
+ "google.gemini-3.0-pro-v1:0",
15
+ "google.gemini-2.5-flash-v1:0",
16
+ ],
17
  }
18
 
19
 
 
32
  "anthropic.claude-haiku-4-5-20251001-v1:0": "AWS Bedrock - Anthropic - Claude Haiku 4.5",
33
  "anthropic.claude-sonnet-4-5-20250929-v1:0": "AWS Bedrock - Anthropic - Claude Sonnet 4.5",
34
  "anthropic.claude-opus-4-20250514-v1:0": "AWS Bedrock - Anthropic - Claude Opus 4",
35
+ "google.gemini-3.0-pro-v1:0": "AWS Bedrock - Google - Gemini 3.0 Pro",
36
+ "google.gemini-2.5-flash-v1:0": "AWS Bedrock - Google - Gemini 2.5 Flash",
37
  }
38
 
39
  MODEL_CHOICES = [model for models in PROVIDER_MODEL_MAP.values() for model in models]
 
57
  opus_arn = os.getenv("BEDROCK_OPUS_4_ARN", "").strip()
58
  if opus_arn:
59
  MODEL_TO_INFERENCE_PROFILE_ARN["anthropic.claude-opus-4-20250514-v1:0"] = opus_arn
60
+
61
+ # Gemini 3.0 Pro
62
+ gemini_3_arn = os.getenv("BEDROCK_GEMINI_3_ARN", "").strip()
63
+ if gemini_3_arn:
64
+ MODEL_TO_INFERENCE_PROFILE_ARN["google.gemini-3.0-pro-v1:0"] = gemini_3_arn
65
+
66
+ # Gemini 2.5 Flash
67
+ gemini_2_5_arn = os.getenv("BEDROCK_GEMINI_2_5_ARN", "").strip()
68
+ if gemini_2_5_arn:
69
+ MODEL_TO_INFERENCE_PROFILE_ARN["google.gemini-2.5-flash-v1:0"] = gemini_2_5_arn