"""Model chaining logic with Groq fallback.""" import re from typing import Optional, Dict, Any, List from enum import Enum from llama_index.llms.ollama import Ollama from llama_index.llms.groq import Groq from llama_index.core.llms import LLM from litellm import completion import httpx from src.config import config from src.harry_personality import get_harry_prompt class ModelType(Enum): """Model types for routing.""" LOCAL_SMALL = "local_small" LOCAL_LARGE = "local_large" GROQ_API = "groq" class QueryComplexity(Enum): """Query complexity levels.""" SIMPLE = "simple" # Factual questions, definitions MODERATE = "moderate" # Analysis, reasoning COMPLEX = "complex" # Creative, multi-step reasoning class ModelChain: """Intelligent model routing with fallback to Groq.""" def __init__(self): self.models = {} self.groq_available = config.has_groq_api() # Initialize models lazily self._ollama_available = None def check_ollama_available(self) -> bool: """Check if Ollama is running and available.""" if self._ollama_available is not None: return self._ollama_available try: # Try to connect to Ollama response = httpx.get(f"{config.ollama_host}/api/tags", timeout=2.0) self._ollama_available = response.status_code == 200 if self._ollama_available: print("Ollama is available") else: print("Ollama is not responding correctly") except Exception as e: print(f"Ollama not available: {e}") self._ollama_available = False return self._ollama_available def get_model(self, model_type: ModelType) -> Optional[LLM]: """Get or initialize a model.""" if model_type in self.models: return self.models[model_type] if model_type == ModelType.GROQ_API: if not self.groq_available: print("Groq API key not configured") return None try: # For groq/compound model, we'll use litellm # Return a wrapper that uses litellm return "groq" # Special marker for litellm usage except Exception as e: print(f"Failed to initialize Groq: {e}") return None elif model_type in [ModelType.LOCAL_SMALL, ModelType.LOCAL_LARGE]: if not self.check_ollama_available(): print("Ollama not available, falling back to Groq") return None model_config = config.get_model_config(model_type.value) try: model = Ollama( model=model_config["model"], base_url=config.ollama_host, temperature=model_config["temperature"], request_timeout=120.0, ) self.models[model_type] = model print(f"Initialized {model_type.value} model: {model_config['model']}") return model except Exception as e: print(f"Failed to initialize Ollama model: {e}") return None return None def analyze_query_complexity(self, query: str, context_size: int = 0) -> QueryComplexity: """Analyze query complexity to determine which model to use.""" query_lower = query.lower() # Simple queries - factual questions simple_patterns = [ r"what is", r"who is", r"when did", r"where is", r"define", r"list", r"name", r"how many", r"yes or no", ] # Complex queries - requiring reasoning or creativity complex_patterns = [ r"explain why", r"analyze", r"compare and contrast", r"what would happen if", r"imagine", r"create", r"write a", r"develop", r"design", r"evaluate", r"critique", r"synthesize", ] # Check for simple patterns for pattern in simple_patterns: if re.search(pattern, query_lower): return QueryComplexity.SIMPLE # Check for complex patterns for pattern in complex_patterns: if re.search(pattern, query_lower): return QueryComplexity.COMPLEX # Check query length and context size if len(query.split()) > 50 or context_size > config.max_local_context_size: return QueryComplexity.COMPLEX # Default to moderate return QueryComplexity.MODERATE def route_query( self, query: str, context: Optional[str] = None, force_model: Optional[ModelType] = None ) -> ModelType: """Determine which model to use for the query.""" if force_model: return force_model context_size = len(context) if context else 0 complexity = self.analyze_query_complexity(query, context_size) # Check Ollama availability ollama_available = self.check_ollama_available() # Routing logic if complexity == QueryComplexity.SIMPLE: if ollama_available: return ModelType.LOCAL_SMALL elif self.groq_available: return ModelType.GROQ_API elif complexity == QueryComplexity.MODERATE: if ollama_available: return ModelType.LOCAL_LARGE elif self.groq_available: return ModelType.GROQ_API else: # COMPLEX if self.groq_available: return ModelType.GROQ_API elif ollama_available: return ModelType.LOCAL_LARGE # Final fallback if self.groq_available: return ModelType.GROQ_API elif ollama_available: return ModelType.LOCAL_SMALL else: raise RuntimeError("No models available! Please check Ollama or configure Groq API key.") def generate_response( self, query: str, context: Optional[str] = None, force_model: Optional[ModelType] = None, stream: bool = False ) -> Dict[str, Any]: """Generate response using appropriate model.""" # Determine which model to use model_type = self.route_query(query, context, force_model) print(f"Using model: {model_type.value}") # Prepare prompt with Harry's personality if context: prompt = get_harry_prompt(query, context) else: # Even without context, use Harry's voice prompt = get_harry_prompt(query, "No specific context available - respond based on your general knowledge of HPMOR.") # Try primary model try: model = self.get_model(model_type) if model == "groq": # Special handling for Groq via litellm # Use litellm for Groq response = completion( model=f"groq/{config.groq_model}", messages=[{"role": "user", "content": prompt}], api_key=config.groq_api_key, temperature=0.7, max_tokens=2048, stream=stream ) if stream: return { "response": response, "model_used": model_type.value, "streaming": True } else: return { "response": response.choices[0].message.content, "model_used": model_type.value, "tokens_used": response.usage.total_tokens if hasattr(response, 'usage') else None } elif model: # Use LlamaIndex model if stream: response = model.stream_complete(prompt) else: response = model.complete(prompt) return { "response": response, "model_used": model_type.value, "streaming": stream } except Exception as e: print(f"Error with {model_type.value}: {e}") # Try fallback if model_type != ModelType.GROQ_API and self.groq_available: print("Falling back to Groq API...") model_type = ModelType.GROQ_API try: response = completion( model=f"groq/{config.groq_model}", messages=[{"role": "user", "content": prompt}], api_key=config.groq_api_key, temperature=0.7, max_tokens=2048, stream=stream ) if stream: return { "response": response, "model_used": model_type.value, "streaming": True, "fallback": True } else: return { "response": response.choices[0].message.content, "model_used": model_type.value, "tokens_used": response.usage.total_tokens if hasattr(response, 'usage') else None, "fallback": True } except Exception as e2: print(f"Fallback to Groq also failed: {e2}") raise RuntimeError(f"All models failed. Last error: {e2}") raise RuntimeError("No models available for response generation") def main(): """Test model chaining.""" chain = ModelChain() # Test queries of different complexities test_queries = [ ("What is Harry's full name?", QueryComplexity.SIMPLE), ("Explain Harry's reasoning about magic", QueryComplexity.MODERATE), ("Analyze the philosophical implications of Harry's scientific approach to magic", QueryComplexity.COMPLEX), ] for query, expected_complexity in test_queries: print(f"\nQuery: {query}") complexity = chain.analyze_query_complexity(query) print(f"Detected complexity: {complexity}") print(f"Expected complexity: {expected_complexity}") try: model_type = chain.route_query(query) print(f"Selected model: {model_type.value}") # Generate response result = chain.generate_response(query) print(f"Model used: {result['model_used']}") print(f"Response preview: {str(result['response'])[:200]}...") except Exception as e: print(f"Error: {e}") if __name__ == "__main__": main()