import json import os import time import boto3 import openai import requests from dotenv import load_dotenv from model_config import MODEL_TO_PROVIDER # ────────────────────────────────────────────────────────────── # Load environment variables load_dotenv() # ────────────────────────────────────────────────────────────── # ────────────────────────────────────────────────────────────── # Configuration # ────────────────────────────────────────────────────────────── MODEL_STRING = "gpt-4.1-mini" # we default on gpt-4.1-mini api_key = os.getenv("MODEL_API_KEY") client = openai.OpenAI(api_key=api_key) bedrock_runtime = boto3.client( "bedrock-runtime", region_name="us-east-1", aws_access_key_id= os.getenv("AWS_ACCESS_ID") aws_secret_access_key= os.getenv("AWS_SECRET_ACCESS_KEY") ) # ────────────────────────────────────────────────────────────── # Model switcher # ────────────────────────────────────────────────────────────── def set_model(model_id: str) -> None: global MODEL_STRING MODEL_STRING = model_id print(f"Model changed to: {model_id}") def set_provider(provider: str) -> None: global PROVIDER # ────────────────────────────────────────────────────────────── # High-level Chat wrapper # ────────────────────────────────────────────────────────────── def chat(messages, persona): provider = MODEL_TO_PROVIDER[MODEL_STRING] if provider == "openai": print("Using openai: ", MODEL_STRING) t0 = time.time() # Add system prompt for better behavior system_prompt = "" # Prepare messages with system prompt chat_messages = [{"role": "system", "content": system_prompt}] for msg in messages: chat_messages.append({ "role": msg["role"], "content": msg["content"] }) request_kwargs = { "model": MODEL_STRING, "messages": chat_messages, "max_completion_tokens": 4000, } # Some newer OpenAI models only support the default temperature. if MODEL_STRING not in {"gpt-5-nano", "gpt-5-mini"}: request_kwargs["temperature"] = 0.3 response = client.chat.completions.create(**request_kwargs) dt = time.time() - t0 text = response.choices[0].message.content.strip() # Calculate tokens total_tok = response.usage.total_tokens if response.usage else len(text.split()) return text, dt, total_tok, (total_tok / dt if dt else total_tok) elif provider == "anthropic": print("Using anthropic: ", MODEL_STRING) t0 = time.time() # Add system prompt for better behavior system_prompt = "" claude_messages = [ {"role": m["role"], "content": m["content"]} for m in messages ] response = bedrock_runtime.invoke_model( modelId=MODEL_STRING, contentType="application/json", accept="application/json", body=json.dumps( { "anthropic_version": "bedrock-2023-05-31", "system": system_prompt, "messages": claude_messages, "max_tokens": 4000, # Much higher limit for longer responses "temperature": 0.3, # Lower temperature for more focused responses } ), ) dt = time.time() - t0 body = json.loads(response["body"].read()) text = "".join( part["text"] for part in body["content"] if part["type"] == "text" ).strip() total_tok = len(text.split()) return text, dt, total_tok, (total_tok / dt if dt else total_tok) elif provider == "deepseek": print("Using deepseek: ", MODEL_STRING) t0 = time.time() system_prompt = ( "" ) ds_messages = [ { "role": "system", "content": [{"type": "text", "text": system_prompt}], } ] for msg in messages: role = msg.get("role", "user") ds_messages.append( { "role": role, "content": [{"type": "text", "text": msg["content"]}], } ) response = bedrock_runtime.invoke_model( modelId=MODEL_STRING, contentType="application/json", accept="application/json", body=json.dumps( { "messages": ds_messages, "max_completion_tokens": 500, "temperature": 0.5, "top_p": 0.9, } ), ) dt = time.time() - t0 body = json.loads(response["body"].read()) outputs = body.get("output", []) text_chunks = [] for item in outputs: for content in item.get("content", []): chunk_text = content.get("text") or content.get("output_text") if chunk_text: text_chunks.append(chunk_text) text = "".join(text_chunks).strip() if not text and "response" in body: text = body["response"].get("output_text", "").strip() total_tok = len(text.split()) return text, dt, total_tok, (total_tok / dt if dt else total_tok) # elif provider == "meta": # print("Using meta (LLaMA): ", MODEL_STRING) # t0 = time.time() # # Add system prompt for better behavior # system_prompt = "" # # Format conversation properly for Llama3 # formatted_prompt = "<|begin_of_text|>" # # Add system prompt # formatted_prompt += "<|start_header_id|>system<|end_header_id|>\n" + system_prompt + "<|eot_id|>\n" # # Add conversation history # for msg in messages: # if msg["role"] == "user": # formatted_prompt += "<|start_header_id|>user<|end_header_id|>\n" + msg["content"] + "<|eot_id|>\n" # elif msg["role"] == "assistant": # formatted_prompt += "<|start_header_id|>assistant<|end_header_id|>\n" + msg["content"] + "<|eot_id|>\n" # # Add final assistant prompt # formatted_prompt += "<|start_header_id|>assistant<|end_header_id|>\n" # response = bedrock_runtime.invoke_model( # modelId=MODEL_STRING, # contentType="application/json", # accept="application/json", # body=json.dumps( # { # "prompt": formatted_prompt, # "max_gen_len": 512, # Shorter responses # "temperature": 0.3, # Lower temperature for more focused responses # } # ), # ) # dt = time.time() - t0 # body = json.loads(response["body"].read()) # text = body.get("generation", "").strip() # total_tok = len(text.split()) # return text, dt, total_tok, (total_tok / dt if dt else total_tok) # elif provider == "mistral": # print("Using mistral: ", MODEL_STRING) # t0 = time.time() # prompt = messages[-1]["content"] # formatted_prompt = f"[INST] {prompt} [/INST]" # response = bedrock_runtime.invoke_model( # modelId=MODEL_STRING, # contentType="application/json", # accept="application/json", # body=json.dumps( # {"prompt": formatted_prompt, "max_tokens": 512, "temperature": 0.5} # ), # ) # dt = time.time() - t0 # body = json.loads(response["body"].read()) # text = body["outputs"][0]["text"].strip() # total_tok = len(text.split()) # return text, dt, total_tok, (total_tok / dt if dt else total_tok) # elif provider == "ollama": # print("Using ollama: ", MODEL_STRING) # t0 = time.time() # # Format messages for Ollama API with system prompt # ollama_messages = [] # # Add system prompt for better behavior # system_prompt = "" # ollama_messages.append({ # "role": "system", # "content": system_prompt # }) # for msg in messages: # ollama_messages.append({ # "role": msg["role"], # "content": msg["content"] # }) # # Make request to Ollama API # response = requests.post( # f"{OLLAMA_BASE_URL}/api/chat", # json={ # "model": MODEL_STRING, # "messages": ollama_messages, # "stream": False, # "options": { # "temperature": 0.3, # Lower temperature for more focused responses # # "num_predict": 4000, # Much higher limit for longer responses # "top_p": 0.9, # "repeat_penalty": 1.1 # } # }, # timeout=60 # ) # dt = time.time() - t0 # if response.status_code == 200: # result = response.json() # text = result["message"]["content"].strip() # total_tok = len(text.split()) # return text, dt, total_tok, (total_tok / dt if dt else total_tok) # else: # raise Exception(f"Ollama API error: {response.status_code} - {response.text}") # ────────────────────────────────────────────────────────────── # Diagnostics / CLI test # ────────────────────────────────────────────────────────────── def check_credentials(): # # Check if using Ollama (no API key required) # if MODEL_TO_PROVIDER.get(MODEL_STRING) == "ollama": # # Test Ollama connection # try: # response = requests.get(f"{OLLAMA_BASE_URL}/api/tags", timeout=5) # if response.status_code == 200: # print("Ollama connection successful") # return True # else: # print(f"Ollama connection failed: {response.status_code}") # return False # except Exception as e: # print(f"Ollama connection failed: {e}") # return False # Check if using Bedrock providers (anthropic, meta, mistral, deepseek) bedrock_providers = ["anthropic"] if MODEL_TO_PROVIDER.get(MODEL_STRING) in bedrock_providers: # Test AWS Bedrock connection by trying to invoke a simple model try: # Try a simple test invocation to verify credentials test_response = bedrock_runtime.invoke_model( modelId="anthropic.claude-3-haiku-20240307-v1:0", contentType="application/json", accept="application/json", body=json.dumps({ "anthropic_version": "bedrock-2023-05-31", "messages": [{"role": "user", "content": "test"}], "max_tokens": 10, "temperature": 0.1 }) ) print("Bedrock connection successful") return True except Exception as e: print(f"Bedrock connection failed: {e}") print("Make sure AWS credentials are configured and you have access to Bedrock") return False # For OpenAI, check API key if MODEL_TO_PROVIDER.get(MODEL_STRING) == "openai": required = ["MODEL_API_KEY"] missing = [var for var in required if not os.getenv(var)] if missing: print(f"Missing environment variables: {missing}") return False return True return True def test_chat(): print("Testing chat...") try: test_messages = [ { "role": "user", "content": "Hello! Please respond with just 'Test successful'.", } ] text, latency, tokens, tps = chat(test_messages) print(f"Test passed! {text} {latency:.2f}s {tokens} ⚡ {tps:.1f} tps") except Exception as e: print(f"Test failed: {e}") if __name__ == "__main__": print("running diagnostics") if check_credentials(): test_chat() print("\nDone.")