Spaces:
Running
Running
| import json | |
| import os | |
| import time | |
| import boto3 | |
| import openai | |
| import requests | |
| from dotenv import load_dotenv | |
| from model_config import MODEL_TO_PROVIDER | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Load environment variables | |
| load_dotenv() | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Configuration | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODEL_STRING = "gpt-4.1-mini" # we default on gpt-4.1-mini | |
| api_key = os.getenv("MODEL_API_KEY") | |
| client = openai.OpenAI(api_key=api_key) | |
| bedrock_runtime = boto3.client( | |
| "bedrock-runtime", | |
| region_name="us-east-1", | |
| aws_access_key_id= os.getenv("AWS_ACCESS_ID") | |
| aws_secret_access_key= os.getenv("AWS_SECRET_ACCESS_KEY") | |
| ) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Model switcher | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def set_model(model_id: str) -> None: | |
| global MODEL_STRING | |
| MODEL_STRING = model_id | |
| print(f"Model changed to: {model_id}") | |
| def set_provider(provider: str) -> None: | |
| global PROVIDER | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # High-level Chat wrapper | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def chat(messages, persona): | |
| provider = MODEL_TO_PROVIDER[MODEL_STRING] | |
| if provider == "openai": | |
| print("Using openai: ", MODEL_STRING) | |
| t0 = time.time() | |
| # Add system prompt for better behavior | |
| system_prompt = "" | |
| # Prepare messages with system prompt | |
| chat_messages = [{"role": "system", "content": system_prompt}] | |
| for msg in messages: | |
| chat_messages.append({ | |
| "role": msg["role"], | |
| "content": msg["content"] | |
| }) | |
| request_kwargs = { | |
| "model": MODEL_STRING, | |
| "messages": chat_messages, | |
| "max_completion_tokens": 4000, | |
| } | |
| # Some newer OpenAI models only support the default temperature. | |
| if MODEL_STRING not in {"gpt-5-nano", "gpt-5-mini"}: | |
| request_kwargs["temperature"] = 0.3 | |
| response = client.chat.completions.create(**request_kwargs) | |
| dt = time.time() - t0 | |
| text = response.choices[0].message.content.strip() | |
| # Calculate tokens | |
| total_tok = response.usage.total_tokens if response.usage else len(text.split()) | |
| return text, dt, total_tok, (total_tok / dt if dt else total_tok) | |
| elif provider == "anthropic": | |
| print("Using anthropic: ", MODEL_STRING) | |
| t0 = time.time() | |
| # Add system prompt for better behavior | |
| system_prompt = "" | |
| claude_messages = [ | |
| {"role": m["role"], "content": m["content"]} for m in messages | |
| ] | |
| response = bedrock_runtime.invoke_model( | |
| modelId=MODEL_STRING, | |
| contentType="application/json", | |
| accept="application/json", | |
| body=json.dumps( | |
| { | |
| "anthropic_version": "bedrock-2023-05-31", | |
| "system": system_prompt, | |
| "messages": claude_messages, | |
| "max_tokens": 4000, # Much higher limit for longer responses | |
| "temperature": 0.3, # Lower temperature for more focused responses | |
| } | |
| ), | |
| ) | |
| dt = time.time() - t0 | |
| body = json.loads(response["body"].read()) | |
| text = "".join( | |
| part["text"] for part in body["content"] if part["type"] == "text" | |
| ).strip() | |
| total_tok = len(text.split()) | |
| return text, dt, total_tok, (total_tok / dt if dt else total_tok) | |
| elif provider == "deepseek": | |
| print("Using deepseek: ", MODEL_STRING) | |
| t0 = time.time() | |
| system_prompt = ( | |
| "" | |
| ) | |
| ds_messages = [ | |
| { | |
| "role": "system", | |
| "content": [{"type": "text", "text": system_prompt}], | |
| } | |
| ] | |
| for msg in messages: | |
| role = msg.get("role", "user") | |
| ds_messages.append( | |
| { | |
| "role": role, | |
| "content": [{"type": "text", "text": msg["content"]}], | |
| } | |
| ) | |
| response = bedrock_runtime.invoke_model( | |
| modelId=MODEL_STRING, | |
| contentType="application/json", | |
| accept="application/json", | |
| body=json.dumps( | |
| { | |
| "messages": ds_messages, | |
| "max_completion_tokens": 500, | |
| "temperature": 0.5, | |
| "top_p": 0.9, | |
| } | |
| ), | |
| ) | |
| dt = time.time() - t0 | |
| body = json.loads(response["body"].read()) | |
| outputs = body.get("output", []) | |
| text_chunks = [] | |
| for item in outputs: | |
| for content in item.get("content", []): | |
| chunk_text = content.get("text") or content.get("output_text") | |
| if chunk_text: | |
| text_chunks.append(chunk_text) | |
| text = "".join(text_chunks).strip() | |
| if not text and "response" in body: | |
| text = body["response"].get("output_text", "").strip() | |
| total_tok = len(text.split()) | |
| return text, dt, total_tok, (total_tok / dt if dt else total_tok) | |
| # elif provider == "meta": | |
| # print("Using meta (LLaMA): ", MODEL_STRING) | |
| # t0 = time.time() | |
| # # Add system prompt for better behavior | |
| # system_prompt = "" | |
| # # Format conversation properly for Llama3 | |
| # formatted_prompt = "<|begin_of_text|>" | |
| # # Add system prompt | |
| # formatted_prompt += "<|start_header_id|>system<|end_header_id|>\n" + system_prompt + "<|eot_id|>\n" | |
| # # Add conversation history | |
| # for msg in messages: | |
| # if msg["role"] == "user": | |
| # formatted_prompt += "<|start_header_id|>user<|end_header_id|>\n" + msg["content"] + "<|eot_id|>\n" | |
| # elif msg["role"] == "assistant": | |
| # formatted_prompt += "<|start_header_id|>assistant<|end_header_id|>\n" + msg["content"] + "<|eot_id|>\n" | |
| # # Add final assistant prompt | |
| # formatted_prompt += "<|start_header_id|>assistant<|end_header_id|>\n" | |
| # response = bedrock_runtime.invoke_model( | |
| # modelId=MODEL_STRING, | |
| # contentType="application/json", | |
| # accept="application/json", | |
| # body=json.dumps( | |
| # { | |
| # "prompt": formatted_prompt, | |
| # "max_gen_len": 512, # Shorter responses | |
| # "temperature": 0.3, # Lower temperature for more focused responses | |
| # } | |
| # ), | |
| # ) | |
| # dt = time.time() - t0 | |
| # body = json.loads(response["body"].read()) | |
| # text = body.get("generation", "").strip() | |
| # total_tok = len(text.split()) | |
| # return text, dt, total_tok, (total_tok / dt if dt else total_tok) | |
| # elif provider == "mistral": | |
| # print("Using mistral: ", MODEL_STRING) | |
| # t0 = time.time() | |
| # prompt = messages[-1]["content"] | |
| # formatted_prompt = f"<s>[INST] {prompt} [/INST]" | |
| # response = bedrock_runtime.invoke_model( | |
| # modelId=MODEL_STRING, | |
| # contentType="application/json", | |
| # accept="application/json", | |
| # body=json.dumps( | |
| # {"prompt": formatted_prompt, "max_tokens": 512, "temperature": 0.5} | |
| # ), | |
| # ) | |
| # dt = time.time() - t0 | |
| # body = json.loads(response["body"].read()) | |
| # text = body["outputs"][0]["text"].strip() | |
| # total_tok = len(text.split()) | |
| # return text, dt, total_tok, (total_tok / dt if dt else total_tok) | |
| # elif provider == "ollama": | |
| # print("Using ollama: ", MODEL_STRING) | |
| # t0 = time.time() | |
| # # Format messages for Ollama API with system prompt | |
| # ollama_messages = [] | |
| # # Add system prompt for better behavior | |
| # system_prompt = "" | |
| # ollama_messages.append({ | |
| # "role": "system", | |
| # "content": system_prompt | |
| # }) | |
| # for msg in messages: | |
| # ollama_messages.append({ | |
| # "role": msg["role"], | |
| # "content": msg["content"] | |
| # }) | |
| # # Make request to Ollama API | |
| # response = requests.post( | |
| # f"{OLLAMA_BASE_URL}/api/chat", | |
| # json={ | |
| # "model": MODEL_STRING, | |
| # "messages": ollama_messages, | |
| # "stream": False, | |
| # "options": { | |
| # "temperature": 0.3, # Lower temperature for more focused responses | |
| # # "num_predict": 4000, # Much higher limit for longer responses | |
| # "top_p": 0.9, | |
| # "repeat_penalty": 1.1 | |
| # } | |
| # }, | |
| # timeout=60 | |
| # ) | |
| # dt = time.time() - t0 | |
| # if response.status_code == 200: | |
| # result = response.json() | |
| # text = result["message"]["content"].strip() | |
| # total_tok = len(text.split()) | |
| # return text, dt, total_tok, (total_tok / dt if dt else total_tok) | |
| # else: | |
| # raise Exception(f"Ollama API error: {response.status_code} - {response.text}") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Diagnostics / CLI test | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def check_credentials(): | |
| # # Check if using Ollama (no API key required) | |
| # if MODEL_TO_PROVIDER.get(MODEL_STRING) == "ollama": | |
| # # Test Ollama connection | |
| # try: | |
| # response = requests.get(f"{OLLAMA_BASE_URL}/api/tags", timeout=5) | |
| # if response.status_code == 200: | |
| # print("Ollama connection successful") | |
| # return True | |
| # else: | |
| # print(f"Ollama connection failed: {response.status_code}") | |
| # return False | |
| # except Exception as e: | |
| # print(f"Ollama connection failed: {e}") | |
| # return False | |
| # Check if using Bedrock providers (anthropic, meta, mistral, deepseek) | |
| bedrock_providers = ["anthropic"] | |
| if MODEL_TO_PROVIDER.get(MODEL_STRING) in bedrock_providers: | |
| # Test AWS Bedrock connection by trying to invoke a simple model | |
| try: | |
| # Try a simple test invocation to verify credentials | |
| test_response = bedrock_runtime.invoke_model( | |
| modelId="anthropic.claude-3-haiku-20240307-v1:0", | |
| contentType="application/json", | |
| accept="application/json", | |
| body=json.dumps({ | |
| "anthropic_version": "bedrock-2023-05-31", | |
| "messages": [{"role": "user", "content": "test"}], | |
| "max_tokens": 10, | |
| "temperature": 0.1 | |
| }) | |
| ) | |
| print("Bedrock connection successful") | |
| return True | |
| except Exception as e: | |
| print(f"Bedrock connection failed: {e}") | |
| print("Make sure AWS credentials are configured and you have access to Bedrock") | |
| return False | |
| # For OpenAI, check API key | |
| if MODEL_TO_PROVIDER.get(MODEL_STRING) == "openai": | |
| required = ["MODEL_API_KEY"] | |
| missing = [var for var in required if not os.getenv(var)] | |
| if missing: | |
| print(f"Missing environment variables: {missing}") | |
| return False | |
| return True | |
| return True | |
| def test_chat(): | |
| print("Testing chat...") | |
| try: | |
| test_messages = [ | |
| { | |
| "role": "user", | |
| "content": "Hello! Please respond with just 'Test successful'.", | |
| } | |
| ] | |
| text, latency, tokens, tps = chat(test_messages) | |
| print(f"Test passed! {text} {latency:.2f}s {tokens} β‘ {tps:.1f} tps") | |
| except Exception as e: | |
| print(f"Test failed: {e}") | |
| if __name__ == "__main__": | |
| print("running diagnostics") | |
| if check_credentials(): | |
| test_chat() | |
| print("\nDone.") |