hpmor / src /model_chain.py
deenaik's picture
Update README and chat interface to enhance user interaction with Harry Potter character. Improved prompts and example questions for better engagement. Refactored model chain to utilize Harry's personality in responses.
d659883
"""Model chaining logic with Groq fallback."""
import re
from typing import Optional, Dict, Any, List
from enum import Enum
from llama_index.llms.ollama import Ollama
from llama_index.llms.groq import Groq
from llama_index.core.llms import LLM
from litellm import completion
import httpx
from src.config import config
from src.harry_personality import get_harry_prompt
class ModelType(Enum):
"""Model types for routing."""
LOCAL_SMALL = "local_small"
LOCAL_LARGE = "local_large"
GROQ_API = "groq"
class QueryComplexity(Enum):
"""Query complexity levels."""
SIMPLE = "simple" # Factual questions, definitions
MODERATE = "moderate" # Analysis, reasoning
COMPLEX = "complex" # Creative, multi-step reasoning
class ModelChain:
"""Intelligent model routing with fallback to Groq."""
def __init__(self):
self.models = {}
self.groq_available = config.has_groq_api()
# Initialize models lazily
self._ollama_available = None
def check_ollama_available(self) -> bool:
"""Check if Ollama is running and available."""
if self._ollama_available is not None:
return self._ollama_available
try:
# Try to connect to Ollama
response = httpx.get(f"{config.ollama_host}/api/tags", timeout=2.0)
self._ollama_available = response.status_code == 200
if self._ollama_available:
print("Ollama is available")
else:
print("Ollama is not responding correctly")
except Exception as e:
print(f"Ollama not available: {e}")
self._ollama_available = False
return self._ollama_available
def get_model(self, model_type: ModelType) -> Optional[LLM]:
"""Get or initialize a model."""
if model_type in self.models:
return self.models[model_type]
if model_type == ModelType.GROQ_API:
if not self.groq_available:
print("Groq API key not configured")
return None
try:
# For groq/compound model, we'll use litellm
# Return a wrapper that uses litellm
return "groq" # Special marker for litellm usage
except Exception as e:
print(f"Failed to initialize Groq: {e}")
return None
elif model_type in [ModelType.LOCAL_SMALL, ModelType.LOCAL_LARGE]:
if not self.check_ollama_available():
print("Ollama not available, falling back to Groq")
return None
model_config = config.get_model_config(model_type.value)
try:
model = Ollama(
model=model_config["model"],
base_url=config.ollama_host,
temperature=model_config["temperature"],
request_timeout=120.0,
)
self.models[model_type] = model
print(f"Initialized {model_type.value} model: {model_config['model']}")
return model
except Exception as e:
print(f"Failed to initialize Ollama model: {e}")
return None
return None
def analyze_query_complexity(self, query: str, context_size: int = 0) -> QueryComplexity:
"""Analyze query complexity to determine which model to use."""
query_lower = query.lower()
# Simple queries - factual questions
simple_patterns = [
r"what is",
r"who is",
r"when did",
r"where is",
r"define",
r"list",
r"name",
r"how many",
r"yes or no",
]
# Complex queries - requiring reasoning or creativity
complex_patterns = [
r"explain why",
r"analyze",
r"compare and contrast",
r"what would happen if",
r"imagine",
r"create",
r"write a",
r"develop",
r"design",
r"evaluate",
r"critique",
r"synthesize",
]
# Check for simple patterns
for pattern in simple_patterns:
if re.search(pattern, query_lower):
return QueryComplexity.SIMPLE
# Check for complex patterns
for pattern in complex_patterns:
if re.search(pattern, query_lower):
return QueryComplexity.COMPLEX
# Check query length and context size
if len(query.split()) > 50 or context_size > config.max_local_context_size:
return QueryComplexity.COMPLEX
# Default to moderate
return QueryComplexity.MODERATE
def route_query(
self,
query: str,
context: Optional[str] = None,
force_model: Optional[ModelType] = None
) -> ModelType:
"""Determine which model to use for the query."""
if force_model:
return force_model
context_size = len(context) if context else 0
complexity = self.analyze_query_complexity(query, context_size)
# Check Ollama availability
ollama_available = self.check_ollama_available()
# Routing logic
if complexity == QueryComplexity.SIMPLE:
if ollama_available:
return ModelType.LOCAL_SMALL
elif self.groq_available:
return ModelType.GROQ_API
elif complexity == QueryComplexity.MODERATE:
if ollama_available:
return ModelType.LOCAL_LARGE
elif self.groq_available:
return ModelType.GROQ_API
else: # COMPLEX
if self.groq_available:
return ModelType.GROQ_API
elif ollama_available:
return ModelType.LOCAL_LARGE
# Final fallback
if self.groq_available:
return ModelType.GROQ_API
elif ollama_available:
return ModelType.LOCAL_SMALL
else:
raise RuntimeError("No models available! Please check Ollama or configure Groq API key.")
def generate_response(
self,
query: str,
context: Optional[str] = None,
force_model: Optional[ModelType] = None,
stream: bool = False
) -> Dict[str, Any]:
"""Generate response using appropriate model."""
# Determine which model to use
model_type = self.route_query(query, context, force_model)
print(f"Using model: {model_type.value}")
# Prepare prompt with Harry's personality
if context:
prompt = get_harry_prompt(query, context)
else:
# Even without context, use Harry's voice
prompt = get_harry_prompt(query, "No specific context available - respond based on your general knowledge of HPMOR.")
# Try primary model
try:
model = self.get_model(model_type)
if model == "groq": # Special handling for Groq via litellm
# Use litellm for Groq
response = completion(
model=f"groq/{config.groq_model}",
messages=[{"role": "user", "content": prompt}],
api_key=config.groq_api_key,
temperature=0.7,
max_tokens=2048,
stream=stream
)
if stream:
return {
"response": response,
"model_used": model_type.value,
"streaming": True
}
else:
return {
"response": response.choices[0].message.content,
"model_used": model_type.value,
"tokens_used": response.usage.total_tokens if hasattr(response, 'usage') else None
}
elif model:
# Use LlamaIndex model
if stream:
response = model.stream_complete(prompt)
else:
response = model.complete(prompt)
return {
"response": response,
"model_used": model_type.value,
"streaming": stream
}
except Exception as e:
print(f"Error with {model_type.value}: {e}")
# Try fallback
if model_type != ModelType.GROQ_API and self.groq_available:
print("Falling back to Groq API...")
model_type = ModelType.GROQ_API
try:
response = completion(
model=f"groq/{config.groq_model}",
messages=[{"role": "user", "content": prompt}],
api_key=config.groq_api_key,
temperature=0.7,
max_tokens=2048,
stream=stream
)
if stream:
return {
"response": response,
"model_used": model_type.value,
"streaming": True,
"fallback": True
}
else:
return {
"response": response.choices[0].message.content,
"model_used": model_type.value,
"tokens_used": response.usage.total_tokens if hasattr(response, 'usage') else None,
"fallback": True
}
except Exception as e2:
print(f"Fallback to Groq also failed: {e2}")
raise RuntimeError(f"All models failed. Last error: {e2}")
raise RuntimeError("No models available for response generation")
def main():
"""Test model chaining."""
chain = ModelChain()
# Test queries of different complexities
test_queries = [
("What is Harry's full name?", QueryComplexity.SIMPLE),
("Explain Harry's reasoning about magic", QueryComplexity.MODERATE),
("Analyze the philosophical implications of Harry's scientific approach to magic", QueryComplexity.COMPLEX),
]
for query, expected_complexity in test_queries:
print(f"\nQuery: {query}")
complexity = chain.analyze_query_complexity(query)
print(f"Detected complexity: {complexity}")
print(f"Expected complexity: {expected_complexity}")
try:
model_type = chain.route_query(query)
print(f"Selected model: {model_type.value}")
# Generate response
result = chain.generate_response(query)
print(f"Model used: {result['model_used']}")
print(f"Response preview: {str(result['response'])[:200]}...")
except Exception as e:
print(f"Error: {e}")
if __name__ == "__main__":
main()