|
|
"""Model chaining logic with Groq fallback.""" |
|
|
|
|
|
import re |
|
|
from typing import Optional, Dict, Any, List |
|
|
from enum import Enum |
|
|
|
|
|
from llama_index.llms.ollama import Ollama |
|
|
from llama_index.llms.groq import Groq |
|
|
from llama_index.core.llms import LLM |
|
|
from litellm import completion |
|
|
import httpx |
|
|
|
|
|
from src.config import config |
|
|
from src.harry_personality import get_harry_prompt |
|
|
|
|
|
|
|
|
class ModelType(Enum): |
|
|
"""Model types for routing.""" |
|
|
LOCAL_SMALL = "local_small" |
|
|
LOCAL_LARGE = "local_large" |
|
|
GROQ_API = "groq" |
|
|
|
|
|
|
|
|
class QueryComplexity(Enum): |
|
|
"""Query complexity levels.""" |
|
|
SIMPLE = "simple" |
|
|
MODERATE = "moderate" |
|
|
COMPLEX = "complex" |
|
|
|
|
|
|
|
|
class ModelChain: |
|
|
"""Intelligent model routing with fallback to Groq.""" |
|
|
|
|
|
def __init__(self): |
|
|
self.models = {} |
|
|
self.groq_available = config.has_groq_api() |
|
|
|
|
|
|
|
|
self._ollama_available = None |
|
|
|
|
|
def check_ollama_available(self) -> bool: |
|
|
"""Check if Ollama is running and available.""" |
|
|
if self._ollama_available is not None: |
|
|
return self._ollama_available |
|
|
|
|
|
try: |
|
|
|
|
|
response = httpx.get(f"{config.ollama_host}/api/tags", timeout=2.0) |
|
|
self._ollama_available = response.status_code == 200 |
|
|
if self._ollama_available: |
|
|
print("Ollama is available") |
|
|
else: |
|
|
print("Ollama is not responding correctly") |
|
|
except Exception as e: |
|
|
print(f"Ollama not available: {e}") |
|
|
self._ollama_available = False |
|
|
|
|
|
return self._ollama_available |
|
|
|
|
|
def get_model(self, model_type: ModelType) -> Optional[LLM]: |
|
|
"""Get or initialize a model.""" |
|
|
if model_type in self.models: |
|
|
return self.models[model_type] |
|
|
|
|
|
if model_type == ModelType.GROQ_API: |
|
|
if not self.groq_available: |
|
|
print("Groq API key not configured") |
|
|
return None |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
return "groq" |
|
|
except Exception as e: |
|
|
print(f"Failed to initialize Groq: {e}") |
|
|
return None |
|
|
|
|
|
elif model_type in [ModelType.LOCAL_SMALL, ModelType.LOCAL_LARGE]: |
|
|
if not self.check_ollama_available(): |
|
|
print("Ollama not available, falling back to Groq") |
|
|
return None |
|
|
|
|
|
model_config = config.get_model_config(model_type.value) |
|
|
try: |
|
|
model = Ollama( |
|
|
model=model_config["model"], |
|
|
base_url=config.ollama_host, |
|
|
temperature=model_config["temperature"], |
|
|
request_timeout=120.0, |
|
|
) |
|
|
self.models[model_type] = model |
|
|
print(f"Initialized {model_type.value} model: {model_config['model']}") |
|
|
return model |
|
|
except Exception as e: |
|
|
print(f"Failed to initialize Ollama model: {e}") |
|
|
return None |
|
|
|
|
|
return None |
|
|
|
|
|
def analyze_query_complexity(self, query: str, context_size: int = 0) -> QueryComplexity: |
|
|
"""Analyze query complexity to determine which model to use.""" |
|
|
query_lower = query.lower() |
|
|
|
|
|
|
|
|
simple_patterns = [ |
|
|
r"what is", |
|
|
r"who is", |
|
|
r"when did", |
|
|
r"where is", |
|
|
r"define", |
|
|
r"list", |
|
|
r"name", |
|
|
r"how many", |
|
|
r"yes or no", |
|
|
] |
|
|
|
|
|
|
|
|
complex_patterns = [ |
|
|
r"explain why", |
|
|
r"analyze", |
|
|
r"compare and contrast", |
|
|
r"what would happen if", |
|
|
r"imagine", |
|
|
r"create", |
|
|
r"write a", |
|
|
r"develop", |
|
|
r"design", |
|
|
r"evaluate", |
|
|
r"critique", |
|
|
r"synthesize", |
|
|
] |
|
|
|
|
|
|
|
|
for pattern in simple_patterns: |
|
|
if re.search(pattern, query_lower): |
|
|
return QueryComplexity.SIMPLE |
|
|
|
|
|
|
|
|
for pattern in complex_patterns: |
|
|
if re.search(pattern, query_lower): |
|
|
return QueryComplexity.COMPLEX |
|
|
|
|
|
|
|
|
if len(query.split()) > 50 or context_size > config.max_local_context_size: |
|
|
return QueryComplexity.COMPLEX |
|
|
|
|
|
|
|
|
return QueryComplexity.MODERATE |
|
|
|
|
|
def route_query( |
|
|
self, |
|
|
query: str, |
|
|
context: Optional[str] = None, |
|
|
force_model: Optional[ModelType] = None |
|
|
) -> ModelType: |
|
|
"""Determine which model to use for the query.""" |
|
|
if force_model: |
|
|
return force_model |
|
|
|
|
|
context_size = len(context) if context else 0 |
|
|
complexity = self.analyze_query_complexity(query, context_size) |
|
|
|
|
|
|
|
|
ollama_available = self.check_ollama_available() |
|
|
|
|
|
|
|
|
if complexity == QueryComplexity.SIMPLE: |
|
|
if ollama_available: |
|
|
return ModelType.LOCAL_SMALL |
|
|
elif self.groq_available: |
|
|
return ModelType.GROQ_API |
|
|
elif complexity == QueryComplexity.MODERATE: |
|
|
if ollama_available: |
|
|
return ModelType.LOCAL_LARGE |
|
|
elif self.groq_available: |
|
|
return ModelType.GROQ_API |
|
|
else: |
|
|
if self.groq_available: |
|
|
return ModelType.GROQ_API |
|
|
elif ollama_available: |
|
|
return ModelType.LOCAL_LARGE |
|
|
|
|
|
|
|
|
if self.groq_available: |
|
|
return ModelType.GROQ_API |
|
|
elif ollama_available: |
|
|
return ModelType.LOCAL_SMALL |
|
|
else: |
|
|
raise RuntimeError("No models available! Please check Ollama or configure Groq API key.") |
|
|
|
|
|
def generate_response( |
|
|
self, |
|
|
query: str, |
|
|
context: Optional[str] = None, |
|
|
force_model: Optional[ModelType] = None, |
|
|
stream: bool = False |
|
|
) -> Dict[str, Any]: |
|
|
"""Generate response using appropriate model.""" |
|
|
|
|
|
model_type = self.route_query(query, context, force_model) |
|
|
print(f"Using model: {model_type.value}") |
|
|
|
|
|
|
|
|
if context: |
|
|
prompt = get_harry_prompt(query, context) |
|
|
else: |
|
|
|
|
|
prompt = get_harry_prompt(query, "No specific context available - respond based on your general knowledge of HPMOR.") |
|
|
|
|
|
|
|
|
try: |
|
|
model = self.get_model(model_type) |
|
|
|
|
|
if model == "groq": |
|
|
|
|
|
response = completion( |
|
|
model=f"groq/{config.groq_model}", |
|
|
messages=[{"role": "user", "content": prompt}], |
|
|
api_key=config.groq_api_key, |
|
|
temperature=0.7, |
|
|
max_tokens=2048, |
|
|
stream=stream |
|
|
) |
|
|
|
|
|
if stream: |
|
|
return { |
|
|
"response": response, |
|
|
"model_used": model_type.value, |
|
|
"streaming": True |
|
|
} |
|
|
else: |
|
|
return { |
|
|
"response": response.choices[0].message.content, |
|
|
"model_used": model_type.value, |
|
|
"tokens_used": response.usage.total_tokens if hasattr(response, 'usage') else None |
|
|
} |
|
|
|
|
|
elif model: |
|
|
|
|
|
if stream: |
|
|
response = model.stream_complete(prompt) |
|
|
else: |
|
|
response = model.complete(prompt) |
|
|
|
|
|
return { |
|
|
"response": response, |
|
|
"model_used": model_type.value, |
|
|
"streaming": stream |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error with {model_type.value}: {e}") |
|
|
|
|
|
|
|
|
if model_type != ModelType.GROQ_API and self.groq_available: |
|
|
print("Falling back to Groq API...") |
|
|
model_type = ModelType.GROQ_API |
|
|
try: |
|
|
response = completion( |
|
|
model=f"groq/{config.groq_model}", |
|
|
messages=[{"role": "user", "content": prompt}], |
|
|
api_key=config.groq_api_key, |
|
|
temperature=0.7, |
|
|
max_tokens=2048, |
|
|
stream=stream |
|
|
) |
|
|
|
|
|
if stream: |
|
|
return { |
|
|
"response": response, |
|
|
"model_used": model_type.value, |
|
|
"streaming": True, |
|
|
"fallback": True |
|
|
} |
|
|
else: |
|
|
return { |
|
|
"response": response.choices[0].message.content, |
|
|
"model_used": model_type.value, |
|
|
"tokens_used": response.usage.total_tokens if hasattr(response, 'usage') else None, |
|
|
"fallback": True |
|
|
} |
|
|
except Exception as e2: |
|
|
print(f"Fallback to Groq also failed: {e2}") |
|
|
raise RuntimeError(f"All models failed. Last error: {e2}") |
|
|
|
|
|
raise RuntimeError("No models available for response generation") |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Test model chaining.""" |
|
|
chain = ModelChain() |
|
|
|
|
|
|
|
|
test_queries = [ |
|
|
("What is Harry's full name?", QueryComplexity.SIMPLE), |
|
|
("Explain Harry's reasoning about magic", QueryComplexity.MODERATE), |
|
|
("Analyze the philosophical implications of Harry's scientific approach to magic", QueryComplexity.COMPLEX), |
|
|
] |
|
|
|
|
|
for query, expected_complexity in test_queries: |
|
|
print(f"\nQuery: {query}") |
|
|
complexity = chain.analyze_query_complexity(query) |
|
|
print(f"Detected complexity: {complexity}") |
|
|
print(f"Expected complexity: {expected_complexity}") |
|
|
|
|
|
try: |
|
|
model_type = chain.route_query(query) |
|
|
print(f"Selected model: {model_type.value}") |
|
|
|
|
|
|
|
|
result = chain.generate_response(query) |
|
|
print(f"Model used: {result['model_used']}") |
|
|
print(f"Response preview: {str(result['response'])[:200]}...") |
|
|
except Exception as e: |
|
|
print(f"Error: {e}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |