AnirudhEsthuri-MV commited on
Commit
94aee85
Β·
1 Parent(s): f2a2584

Update llm.py

Browse files
Files changed (1) hide show
  1. llm.py +118 -122
llm.py CHANGED
@@ -21,15 +21,11 @@ api_key = os.getenv("MODEL_API_KEY")
21
  client = openai.OpenAI(api_key=api_key)
22
  bedrock_runtime = boto3.client(
23
  "bedrock-runtime",
24
- region_name="us-west-2",
25
- aws_access_key_id="AWS_ACCESS_KEY_ID",
26
- aws_secret_access_key="AWS_SECRET_ACCESS_KEY"
27
  )
28
 
29
- # Ollama configuration
30
- OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
31
-
32
-
33
  # ──────────────────────────────────────────────────────────────
34
  # Model switcher
35
  # ──────────────────────────────────────────────────────────────
@@ -170,140 +166,140 @@ def chat(messages, persona):
170
  total_tok = len(text.split())
171
 
172
  return text, dt, total_tok, (total_tok / dt if dt else total_tok)
173
- elif provider == "meta":
174
- print("Using meta (LLaMA): ", MODEL_STRING)
175
- t0 = time.time()
176
 
177
- # Add system prompt for better behavior
178
- system_prompt = ""
179
 
180
- # Format conversation properly for Llama3
181
- formatted_prompt = "<|begin_of_text|>"
182
 
183
- # Add system prompt
184
- formatted_prompt += "<|start_header_id|>system<|end_header_id|>\n" + system_prompt + "<|eot_id|>\n"
185
 
186
- # Add conversation history
187
- for msg in messages:
188
- if msg["role"] == "user":
189
- formatted_prompt += "<|start_header_id|>user<|end_header_id|>\n" + msg["content"] + "<|eot_id|>\n"
190
- elif msg["role"] == "assistant":
191
- formatted_prompt += "<|start_header_id|>assistant<|end_header_id|>\n" + msg["content"] + "<|eot_id|>\n"
192
 
193
- # Add final assistant prompt
194
- formatted_prompt += "<|start_header_id|>assistant<|end_header_id|>\n"
195
-
196
- response = bedrock_runtime.invoke_model(
197
- modelId=MODEL_STRING,
198
- contentType="application/json",
199
- accept="application/json",
200
- body=json.dumps(
201
- {
202
- "prompt": formatted_prompt,
203
- "max_gen_len": 512, # Shorter responses
204
- "temperature": 0.3, # Lower temperature for more focused responses
205
- }
206
- ),
207
- )
208
-
209
- dt = time.time() - t0
210
- body = json.loads(response["body"].read())
211
- text = body.get("generation", "").strip()
212
- total_tok = len(text.split())
213
-
214
- return text, dt, total_tok, (total_tok / dt if dt else total_tok)
215
- elif provider == "mistral":
216
- print("Using mistral: ", MODEL_STRING)
217
- t0 = time.time()
218
-
219
- prompt = messages[-1]["content"]
220
- formatted_prompt = f"<s>[INST] {prompt} [/INST]"
221
-
222
- response = bedrock_runtime.invoke_model(
223
- modelId=MODEL_STRING,
224
- contentType="application/json",
225
- accept="application/json",
226
- body=json.dumps(
227
- {"prompt": formatted_prompt, "max_tokens": 512, "temperature": 0.5}
228
- ),
229
- )
230
-
231
- dt = time.time() - t0
232
- body = json.loads(response["body"].read())
233
-
234
- text = body["outputs"][0]["text"].strip()
235
- total_tok = len(text.split())
236
-
237
- return text, dt, total_tok, (total_tok / dt if dt else total_tok)
238
- elif provider == "ollama":
239
- print("Using ollama: ", MODEL_STRING)
240
- t0 = time.time()
241
 
242
- # Format messages for Ollama API with system prompt
243
- ollama_messages = []
244
 
245
- # Add system prompt for better behavior
246
- system_prompt = ""
247
- ollama_messages.append({
248
- "role": "system",
249
- "content": system_prompt
250
- })
251
 
252
- for msg in messages:
253
- ollama_messages.append({
254
- "role": msg["role"],
255
- "content": msg["content"]
256
- })
257
 
