my-voice-agent / llm_handler.py
jblast94's picture
Update llm_handler.py
34f8426 verified
# llm_handler.py (Refactored)
import os
import requests
from typing import Iterator, Optional
# A simple registry to define providers and their models
PROVIDER_CONFIG = {
"anthropic": {
"models": {
"claude-3-5-sonnet-20241022": {"provider": "anthropic", "api_url": "https://api.anthropic.com/v1/messages"},
"claude-3-haiku-20240307": {"provider": "anthropic", "api_url": "https://api.anthropic.com/v1/messages"},
},
"openrouter": {
"models": {
"anthropic/claude-3-opus-20240229": {"provider": "anthropic", "api_url": "https://api.anthropic.com/v1/messages"},
"google/gemini-2.0-flash-exp": {"provider": "google", "api_url": "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash-exp:generateContent"},
}
},
"huggingface": {
"models": {
"mistralai/Mixtral-8x7B": {"provider": "huggingface", "api_url": "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B"},
"meta-llama/Meta-Llama-3.1-8B-Instruct": {"provider": "huggingface", "api_url": "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3.1-8B-Instruct"},
}
}
}
}
class LLMHandler:
"""
Handle LLM interactions via OpenRouter.
- Uses OPENROUTER_API_KEY (required).
- Default model from PREFERRED_MODEL or google/gemini-2.0-flash-exp.
- Supports dynamic override from UI (model_override).
"""
def __init__(self, model_override: str | None = None):
self.openrouter_key = os.getenv("OPENROUTER_API_KEY")
if not self.openrouter_key:
raise ValueError(
"OPENROUTER_API_KEY is not set. Configure it in your Space secrets."
)
default_model = os.getenv("PREFERRED_MODEL", "google/gemini-2.0-flash-exp")
self.model_id = model_override or default_model
def set_model(self, model_name: str):
"""Update active model at runtime."""
if model_name:
self.model_id = model_name
def generate_streaming(self, prompt: str, model: Optional[str] = None) -> Iterator[str]:
"""
Generate a streaming response using OpenRouter chat completions.
"""
model_to_use = model or self.model_id
print(f"[LLMHandler] Using OpenRouter model: {model_to_use}")
try:
yield from self._call_openrouter_streaming(prompt, model_to_use)
except Exception as e:
error_msg = f"Error during generation with OpenRouter: {str(e)}"
print(error_msg)
yield error_msg
def _call_anthropic_streaming(self, prompt: str, api_url: str) -> Iterator[str]:
headers = {
"x-api-key": self.anthropic_key,
"anthropic-version": "2023-06-01",
"content-type": "application/json"
}
data = {"model": self.model_id, "max_tokens": 2000, "stream": True, "messages": [{"role": "user", "content": prompt}]}
response = requests.post(api_url, headers=headers, json=data, stream=True, timeout=60)
response.raise_for_status()
for line in response.iter_lines():
line = line.decode('utf-8')
if line.startswith("data: "):
line = line[6:]
if line == "[DONE]":
break
try:
chunk = json.loads(line)
if chunk.get("type") == "content_block_delta" and chunk.get("delta", {}).get("text"):
yield chunk["delta"]["text"]
except json.JSONDecodeError:
continue
def _call_huggingface_streaming(self, prompt: str, api_url: str) -> Iterator[str]:
headers = {"Authorization": f"Bearer {self.hf_token}", "Content-Type": "application/json"}
data = {"inputs": prompt, "parameters": {"max_new_tokens": 2000, "stream": True}}
response = requests.post(api_url, headers=headers, json=data, stream=True, timeout=60)
if response.status_code == 503: # Model loading
yield "Model is loading, please try again in a few moments..."
return
response.raise_for_status()
for line in response.iter_lines():
if line:
try:
chunk = json.loads(line.decode('utf-8'))
if 'token' in chunk:
text = chunk.get('token', {}).get('text', '')
if text:
yield text
except json.JSONDecodeError:
continue
def _call_openrouter_streaming(self, prompt: str, model_id: str) -> Iterator[str]:
"""
Stream completions from OpenRouter using OpenAI-compatible SSE.
"""
api_url = "https://openrouter.ai/api/v1/chat/completions"
headers = {
"Authorization": f"Bearer {self.openrouter_key}",
"Content-Type": "application/json",
"HTTP-Referer": "https://huggingface.co/spaces/jblast94/my-voice-agent",
"X-Title": "my-voice-agent",
}
data = {
"model": model_id,
"stream": True,
"messages": [
{"role": "system", "content": "You are a helpful, friendly AI assistant."},
{"role": "user", "content": prompt},
],
}
with requests.post(api_url, headers=headers, json=data, stream=True, timeout=60) as response:
response.raise_for_status()
for raw_line in response.iter_lines():
if not raw_line:
continue
if raw_line.startswith(b"data: "):
payload = raw_line[6:]
if payload == b"[DONE]":
break
try:
chunk = json.loads(payload.decode("utf-8"))
except Exception:
continue
choices = chunk.get("choices") or []
if not choices:
continue
delta = choices[0].get("delta") or {}
content_piece = delta.get("content")
if content_piece:
yield content_piece
def get_provider_info(self) -> dict:
"""Get information about the current provider configuration."""
return {
"provider": "openrouter",
"model": self.model_id,
"requires": ["OPENROUTER_API_KEY"],
}