Spaces:
Sleeping
Sleeping
| """ | |
| LLM Backend for Project Echo - Supports multiple providers | |
| """ | |
| import os | |
| import requests | |
| import json | |
| from typing import List, Dict, Optional | |
| from enum import Enum | |
| # Try to import transformers for local model loading | |
| try: | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM | |
| import torch | |
| TRANSFORMERS_AVAILABLE = True | |
| except ImportError: | |
| TRANSFORMERS_AVAILABLE = False | |
| class LLMProvider(Enum): | |
| """Supported LLM providers""" | |
| OPENAI = "openai" | |
| ANTHROPIC = "anthropic" | |
| HUGGINGFACE = "huggingface" | |
| LM_STUDIO = "lm_studio" | |
| class LLMBackend: | |
| """ | |
| Unified interface for multiple LLM providers. | |
| Supports OpenAI, Anthropic, HuggingFace Inference API, and LM Studio. | |
| """ | |
| def __init__(self, provider: LLMProvider = None, api_key: str = None, model: str = None): | |
| """ | |
| Initialize LLM backend with specified provider. | |
| Args: | |
| provider: LLM provider to use (defaults to env var or HUGGINGFACE) | |
| api_key: API key for the provider (reads from env if not provided) | |
| model: Model name to use (provider-specific defaults if not provided) | |
| """ | |
| # Determine provider | |
| if provider is None: | |
| provider_str = os.getenv("LLM_PROVIDER", "huggingface").lower() | |
| self.provider = LLMProvider(provider_str) | |
| else: | |
| self.provider = provider | |
| # Set API key | |
| if api_key: | |
| self.api_key = api_key | |
| else: | |
| if self.provider == LLMProvider.OPENAI: | |
| self.api_key = os.getenv("OPENAI_API_KEY") | |
| elif self.provider == LLMProvider.ANTHROPIC: | |
| self.api_key = os.getenv("ANTHROPIC_API_KEY") | |
| elif self.provider == LLMProvider.HUGGINGFACE: | |
| self.api_key = os.getenv("HUGGINGFACE_API_KEY") | |
| else: | |
| self.api_key = None | |
| # Set model | |
| if model: | |
| self.model = model | |
| else: | |
| self.model = self._get_default_model() | |
| # Set API endpoint | |
| self.api_url = self._get_api_url() | |
| # Cache for local models (transformers) | |
| self.tokenizer = None | |
| self.local_model = None | |
| self.device = None | |
| def _get_default_model(self) -> str: | |
| """Get default model for each provider with fallback chain""" | |
| defaults = { | |
| LLMProvider.OPENAI: "gpt-4o-mini", | |
| LLMProvider.ANTHROPIC: "claude-3-5-sonnet-20241022", | |
| # Preferred: Mistral-7B (better instruction following, higher quality) | |
| # Fallback chain for HF Inference API if primary is gated/unavailable | |
| LLMProvider.HUGGINGFACE: "mistralai/Mistral-7B-Instruct-v0.1", | |
| LLMProvider.LM_STUDIO: "google/gemma-3-27b" | |
| } | |
| return os.getenv("LLM_MODEL", defaults[self.provider]) | |
| def get_fallback_models(self) -> List[str]: | |
| """Get fallback model chain for HF Inference API""" | |
| if self.provider == LLMProvider.HUGGINGFACE: | |
| return [ | |
| "mistralai/Mistral-7B-Instruct-v0.1", # Primary | |
| "mistralai/Mixtral-8x7B-Instruct-v0.1", # Fallback 1: Better quality | |
| "google/gemma-7b-it", # Fallback 2: Smaller, faster | |
| "microsoft/phi-2", # Fallback 3: Original | |
| ] | |
| return [self.model] | |
| def _get_api_url(self) -> str: | |
| """Get API URL for each provider""" | |
| if self.provider == LLMProvider.OPENAI: | |
| return "https://api.openai.com/v1/chat/completions" | |
| elif self.provider == LLMProvider.ANTHROPIC: | |
| return "https://api.anthropic.com/v1/messages" | |
| elif self.provider == LLMProvider.HUGGINGFACE: | |
| # HuggingFace endpoint - allow override via env variable | |
| # Default uses old endpoint (works until Nov 1, 2025) | |
| default_url = f"https://api-inference.huggingface.co/models/{self.model}" | |
| return os.getenv("HF_INFERENCE_ENDPOINT", default_url) | |
| elif self.provider == LLMProvider.LM_STUDIO: | |
| return os.getenv("LM_STUDIO_URL", "http://192.168.1.245:1234/v1/chat/completions") | |
| def generate(self, | |
| messages: List[Dict[str, str]], | |
| max_tokens: int = 1000, | |
| temperature: float = 0.7, | |
| json_mode: bool = False) -> str: | |
| """ | |
| Generate completion from messages. | |
| Args: | |
| messages: List of message dicts with 'role' and 'content' | |
| max_tokens: Maximum tokens to generate | |
| temperature: Sampling temperature | |
| json_mode: Whether to request JSON output (supported by some providers) | |
| Returns: | |
| Generated text response | |
| """ | |
| try: | |
| if self.provider == LLMProvider.OPENAI: | |
| return self._generate_openai(messages, max_tokens, temperature, json_mode) | |
| elif self.provider == LLMProvider.ANTHROPIC: | |
| return self._generate_anthropic(messages, max_tokens, temperature) | |
| elif self.provider == LLMProvider.HUGGINGFACE: | |
| return self._generate_huggingface(messages, max_tokens, temperature) | |
| elif self.provider == LLMProvider.LM_STUDIO: | |
| return self._generate_lm_studio(messages, max_tokens, temperature) | |
| except Exception as e: | |
| raise Exception(f"LLM generation failed: {str(e)}") | |
| def _generate_openai(self, messages, max_tokens, temperature, json_mode) -> str: | |
| """Generate using OpenAI API""" | |
| headers = { | |
| "Authorization": f"Bearer {self.api_key}", | |
| "Content-Type": "application/json" | |
| } | |
| payload = { | |
| "model": self.model, | |
| "messages": messages, | |
| "max_tokens": max_tokens, | |
| "temperature": temperature | |
| } | |
| if json_mode: | |
| payload["response_format"] = {"type": "json_object"} | |
| response = requests.post(self.api_url, headers=headers, json=payload, timeout=60) | |
| response.raise_for_status() | |
| data = response.json() | |
| return data["choices"][0]["message"]["content"] | |
| def _generate_anthropic(self, messages, max_tokens, temperature) -> str: | |
| """Generate using Anthropic API""" | |
| headers = { | |
| "x-api-key": self.api_key, | |
| "anthropic-version": "2023-06-01", | |
| "Content-Type": "application/json" | |
| } | |
| # Convert messages format (extract system message if present) | |
| system_message = None | |
| converted_messages = [] | |
| for msg in messages: | |
| if msg["role"] == "system": | |
| system_message = msg["content"] | |
| else: | |
| converted_messages.append(msg) | |
| payload = { | |
| "model": self.model, | |
| "messages": converted_messages, | |
| "max_tokens": max_tokens, | |
| "temperature": temperature | |
| } | |
| if system_message: | |
| payload["system"] = system_message | |
| response = requests.post(self.api_url, headers=headers, json=payload, timeout=60) | |
| response.raise_for_status() | |
| data = response.json() | |
| return data["content"][0]["text"] | |
| def _load_local_model(self): | |
| """Load model locally using transformers""" | |
| if not TRANSFORMERS_AVAILABLE: | |
| raise Exception("transformers library not available. Install with: pip install transformers torch") | |
| if self.local_model is not None: | |
| return # Already loaded | |
| print(f"Loading model {self.model} locally...") | |
| # Determine device | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {self.device}") | |
| # Load tokenizer | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model) | |
| # Load model (T5 models use Seq2SeqLM, others use CausalLM) | |
| if "t5" in self.model.lower() or "flan" in self.model.lower(): | |
| self.local_model = AutoModelForSeq2SeqLM.from_pretrained( | |
| self.model, | |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, | |
| low_cpu_mem_usage=True | |
| ) | |
| else: | |
| self.local_model = AutoModelForCausalLM.from_pretrained( | |
| self.model, | |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, | |
| low_cpu_mem_usage=True | |
| ) | |
| self.local_model = self.local_model.to(self.device) | |
| print(f"Model loaded successfully!") | |
| def _generate_huggingface(self, messages, max_tokens, temperature) -> str: | |
| """Generate using local transformers model with fallback chain""" | |
| # Try to load and generate with fallback chain | |
| fallback_models = self.get_fallback_models() | |
| last_error = None | |
| for model_to_try in fallback_models: | |
| try: | |
| # Temporarily set model for this attempt | |
| original_model = self.model | |
| self.model = model_to_try | |
| self.tokenizer = None # Reset tokenizer cache | |
| self.local_model = None # Reset model cache | |
| # Load model if not already loaded | |
| self._load_local_model() | |
| # Convert messages to prompt | |
| prompt = self._messages_to_prompt(messages) | |
| # Tokenize input | |
| inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) | |
| inputs = inputs.to(self.device) | |
| # Generate | |
| with torch.no_grad(): | |
| outputs = self.local_model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| do_sample=temperature > 0, | |
| top_p=0.9, | |
| pad_token_id=self.tokenizer.eos_token_id | |
| ) | |
| # Decode output | |
| generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # For T5 models, the output is just the generated text | |
| # For causal models, we need to remove the input prompt | |
| if "t5" not in self.model.lower() and "flan" not in self.model.lower(): | |
| # Remove the input prompt from output | |
| if generated_text.startswith(prompt): | |
| generated_text = generated_text[len(prompt):].strip() | |
| # Success! Update the default model for future use | |
| self.model = model_to_try | |
| print(f"✓ Successfully using model: {model_to_try}") | |
| return generated_text | |
| except Exception as e: | |
| last_error = e | |
| print(f"⚠ Model {model_to_try} failed: {str(e)[:100]}") | |
| self.model = original_model # Restore original | |
| continue | |
| # All fallbacks failed | |
| raise Exception(f"All HuggingFace models failed. Last error: {str(last_error)}") | |
| def _generate_lm_studio(self, messages, max_tokens, temperature) -> str: | |
| """Generate using LM Studio local API""" | |
| payload = { | |
| "model": self.model, | |
| "messages": messages, | |
| "max_tokens": max_tokens, | |
| "temperature": temperature | |
| } | |
| response = requests.post(self.api_url, json=payload, timeout=60) | |
| response.raise_for_status() | |
| data = response.json() | |
| return data["choices"][0]["message"]["content"] | |
| def _messages_to_prompt(self, messages: List[Dict[str, str]]) -> str: | |
| """Convert message format to simple prompt""" | |
| prompt_parts = [] | |
| for msg in messages: | |
| role = msg["role"].capitalize() | |
| content = msg["content"] | |
| prompt_parts.append(f"{role}: {content}") | |
| prompt_parts.append("Assistant:") | |
| return "\n\n".join(prompt_parts) | |