"""LLM model wrapper using HuggingFace Transformers - Production Grade.""" import logging import sys import time from typing import Optional, List, Dict import torch from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from src.config import ModelConfig # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler(sys.stdout)] ) logger = logging.getLogger(__name__) # Constants MAX_RETRIES = 3 RETRY_DELAY = 2 # seconds GENERATION_TIMEOUT = 60 # seconds class ModelLoadError(Exception): """Custom exception for model loading failures.""" pass class GenerationError(Exception): """Custom exception for text generation failures.""" pass class PhiModel: """Production-grade LLM wrapper using HuggingFace Transformers.""" def __init__(self, config: Optional[ModelConfig] = None): """Initialize the model wrapper. Args: config: Model configuration. Uses defaults if not provided. """ self.config = config or ModelConfig() self._model = None self._tokenizer = None self._pipeline = None self._is_loaded = False @property def model(self): """Lazy load the model with retry logic.""" if self._pipeline is None: self._load_model_with_retry() return self._pipeline @property def is_ready(self) -> bool: """Check if model is loaded and ready.""" return self._is_loaded and self._pipeline is not None def _load_model_with_retry(self) -> None: """Load model with retry logic for production reliability.""" last_error = None for attempt in range(1, MAX_RETRIES + 1): try: logger.info(f"📥 Loading model (attempt {attempt}/{MAX_RETRIES}): {self.config.repo_id}") self._load_model() self._is_loaded = True return except Exception as e: last_error = e logger.warning(f"⚠️ Attempt {attempt} failed: {str(e)[:100]}") if attempt < MAX_RETRIES: logger.info(f"⏳ Retrying in {RETRY_DELAY} seconds...") time.sleep(RETRY_DELAY) logger.error(f"❌ Model loading failed after {MAX_RETRIES} attempts") raise ModelLoadError(f"Failed to load model after {MAX_RETRIES} attempts: {last_error}") def _load_model(self) -> None: """Download and load the model.""" # Load tokenizer logger.info("🔧 Loading tokenizer...") self._tokenizer = AutoTokenizer.from_pretrained( self.config.repo_id, trust_remote_code=True ) # Ensure pad token is set if self._tokenizer.pad_token is None: self._tokenizer.pad_token = self._tokenizer.eos_token # Load model with CPU optimizations logger.info("🔧 Loading model weights...") self._model = AutoModelForCausalLM.from_pretrained( self.config.repo_id, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True, low_cpu_mem_usage=True ) # Create pipeline for text generation self._pipeline = pipeline( "text-generation", model=self._model, tokenizer=self._tokenizer, max_new_tokens=self.config.max_tokens, temperature=self.config.temperature, do_sample=True, pad_token_id=self._tokenizer.eos_token_id ) logger.info("✅ Model loaded successfully!") def generate(self, prompt: str, max_tokens: Optional[int] = None) -> str: """Generate text completion with error handling. Args: prompt: Input prompt. max_tokens: Maximum tokens to generate. Returns: Generated text. Raises: GenerationError: If generation fails. """ if not prompt or not prompt.strip(): return "Please provide a valid question." # Truncate very long prompts max_prompt_length = 3000 if len(prompt) > max_prompt_length: prompt = prompt[:max_prompt_length] + "..." logger.warning(f"Prompt truncated to {max_prompt_length} characters") try: result = self.model( prompt, max_new_tokens=max_tokens or self.config.max_tokens, temperature=self.config.temperature, do_sample=True, return_full_text=False ) generated = result[0]["generated_text"].strip() # Clean up response - remove any fake dialogue continuation generated = self._clean_response(generated) if not generated: return "I couldn't generate a response. Please try rephrasing your question." return generated except Exception as e: logger.error(f"Generation error: {e}") raise GenerationError(f"Failed to generate response: {str(e)[:100]}") def _clean_response(self, text: str) -> str: """Remove fake dialogue continuations from model output. TinyLlama and similar models sometimes continue generating fake user/assistant dialogue. This method cuts off such continuations. """ if not text: return text # Stop patterns - cut off if model starts generating fake dialogue stop_patterns = [ "\nUser:", "\nuser:", "\nHuman:", "\nhuman:", "\nSystem:", "\nsystem:", "\nAssistant:", "\nassistant:", "\n\nUser", "\n\nHuman", "User's question:", "\n---\n", "<|", "[INST]", "" ] result = text for pattern in stop_patterns: if pattern in result: result = result.split(pattern)[0] # Also check for repeated newlines with potential role markers lines = result.split("\n") cleaned_lines = [] for line in lines: line_lower = line.lower().strip() # Stop if we hit a line that looks like a role marker if line_lower.startswith(("user:", "human:", "system:", "assistant:")): break cleaned_lines.append(line) result = "\n".join(cleaned_lines).strip() return result def generate_safe(self, prompt: str, max_tokens: Optional[int] = None) -> str: """Generate text with fallback on error (never throws). Args: prompt: Input prompt. max_tokens: Maximum tokens to generate. Returns: Generated text or fallback message. """ try: return self.generate(prompt, max_tokens) except Exception as e: logger.error(f"Safe generation fallback: {e}") return "I'm having trouble processing your request right now. Please try again in a moment." def chat( self, messages: List[Dict[str, str]], max_tokens: Optional[int] = None ) -> str: """Generate chat completion. Args: messages: List of message dicts with 'role' and 'content'. max_tokens: Maximum tokens to generate. Returns: Assistant's response. """ if not messages: return "Please provide a message." # Format messages for chat chat_text = "" for msg in messages: role = msg.get("role", "user") content = msg.get("content", "") if role == "system": chat_text += f"System: {content}\n\n" elif role == "user": chat_text += f"User: {content}\n\n" elif role == "assistant": chat_text += f"Assistant: {content}\n\n" chat_text += "Assistant: " return self.generate_safe(chat_text, max_tokens) def chat_with_context( self, query: str, context: str, system_prompt: Optional[str] = None, conversation_history: Optional[str] = None ) -> str: """Generate response with RAG context and conversation history. Args: query: User's question. context: Retrieved context from documents. system_prompt: Optional system prompt. conversation_history: Optional formatted conversation history (last 6 messages). Returns: Generated response. """ if not query or not query.strip(): return "Please ask a question." if system_prompt is None: system_prompt = ( "Your name is Dragon. Always speak in only ENGLISH not any other language. " "You are a friendly and helpful assistant having a natural conversation. " "Answer questions based on the provided document context. " "Be conversational, warm, and helpful - like talking to a knowledgeable friend. " "If you can find relevant information, explain it clearly and naturally. " "If the context doesn't have enough information, kindly ask the user to provide " "more details or suggest what they might be looking for. " "Keep your responses concise but friendly." ) # Handle empty context if not context or not context.strip(): context = "No relevant documents found." # Build message with optional history history_section = "" if conversation_history and conversation_history.strip(): history_section = f"""Previous conversation: {conversation_history} --- """ user_message = f"""{history_section}Here's some information from the documents: {context} User's current question: {query} Please respond naturally and helpfully, considering the conversation context:""" messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_message} ] return self.chat(messages)