258
- # Make request to Ollama API
259
- response = requests.post(
260
- f"{OLLAMA_BASE_URL}/api/chat",
261
- json={
262
- "model": MODEL_STRING,
263
- "messages": ollama_messages,
264
- "stream": False,
265
- "options": {
266
- "temperature": 0.3, # Lower temperature for more focused responses
267
- # "num_predict": 4000, # Much higher limit for longer responses
268
- "top_p": 0.9,
269
- "repeat_penalty": 1.1
270
- }
271
- },
272
- timeout=60
273
- )
274
 
275
- dt = time.time() - t0
276
 
277
- if response.status_code == 200:
278
- result = response.json()
279
- text = result["message"]["content"].strip()
280
- total_tok = len(text.split())
281
- return text, dt, total_tok, (total_tok / dt if dt else total_tok)
282
- else:
283
- raise Exception(f"Ollama API error: {response.status_code} - {response.text}")
284
 
285
 
286
  # ──────────────────────────────────────────────────────────────
287
  # Diagnostics / CLI test
288
  # ──────────────────────────────────────────────────────────────
289
  def check_credentials():
290
- # Check if using Ollama (no API key required)
291
- if MODEL_TO_PROVIDER.get(MODEL_STRING) == "ollama":
292
- # Test Ollama connection
293
- try:
294
- response = requests.get(f"{OLLAMA_BASE_URL}/api/tags", timeout=5)
295
- if response.status_code == 200:
296
- print("Ollama connection successful")
297
- return True
298
- else:
299
- print(f"Ollama connection failed: {response.status_code}")
300
- return False
301
- except Exception as e:
302
- print(f"Ollama connection failed: {e}")
303
- return False
304
 
305
  # Check if using Bedrock providers (anthropic, meta, mistral, deepseek)
306
- bedrock_providers = ["anthropic", "meta", "mistral", "deepseek"]
307
  if MODEL_TO_PROVIDER.get(MODEL_STRING) in bedrock_providers:
308
  # Test AWS Bedrock connection by trying to invoke a simple model
309
  try:
 
21
  client = openai.OpenAI(api_key=api_key)
22
  bedrock_runtime = boto3.client(
23
  "bedrock-runtime",
24
+ region_name="us-east-1",
25
+ aws_access_key_id= os.getenv("AWS_ACCESS_ID")
26
+ aws_secret_access_key= os.getenv("AWS_SECRET_ACCESS_KEY")
27
  )
28
 
 
 
 
 
29
  # ──────────────────────────────────────────────────────────────
30
  # Model switcher
31
  # ──────────────────────────────────────────────────────────────
 
166
  total_tok = len(text.split())
167
 
168
  return text, dt, total_tok, (total_tok / dt if dt else total_tok)
169
+ # elif provider == "meta":
170
+ # print("Using meta (LLaMA): ", MODEL_STRING)
171
+ # t0 = time.time()
172
 
173
+ # # Add system prompt for better behavior
174
+ # system_prompt = ""
175
 
176
+ # # Format conversation properly for Llama3
177
+ # formatted_prompt = "<|begin_of_text|>"
178
 
179
+ # # Add system prompt
180
+ # formatted_prompt += "<|start_header_id|>system<|end_header_id|>\n" + system_prompt + "<|eot_id|>\n"
181
 
182
+ # # Add conversation history
183
+ # for msg in messages:
184
+ # if msg["role"] == "user":
185
+ # formatted_prompt += "<|start_header_id|>user<|end_header_id|>\n" + msg["content"] + "<|eot_id|>\n"
186
+ # elif msg["role"] == "assistant":
187
+ # formatted_prompt += "<|start_header_id|>assistant<|end_header_id|>\n" + msg["content"] + "<|eot_id|>\n"
188
 
