AnirudhEsthuri-MV's picture
Update llm.py
94aee85
raw
history blame
13.7 kB
import json
import os
import time
import boto3
import openai
import requests
from dotenv import load_dotenv
from model_config import MODEL_TO_PROVIDER
# ──────────────────────────────────────────────────────────────
# Load environment variables
load_dotenv()
# ──────────────────────────────────────────────────────────────
# ──────────────────────────────────────────────────────────────
# Configuration
# ──────────────────────────────────────────────────────────────
MODEL_STRING = "gpt-4.1-mini" # we default on gpt-4.1-mini
api_key = os.getenv("MODEL_API_KEY")
client = openai.OpenAI(api_key=api_key)
bedrock_runtime = boto3.client(
"bedrock-runtime",
region_name="us-east-1",
aws_access_key_id= os.getenv("AWS_ACCESS_ID")
aws_secret_access_key= os.getenv("AWS_SECRET_ACCESS_KEY")
)
# ──────────────────────────────────────────────────────────────
# Model switcher
# ──────────────────────────────────────────────────────────────
def set_model(model_id: str) -> None:
global MODEL_STRING
MODEL_STRING = model_id
print(f"Model changed to: {model_id}")
def set_provider(provider: str) -> None:
global PROVIDER
# ──────────────────────────────────────────────────────────────
# High-level Chat wrapper
# ──────────────────────────────────────────────────────────────
def chat(messages, persona):
provider = MODEL_TO_PROVIDER[MODEL_STRING]
if provider == "openai":
print("Using openai: ", MODEL_STRING)
t0 = time.time()
# Add system prompt for better behavior
system_prompt = ""
# Prepare messages with system prompt
chat_messages = [{"role": "system", "content": system_prompt}]
for msg in messages:
chat_messages.append({
"role": msg["role"],
"content": msg["content"]
})
request_kwargs = {
"model": MODEL_STRING,
"messages": chat_messages,
"max_completion_tokens": 4000,
}
# Some newer OpenAI models only support the default temperature.
if MODEL_STRING not in {"gpt-5-nano", "gpt-5-mini"}:
request_kwargs["temperature"] = 0.3
response = client.chat.completions.create(**request_kwargs)
dt = time.time() - t0
text = response.choices[0].message.content.strip()
# Calculate tokens
total_tok = response.usage.total_tokens if response.usage else len(text.split())
return text, dt, total_tok, (total_tok / dt if dt else total_tok)
elif provider == "anthropic":
print("Using anthropic: ", MODEL_STRING)
t0 = time.time()
# Add system prompt for better behavior
system_prompt = ""
claude_messages = [
{"role": m["role"], "content": m["content"]} for m in messages
]
response = bedrock_runtime.invoke_model(
modelId=MODEL_STRING,
contentType="application/json",
accept="application/json",
body=json.dumps(
{
"anthropic_version": "bedrock-2023-05-31",
"system": system_prompt,
"messages": claude_messages,
"max_tokens": 4000, # Much higher limit for longer responses
"temperature": 0.3, # Lower temperature for more focused responses
}
),
)
dt = time.time() - t0
body = json.loads(response["body"].read())
text = "".join(
part["text"] for part in body["content"] if part["type"] == "text"
).strip()
total_tok = len(text.split())
return text, dt, total_tok, (total_tok / dt if dt else total_tok)
elif provider == "deepseek":
print("Using deepseek: ", MODEL_STRING)
t0 = time.time()
system_prompt = (
""
)
ds_messages = [
{
"role": "system",
"content": [{"type": "text", "text": system_prompt}],
}
]
for msg in messages:
role = msg.get("role", "user")
ds_messages.append(
{
"role": role,
"content": [{"type": "text", "text": msg["content"]}],
}
)
response = bedrock_runtime.invoke_model(
modelId=MODEL_STRING,
contentType="application/json",
accept="application/json",
body=json.dumps(
{
"messages": ds_messages,
"max_completion_tokens": 500,
"temperature": 0.5,
"top_p": 0.9,
}
),
)
dt = time.time() - t0
body = json.loads(response["body"].read())
outputs = body.get("output", [])
text_chunks = []
for item in outputs:
for content in item.get("content", []):
chunk_text = content.get("text") or content.get("output_text")
if chunk_text:
text_chunks.append(chunk_text)
text = "".join(text_chunks).strip()
if not text and "response" in body:
text = body["response"].get("output_text", "").strip()
total_tok = len(text.split())
return text, dt, total_tok, (total_tok / dt if dt else total_tok)
# elif provider == "meta":
# print("Using meta (LLaMA): ", MODEL_STRING)
# t0 = time.time()
# # Add system prompt for better behavior
# system_prompt = ""
# # Format conversation properly for Llama3
# formatted_prompt = "<|begin_of_text|>"
# # Add system prompt
# formatted_prompt += "<|start_header_id|>system<|end_header_id|>\n" + system_prompt + "<|eot_id|>\n"
# # Add conversation history
# for msg in messages:
# if msg["role"] == "user":
# formatted_prompt += "<|start_header_id|>user<|end_header_id|>\n" + msg["content"] + "<|eot_id|>\n"
# elif msg["role"] == "assistant":
# formatted_prompt += "<|start_header_id|>assistant<|end_header_id|>\n" + msg["content"] + "<|eot_id|>\n"
# # Add final assistant prompt
# formatted_prompt += "<|start_header_id|>assistant<|end_header_id|>\n"
# response = bedrock_runtime.invoke_model(
# modelId=MODEL_STRING,
# contentType="application/json",
# accept="application/json",
# body=json.dumps(
# {
# "prompt": formatted_prompt,
# "max_gen_len": 512, # Shorter responses
# "temperature": 0.3, # Lower temperature for more focused responses
# }
# ),
# )
# dt = time.time() - t0
# body = json.loads(response["body"].read())
# text = body.get("generation", "").strip()
# total_tok = len(text.split())
# return text, dt, total_tok, (total_tok / dt if dt else total_tok)
# elif provider == "mistral":
# print("Using mistral: ", MODEL_STRING)
# t0 = time.time()
# prompt = messages[-1]["content"]
# formatted_prompt = f"<s>[INST] {prompt} [/INST]"
# response = bedrock_runtime.invoke_model(
# modelId=MODEL_STRING,
# contentType="application/json",
# accept="application/json",
# body=json.dumps(
# {"prompt": formatted_prompt, "max_tokens": 512, "temperature": 0.5}
# ),
# )
# dt = time.time() - t0
# body = json.loads(response["body"].read())
# text = body["outputs"][0]["text"].strip()
# total_tok = len(text.split())
# return text, dt, total_tok, (total_tok / dt if dt else total_tok)
# elif provider == "ollama":
# print("Using ollama: ", MODEL_STRING)
# t0 = time.time()
# # Format messages for Ollama API with system prompt
# ollama_messages = []
# # Add system prompt for better behavior
# system_prompt = ""
# ollama_messages.append({
# "role": "system",
# "content": system_prompt
# })
# for msg in messages:
# ollama_messages.append({
# "role": msg["role"],
# "content": msg["content"]
# })
# # Make request to Ollama API
# response = requests.post(
# f"{OLLAMA_BASE_URL}/api/chat",
# json={
# "model": MODEL_STRING,
# "messages": ollama_messages,
# "stream": False,
# "options": {
# "temperature": 0.3, # Lower temperature for more focused responses
# # "num_predict": 4000, # Much higher limit for longer responses
# "top_p": 0.9,
# "repeat_penalty": 1.1
# }
# },
# timeout=60
# )
# dt = time.time() - t0
# if response.status_code == 200:
# result = response.json()
# text = result["message"]["content"].strip()
# total_tok = len(text.split())
# return text, dt, total_tok, (total_tok / dt if dt else total_tok)
# else:
# raise Exception(f"Ollama API error: {response.status_code} - {response.text}")
# ──────────────────────────────────────────────────────────────
# Diagnostics / CLI test
# ──────────────────────────────────────────────────────────────
def check_credentials():
# # Check if using Ollama (no API key required)
# if MODEL_TO_PROVIDER.get(MODEL_STRING) == "ollama":
# # Test Ollama connection
# try:
# response = requests.get(f"{OLLAMA_BASE_URL}/api/tags", timeout=5)
# if response.status_code == 200:
# print("Ollama connection successful")
# return True
# else:
# print(f"Ollama connection failed: {response.status_code}")
# return False
# except Exception as e:
# print(f"Ollama connection failed: {e}")
# return False
# Check if using Bedrock providers (anthropic, meta, mistral, deepseek)
bedrock_providers = ["anthropic"]
if MODEL_TO_PROVIDER.get(MODEL_STRING) in bedrock_providers:
# Test AWS Bedrock connection by trying to invoke a simple model
try:
# Try a simple test invocation to verify credentials
test_response = bedrock_runtime.invoke_model(
modelId="anthropic.claude-3-haiku-20240307-v1:0",
contentType="application/json",
accept="application/json",
body=json.dumps({
"anthropic_version": "bedrock-2023-05-31",
"messages": [{"role": "user", "content": "test"}],
"max_tokens": 10,
"temperature": 0.1
})
)
print("Bedrock connection successful")
return True
except Exception as e:
print(f"Bedrock connection failed: {e}")
print("Make sure AWS credentials are configured and you have access to Bedrock")
return False
# For OpenAI, check API key
if MODEL_TO_PROVIDER.get(MODEL_STRING) == "openai":
required = ["MODEL_API_KEY"]
missing = [var for var in required if not os.getenv(var)]
if missing:
print(f"Missing environment variables: {missing}")
return False
return True
return True
def test_chat():
print("Testing chat...")
try:
test_messages = [
{
"role": "user",
"content": "Hello! Please respond with just 'Test successful'.",
}
]
text, latency, tokens, tps = chat(test_messages)
print(f"Test passed! {text} {latency:.2f}s {tokens} ⚑ {tps:.1f} tps")
except Exception as e:
print(f"Test failed: {e}")
if __name__ == "__main__":
print("running diagnostics")
if check_credentials():
test_chat()
print("\nDone.")