| |
| import logging |
| import asyncio |
| from typing import Dict, Optional |
| from .models_config import LLM_CONFIG |
|
|
| logger = logging.getLogger(__name__) |
|
|
| class LLMRouter: |
| def __init__(self, hf_token, use_local_models: bool = True): |
| self.hf_token = hf_token |
| self.health_status = {} |
| self.use_local_models = use_local_models |
| self.local_loader = None |
| |
| logger.info("LLMRouter initialized") |
| if hf_token: |
| logger.info("HF token available") |
| else: |
| logger.warning("No HF token provided") |
| |
| |
| if self.use_local_models: |
| try: |
| from .local_model_loader import LocalModelLoader |
| self.local_loader = LocalModelLoader() |
| logger.info("✓ Local model loader initialized (GPU-based inference)") |
| |
| |
| |
| logger.info("Models will be loaded on-demand for faster startup") |
| except Exception as e: |
| logger.warning(f"Could not initialize local model loader: {e}. Falling back to API.") |
| logger.warning("This is normal if transformers/torch not available") |
| self.use_local_models = False |
| self.local_loader = None |
| |
| async def route_inference(self, task_type: str, prompt: str, **kwargs): |
| """ |
| Smart routing based on task specialization |
| Tries local models first, falls back to HF Inference API if needed |
| """ |
| logger.info(f"Routing inference for task: {task_type}") |
| model_config = self._select_model(task_type) |
| logger.info(f"Selected model: {model_config['model_id']}") |
| |
| |
| if self.use_local_models and self.local_loader: |
| try: |
| |
| if task_type == "embedding_generation": |
| result = await self._call_local_embedding(model_config, prompt, **kwargs) |
| else: |
| result = await self._call_local_model(model_config, prompt, task_type, **kwargs) |
| |
| if result is not None: |
| logger.info(f"Inference complete for {task_type} (local model)") |
| return result |
| else: |
| logger.warning("Local model returned None, falling back to API") |
| except Exception as e: |
| logger.warning(f"Local model inference failed: {e}. Falling back to API.") |
| logger.debug("Exception details:", exc_info=True) |
| |
| |
| logger.info("Using HF Inference API") |
| |
| if not await self._is_model_healthy(model_config["model_id"]): |
| logger.warning(f"Model unhealthy, using fallback") |
| model_config = self._get_fallback_model(task_type) |
| logger.info(f"Fallback model: {model_config['model_id']}") |
| |
| result = await self._call_hf_endpoint(model_config, prompt, task_type, **kwargs) |
| logger.info(f"Inference complete for {task_type}") |
| return result |
| |
| async def _call_local_model(self, model_config: dict, prompt: str, task_type: str, **kwargs) -> Optional[str]: |
| """Call local model for inference.""" |
| if not self.local_loader: |
| return None |
| |
| model_id = model_config["model_id"] |
| max_tokens = kwargs.get('max_tokens', 512) |
| temperature = kwargs.get('temperature', 0.7) |
| |
| try: |
| |
| if model_id not in self.local_loader.loaded_models: |
| logger.info(f"Loading model {model_id} on demand...") |
| self.local_loader.load_chat_model(model_id, load_in_8bit=False) |
| |
| |
| messages = [{"role": "user", "content": prompt}] |
| |
| |
| result = await asyncio.to_thread( |
| self.local_loader.generate_chat_completion, |
| model_id=model_id, |
| messages=messages, |
| max_tokens=max_tokens, |
| temperature=temperature |
| ) |
| |
| logger.info(f"Local model {model_id} generated response (length: {len(result)})") |
| logger.info("=" * 80) |
| logger.info("LOCAL MODEL RESPONSE:") |
| logger.info("=" * 80) |
| logger.info(f"Model: {model_id}") |
| logger.info(f"Task Type: {task_type}") |
| logger.info(f"Response Length: {len(result)} characters") |
| logger.info("-" * 40) |
| logger.info("FULL RESPONSE CONTENT:") |
| logger.info("-" * 40) |
| logger.info(result) |
| logger.info("-" * 40) |
| logger.info("END OF RESPONSE") |
| logger.info("=" * 80) |
| |
| return result |
| |
| except Exception as e: |
| logger.error(f"Error calling local model: {e}", exc_info=True) |
| return None |
| |
| async def _call_local_embedding(self, model_config: dict, text: str, **kwargs) -> Optional[list]: |
| """Call local embedding model.""" |
| if not self.local_loader: |
| return None |
| |
| model_id = model_config["model_id"] |
| |
| try: |
| |
| if model_id not in self.local_loader.loaded_embedding_models: |
| logger.info(f"Loading embedding model {model_id} on demand...") |
| self.local_loader.load_embedding_model(model_id) |
| |
| |
| embedding = await asyncio.to_thread( |
| self.local_loader.get_embedding, |
| model_id=model_id, |
| text=text |
| ) |
| |
| logger.info(f"Local embedding model {model_id} generated vector (dim: {len(embedding)})") |
| return embedding |
| |
| except Exception as e: |
| logger.error(f"Error calling local embedding model: {e}", exc_info=True) |
| return None |
| |
| def _select_model(self, task_type: str) -> dict: |
| model_map = { |
| "intent_classification": LLM_CONFIG["models"]["classification_specialist"], |
| "embedding_generation": LLM_CONFIG["models"]["embedding_specialist"], |
| "safety_check": LLM_CONFIG["models"]["safety_checker"], |
| "general_reasoning": LLM_CONFIG["models"]["reasoning_primary"], |
| "response_synthesis": LLM_CONFIG["models"]["reasoning_primary"] |
| } |
| return model_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"]) |
| |
| async def _is_model_healthy(self, model_id: str) -> bool: |
| """ |
| Check if the model is healthy and available |
| Mark models as healthy by default - actual availability checked at API call time |
| """ |
| |
| if model_id in self.health_status: |
| return self.health_status[model_id] |
| |
| |
| self.health_status[model_id] = True |
| return True |
| |
| def _get_fallback_model(self, task_type: str) -> dict: |
| """ |
| Get fallback model configuration for the task type |
| """ |
| |
| fallback_map = { |
| "intent_classification": LLM_CONFIG["models"]["reasoning_primary"], |
| "embedding_generation": LLM_CONFIG["models"]["embedding_specialist"], |
| "safety_check": LLM_CONFIG["models"]["reasoning_primary"], |
| "general_reasoning": LLM_CONFIG["models"]["reasoning_primary"], |
| "response_synthesis": LLM_CONFIG["models"]["reasoning_primary"] |
| } |
| return fallback_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"]) |
| |
| async def _call_hf_endpoint(self, model_config: dict, prompt: str, task_type: str, **kwargs): |
| """ |
| FIXED: Make actual call to Hugging Face Chat Completions API |
| Uses the correct chat completions protocol with retry logic and exponential backoff |
| |
| IMPORTANT: task_type parameter is now properly included in the method signature |
| """ |
| |
| max_retries = kwargs.get('max_retries', 3) |
| initial_delay = kwargs.get('initial_delay', 1.0) |
| max_delay = kwargs.get('max_delay', 16.0) |
| timeout = kwargs.get('timeout', 30) |
| |
| try: |
| import requests |
| from requests.exceptions import Timeout, RequestException, ConnectionError as RequestsConnectionError |
| |
| model_id = model_config["model_id"] |
| |
| |
| api_url = "https://router.huggingface.co/v1/chat/completions" |
| |
| logger.info(f"Calling HF Chat Completions API for model: {model_id}") |
| logger.debug(f"Prompt length: {len(prompt)}") |
| logger.info("=" * 80) |
| logger.info("LLM API REQUEST - COMPLETE PROMPT:") |
| logger.info("=" * 80) |
| logger.info(f"Model: {model_id}") |
| |
| |
| logger.info(f"Task Type: {task_type}") |
| logger.info(f"Prompt Length: {len(prompt)} characters") |
| logger.info("-" * 40) |
| logger.info("FULL PROMPT CONTENT:") |
| logger.info("-" * 40) |
| logger.info(prompt) |
| logger.info("-" * 40) |
| logger.info("END OF PROMPT") |
| logger.info("=" * 80) |
| |
| |
| max_tokens = kwargs.get('max_tokens', 512) |
| temperature = kwargs.get('temperature', 0.7) |
| |
| payload = { |
| "model": model_id, |
| "messages": [ |
| { |
| "role": "user", |
| "content": prompt |
| } |
| ], |
| "max_tokens": max_tokens, |
| "temperature": temperature, |
| "stream": False |
| } |
| |
| headers = { |
| "Authorization": f"Bearer {self.hf_token}", |
| "Content-Type": "application/json" |
| } |
| |
| |
| last_exception = None |
| for attempt in range(max_retries + 1): |
| try: |
| if attempt > 0: |
| |
| delay = min(initial_delay * (2 ** (attempt - 1)), max_delay) |
| logger.warning(f"Retry attempt {attempt}/{max_retries} after {delay:.1f}s delay (exponential backoff)") |
| await asyncio.sleep(delay) |
| |
| logger.info(f"Sending request to: {api_url} (attempt {attempt + 1}/{max_retries + 1})") |
| logger.debug(f"Payload: {payload}") |
| |
| response = requests.post(api_url, json=payload, headers=headers, timeout=timeout) |
| |
| if response.status_code == 200: |
| result = response.json() |
| logger.debug(f"Raw response: {result}") |
| |
| if 'choices' in result and len(result['choices']) > 0: |
| generated_text = result['choices'][0]['message']['content'] |
| |
| if not generated_text or generated_text.strip() == "": |
| logger.warning(f"Empty or invalid response, using fallback") |
| return None |
| |
| if attempt > 0: |
| logger.info(f"Successfully retrieved response after {attempt} retry attempts") |
| |
| logger.info(f"HF API returned response (length: {len(generated_text)})") |
| logger.info("=" * 80) |
| logger.info("COMPLETE LLM API RESPONSE:") |
| logger.info("=" * 80) |
| logger.info(f"Model: {model_id}") |
| |
| |
| logger.info(f"Task Type: {task_type}") |
| logger.info(f"Response Length: {len(generated_text)} characters") |
| logger.info("-" * 40) |
| logger.info("FULL RESPONSE CONTENT:") |
| logger.info("-" * 40) |
| logger.info(generated_text) |
| logger.info("-" * 40) |
| logger.info("END OF LLM RESPONSE") |
| logger.info("=" * 80) |
| return generated_text |
| else: |
| logger.error(f"Unexpected response format: {result}") |
| return None |
| elif response.status_code == 503: |
| |
| if attempt < max_retries: |
| logger.warning(f"Model loading (503), will retry (attempt {attempt + 1}/{max_retries + 1})") |
| last_exception = Exception(f"Model loading (503)") |
| continue |
| else: |
| |
| logger.warning(f"Model loading (503) after {max_retries} retries, trying fallback model") |
| fallback_config = self._get_fallback_model(task_type) |
| |
| |
| return await self._call_hf_endpoint(fallback_config, prompt, task_type, **kwargs) |
| else: |
| |
| logger.error(f"HF API error: {response.status_code} - {response.text}") |
| return None |
| |
| except Timeout as e: |
| last_exception = e |
| if attempt < max_retries: |
| logger.warning(f"Request timeout (attempt {attempt + 1}/{max_retries + 1}): {str(e)}") |
| continue |
| else: |
| logger.error(f"Request timeout after {max_retries} retries: {str(e)}") |
| |
| logger.warning("Attempting fallback model due to persistent timeout") |
| fallback_config = self._get_fallback_model(task_type) |
| return await self._call_hf_endpoint(fallback_config, prompt, task_type, **kwargs) |
| |
| except (RequestsConnectionError, RequestException) as e: |
| last_exception = e |
| if attempt < max_retries: |
| logger.warning(f"Connection error (attempt {attempt + 1}/{max_retries + 1}): {str(e)}") |
| continue |
| else: |
| logger.error(f"Connection error after {max_retries} retries: {str(e)}") |
| |
| logger.warning("Attempting fallback model due to persistent connection error") |
| fallback_config = self._get_fallback_model(task_type) |
| return await self._call_hf_endpoint(fallback_config, prompt, task_type, **kwargs) |
| |
| |
| if last_exception: |
| logger.error(f"Failed after {max_retries} retries. Last error: {last_exception}") |
| return None |
| |
| except ImportError: |
| logger.warning("requests library not available, using mock response") |
| return f"[Mock] Response to: {prompt[:100]}..." |
| except Exception as e: |
| logger.error(f"Error calling HF endpoint: {e}", exc_info=True) |
| return None |
| |
| async def get_available_models(self): |
| """ |
| Get list of available models for testing |
| """ |
| return list(LLM_CONFIG["models"].keys()) |
| |
| async def health_check(self): |
| """ |
| Perform health check on all models |
| """ |
| health_status = {} |
| for model_name, model_config in LLM_CONFIG["models"].items(): |
| model_id = model_config["model_id"] |
| is_healthy = await self._is_model_healthy(model_id) |
| health_status[model_name] = { |
| "model_id": model_id, |
| "healthy": is_healthy |
| } |
| |
| return health_status |
| |
| def prepare_context_for_llm(self, raw_context: Dict, max_tokens: int = 4000) -> str: |
| """Smart context windowing for LLM calls""" |
| |
| try: |
| from transformers import AutoTokenizer |
| |
| |
| if not hasattr(self, 'tokenizer'): |
| try: |
| self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct") |
| except Exception as e: |
| logger.warning(f"Could not load tokenizer: {e}, using character count estimation") |
| self.tokenizer = None |
| except ImportError: |
| logger.warning("transformers library not available, using character count estimation") |
| self.tokenizer = None |
| |
| |
| priority_elements = [ |
| ('current_query', 1.0), |
| ('recent_interactions', 0.8), |
| ('user_preferences', 0.6), |
| ('session_summary', 0.4), |
| ('historical_context', 0.2) |
| ] |
| |
| formatted_context = [] |
| total_tokens = 0 |
| |
| for element, priority in priority_elements: |
| |
| element_key_map = { |
| 'current_query': raw_context.get('user_input', ''), |
| 'recent_interactions': raw_context.get('interaction_contexts', []), |
| 'user_preferences': raw_context.get('preferences', {}), |
| 'session_summary': raw_context.get('session_context', {}), |
| 'historical_context': raw_context.get('user_context', '') |
| } |
| |
| content = element_key_map.get(element, '') |
| |
| |
| if isinstance(content, dict): |
| content = str(content) |
| elif isinstance(content, list): |
| content = "\n".join([str(item) for item in content[:10]]) |
| |
| if not content: |
| continue |
| |
| |
| if self.tokenizer: |
| try: |
| tokens = len(self.tokenizer.encode(content)) |
| except: |
| |
| tokens = len(content) // 4 |
| else: |
| |
| tokens = len(content) // 4 |
| |
| if total_tokens + tokens <= max_tokens: |
| formatted_context.append(f"=== {element.upper()} ===\n{content}") |
| total_tokens += tokens |
| elif priority > 0.5: |
| available = max_tokens - total_tokens |
| if available > 100: |
| truncated = self._truncate_to_tokens(content, available) |
| formatted_context.append(f"=== {element.upper()} (TRUNCATED) ===\n{truncated}") |
| break |
| |
| return "\n\n".join(formatted_context) |
| |
| def _truncate_to_tokens(self, content: str, max_tokens: int) -> str: |
| """Truncate content to fit within token limit""" |
| if not self.tokenizer: |
| |
| max_chars = max_tokens * 4 |
| if len(content) <= max_chars: |
| return content |
| return content[:max_chars-3] + "..." |
| |
| try: |
| |
| tokens = self.tokenizer.encode(content) |
| if len(tokens) <= max_tokens: |
| return content |
| |
| truncated_tokens = tokens[:max_tokens-3] |
| truncated_text = self.tokenizer.decode(truncated_tokens) |
| return truncated_text + "..." |
| except Exception as e: |
| logger.warning(f"Error truncating with tokenizer: {e}, using character truncation") |
| max_chars = max_tokens * 4 |
| if len(content) <= max_chars: |
| return content |
| return content[:max_chars-3] + "..." |