189
+ # # Add final assistant prompt
190
+ # formatted_prompt += "<|start_header_id|>assistant<|end_header_id|>\n"
191
+
192
+ # response = bedrock_runtime.invoke_model(
193
+ # modelId=MODEL_STRING,
194
+ # contentType="application/json",
195
+ # accept="application/json",
196
+ # body=json.dumps(
197
+ # {
198
+ # "prompt": formatted_prompt,
199
+ # "max_gen_len": 512, # Shorter responses
200
+ # "temperature": 0.3, # Lower temperature for more focused responses
201
+ # }
202
+ # ),
203
+ # )
204
+
205
+ # dt = time.time() - t0
206
+ # body = json.loads(response["body"].read())
207
+ # text = body.get("generation", "").strip()
208
+ # total_tok = len(text.split())
209
+
210
+ # return text, dt, total_tok, (total_tok / dt if dt else total_tok)
211
+ # elif provider == "mistral":
212
+ # print("Using mistral: ", MODEL_STRING)
213
+ # t0 = time.time()
214
+
215
+ # prompt = messages[-1]["content"]
216
+ # formatted_prompt = f"<s>[INST] {prompt} [/INST]"
217
+
218
+ # response = bedrock_runtime.invoke_model(
219
+ # modelId=MODEL_STRING,
220
+ # contentType="application/json",
221
+ # accept="application/json",
222
+ # body=json.dumps(
223
+ # {"prompt": formatted_prompt, "max_tokens": 512, "temperature": 0.5}
224
+ # ),
225
+ # )
226
+
227
+ # dt = time.time() - t0
228
+ # body = json.loads(response["body"].read())
229
+
230
+ # text = body["outputs"][0]["text"].strip()
231
+ # total_tok = len(text.split())
232
+
233
+ # return text, dt, total_tok, (total_tok / dt if dt else total_tok)
234
+ # elif provider == "ollama":
235
+ # print("Using ollama: ", MODEL_STRING)
236
+ # t0 = time.time()
237
 
238
+ # # Format messages for Ollama API with system prompt
239
+ # ollama_messages = []
240
 
241
+ # # Add system prompt for better behavior
242
+ # system_prompt = ""
243
+ # ollama_messages.append({
244
+ # "role": "system",
245
+ # "content": system_prompt
246
+ # })
247
 
248
+ # for msg in messages:
249
+ # ollama_messages.append({
250
+ # "role": msg["role"],
251
+ # "content": msg["content"]
252
+ # })
253
 
254
+ # # Make request to Ollama API
255
+ # response = requests.post(
256
+ # f"{OLLAMA_BASE_URL}/api/chat",
257
+ # json={
258
+ # "model": MODEL_STRING,
259
+ # "messages": ollama_messages,
260
+ # "stream": False,
261
+ # "options": {
262
+ # "temperature": 0.3, # Lower temperature for more focused responses
263
+ # # "num_predict": 4000, # Much higher limit for longer responses
264
+ # "top_p": 0.9,
265
+ # "repeat_penalty": 1.1
266
+ # }
267
+ # },
268
+ # timeout=60
269
+ # )
270
 
271
+ # dt = time.time() - t0
272
 
273
+ # if response.status_code == 200:
274
+ # result = response.json()
275
+ # text = result["message"]["content"].strip()
276
+ # total_tok = len(text.split())
277
+ # return text, dt, total_tok, (total_tok / dt if dt else total_tok)
278
+ # else:
279
+ # raise Exception(f"Ollama API error: {response.status_code} - {response.text}")
280
 
281
 
282
  # ──────────────────────────────────────────────────────────────
283
  # Diagnostics / CLI test
284
  # ──────────────────────────────────────────────────────────────
285
  def check_credentials():
286
+ # # Check if using Ollama (no API key required)
287
+ # if MODEL_TO_PROVIDER.get(MODEL_STRING) == "ollama":
288
+ # # Test Ollama connection
289
+ # try:
290
+ # response = requests.get(f"{OLLAMA_BASE_URL}/api/tags", timeout=5)
291
+ # if response.status_code == 200:
292
+ # print("Ollama connection successful")
293
+ # return True
294
+ # else:
295
+ # print(f"Ollama connection failed: {response.status_code}")
296
+ # return False
297
+ # except Exception as e:
298
+ # print(f"Ollama connection failed: {e}")
299
+ # return False
300
 
301
  # Check if using Bedrock providers (anthropic, meta, mistral, deepseek)
302
+ bedrock_providers = ["anthropic"]
303
  if MODEL_TO_PROVIDER.get(MODEL_STRING) in bedrock_providers:
304
  # Test AWS Bedrock connection by trying to invoke a simple model
305
  try: