File size: 6,662 Bytes
d11b4d3
af8c539
 
 
d11b4d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af8c539
 
34f8426
 
d11b4d3
34f8426
 
 
 
 
 
af8c539
34f8426
 
 
 
 
 
 
 
 
 
 
 
 
d11b4d3
af8c539
d11b4d3
34f8426
d11b4d3
 
34f8426
 
af8c539
34f8426
af8c539
34f8426
 
 
af8c539
d11b4d3
af8c539
 
 
 
 
d11b4d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af8c539
d11b4d3
 
 
af8c539
d11b4d3
af8c539
d11b4d3
 
 
af8c539
d11b4d3
 
 
 
 
 
 
 
 
 
 
af8c539
34f8426
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d11b4d3
34f8426
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af8c539
 
d11b4d3
af8c539
34f8426
d11b4d3
34f8426
af8c539
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# llm_handler.py (Refactored)
import os
import requests
from typing import Iterator, Optional

# A simple registry to define providers and their models
PROVIDER_CONFIG = {
    "anthropic": {
        "models": {
            "claude-3-5-sonnet-20241022": {"provider": "anthropic", "api_url": "https://api.anthropic.com/v1/messages"},
            "claude-3-haiku-20240307": {"provider": "anthropic", "api_url": "https://api.anthropic.com/v1/messages"},
        },
        "openrouter": {
            "models": {
                "anthropic/claude-3-opus-20240229": {"provider": "anthropic", "api_url": "https://api.anthropic.com/v1/messages"},
                "google/gemini-2.0-flash-exp": {"provider": "google", "api_url": "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash-exp:generateContent"},
            }
        },
        "huggingface": {
            "models": {
                "mistralai/Mixtral-8x7B": {"provider": "huggingface", "api_url": "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B"},
                "meta-llama/Meta-Llama-3.1-8B-Instruct": {"provider": "huggingface", "api_url": "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3.1-8B-Instruct"},
            }
        }
    }
}

class LLMHandler:
    """
    Handle LLM interactions via OpenRouter.

    - Uses OPENROUTER_API_KEY (required).
    - Default model from PREFERRED_MODEL or google/gemini-2.0-flash-exp.
    - Supports dynamic override from UI (model_override).
    """

    def __init__(self, model_override: str | None = None):
        self.openrouter_key = os.getenv("OPENROUTER_API_KEY")

        if not self.openrouter_key:
            raise ValueError(
                "OPENROUTER_API_KEY is not set. Configure it in your Space secrets."
            )

        default_model = os.getenv("PREFERRED_MODEL", "google/gemini-2.0-flash-exp")
        self.model_id = model_override or default_model

    def set_model(self, model_name: str):
        """Update active model at runtime."""
        if model_name:
            self.model_id = model_name

    def generate_streaming(self, prompt: str, model: Optional[str] = None) -> Iterator[str]:
        """
        Generate a streaming response using OpenRouter chat completions.
        """
        model_to_use = model or self.model_id
        print(f"[LLMHandler] Using OpenRouter model: {model_to_use}")

        try:
            yield from self._call_openrouter_streaming(prompt, model_to_use)
        except Exception as e:
            error_msg = f"Error during generation with OpenRouter: {str(e)}"
            print(error_msg)
            yield error_msg
    
    def _call_anthropic_streaming(self, prompt: str, api_url: str) -> Iterator[str]:
        headers = {
            "x-api-key": self.anthropic_key,
            "anthropic-version": "2023-06-01",
            "content-type": "application/json"
        }
        data = {"model": self.model_id, "max_tokens": 2000, "stream": True, "messages": [{"role": "user", "content": prompt}]}
        
        response = requests.post(api_url, headers=headers, json=data, stream=True, timeout=60)
        response.raise_for_status()
        for line in response.iter_lines():
            line = line.decode('utf-8')
            if line.startswith("data: "):
                line = line[6:]
                if line == "[DONE]":
                    break
                try:
                    chunk = json.loads(line)
                    if chunk.get("type") == "content_block_delta" and chunk.get("delta", {}).get("text"):
                        yield chunk["delta"]["text"]
                except json.JSONDecodeError:
                    continue
    
    def _call_huggingface_streaming(self, prompt: str, api_url: str) -> Iterator[str]:
        headers = {"Authorization": f"Bearer {self.hf_token}", "Content-Type": "application/json"}
        data = {"inputs": prompt, "parameters": {"max_new_tokens": 2000, "stream": True}}
        
        response = requests.post(api_url, headers=headers, json=data, stream=True, timeout=60)
        
        if response.status_code == 503: # Model loading
            yield "Model is loading, please try again in a few moments..."
            return
            
        response.raise_for_status()
        for line in response.iter_lines():
            if line:
                try:
                    chunk = json.loads(line.decode('utf-8'))
                    if 'token' in chunk:
                        text = chunk.get('token', {}).get('text', '')
                        if text:
                            yield text
                except json.JSONDecodeError:
                    continue
    
    def _call_openrouter_streaming(self, prompt: str, model_id: str) -> Iterator[str]:
        """
        Stream completions from OpenRouter using OpenAI-compatible SSE.
        """
        api_url = "https://openrouter.ai/api/v1/chat/completions"
        headers = {
            "Authorization": f"Bearer {self.openrouter_key}",
            "Content-Type": "application/json",
            "HTTP-Referer": "https://huggingface.co/spaces/jblast94/my-voice-agent",
            "X-Title": "my-voice-agent",
        }
        data = {
            "model": model_id,
            "stream": True,
            "messages": [
                {"role": "system", "content": "You are a helpful, friendly AI assistant."},
                {"role": "user", "content": prompt},
            ],
        }

        with requests.post(api_url, headers=headers, json=data, stream=True, timeout=60) as response:
            response.raise_for_status()
            for raw_line in response.iter_lines():
                if not raw_line:
                    continue
                if raw_line.startswith(b"data: "):
                    payload = raw_line[6:]
                    if payload == b"[DONE]":
                        break
                    try:
                        chunk = json.loads(payload.decode("utf-8"))
                    except Exception:
                        continue

                    choices = chunk.get("choices") or []
                    if not choices:
                        continue
                    delta = choices[0].get("delta") or {}
                    content_piece = delta.get("content")
                    if content_piece:
                        yield content_piece

    def get_provider_info(self) -> dict:
        """Get information about the current provider configuration."""
        return {
            "provider": "openrouter",
            "model": self.model_id,
            "requires": ["OPENROUTER_API_KEY"],
        }