import os from typing import List, Dict, Any, Optional import torch from transformers import ( AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, T5ForConditionalGeneration, T5Tokenizer ) from config import Config class LLMHandler: """Handle LLM operations for answer generation""" def __init__(self, config: Config = None): self.config = config or Config() self.model = None self.tokenizer = None self.pipeline = None # Set device self.device = "cuda" if torch.cuda.is_available() and self.config.USE_GPU else "cpu" print(f"🔧 Using device: {self.device}") # Load model self._load_model() def _load_model(self): """Load the LLM model and tokenizer""" try: print(f"🤖 Loading model: {self.config.LLM_MODEL}") # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained( self.config.LLM_MODEL, cache_dir=self.config.HF_CACHE_DIR ) # Load model if "flan-t5" in self.config.LLM_MODEL.lower(): # T5 models self.model = T5ForConditionalGeneration.from_pretrained( self.config.LLM_MODEL, cache_dir=self.config.HF_CACHE_DIR, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, device_map="auto" if self.device == "cuda" else None ) else: # Generic sequence-to-sequence models self.model = AutoModelForSeq2SeqLM.from_pretrained( self.config.LLM_MODEL, cache_dir=self.config.HF_CACHE_DIR, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ) # Move model to device if not using device_map if self.device == "cpu" or "device_map" not in self.model.config.__dict__: self.model.to(self.device) # Create pipeline self.pipeline = pipeline( "text2text-generation", model=self.model, tokenizer=self.tokenizer, device=0 if self.device == "cuda" else -1, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ) print("✅ LLM model loaded successfully") except Exception as e: print(f"❌ Error loading model: {e}") # Fallback to a simpler model self._load_fallback_model() def _load_fallback_model(self): """Load a fallback model if primary model fails""" try: print("🔄 Loading fallback model: google/flan-t5-small") self.tokenizer = T5Tokenizer.from_pretrained( "google/flan-t5-small", cache_dir=self.config.HF_CACHE_DIR ) self.model = T5ForConditionalGeneration.from_pretrained( "google/flan-t5-small", cache_dir=self.config.HF_CACHE_DIR ) self.model.to(self.device) self.pipeline = pipeline( "text2text-generation", model=self.model, tokenizer=self.tokenizer, device=0 if self.device == "cuda" else -1 ) print("✅ Fallback model loaded successfully") except Exception as e: print(f"❌ Fallback model also failed: {e}") raise def generate_answer(self, question: str, context: List[str], max_length: int = 200) -> str: """Generate answer based on question and context""" try: if not context: return "I don't have enough context to answer this question." # Prepare context (use top 3 most relevant chunks) context_text = "\n\n".join(context[:3]) # Construct prompt prompt = self._construct_prompt(question, context_text) # Generate answer response = self.pipeline( prompt, max_length=max_length, min_length=20, temperature=0.7, do_sample=True, top_p=0.9, repetition_penalty=1.2, num_return_sequences=1, pad_token_id=self.tokenizer.eos_token_id ) # Extract and clean answer answer = response[0]['generated_text'] answer = self._clean_answer(answer, prompt) return answer except Exception as e: print(f"❌ Error generating answer: {e}") return f"I apologize, but I encountered an error while generating the answer: {str(e)}" def _construct_prompt(self, question: str, context: str) -> str: """Construct prompt for the model""" # Different prompt templates for different models if "flan-t5" in self.config.LLM_MODEL.lower(): prompt = f"""Answer the following question based on the given context. Be concise and accurate. Context: {context} Question: {question} Answer:""" else: prompt = f"""Based on the context below, please answer the question. Context: {context} Question: {question} Answer:""" # Truncate if too long max_prompt_length = 1500 # Leave room for generation if len(prompt) > max_prompt_length: # Truncate context while keeping question context_limit = max_prompt_length - len(question) - 100 truncated_context = context[:context_limit] + "..." prompt = f"""Answer the following question based on the given context. Be concise and accurate. Context: {truncated_context} Question: {question} Answer:""" return prompt def _clean_answer(self, answer: str, prompt: str) -> str: """Clean and post-process the generated answer""" # Remove the prompt from the answer if it's repeated if prompt in answer: answer = answer.replace(prompt, "").strip() # Remove common artifacts if "Answer:" in answer: answer = answer.split("Answer:")[-1].strip() # Remove repetitive patterns lines = answer.split('\n') cleaned_lines = [] prev_line = "" for line in lines: line = line.strip() if line and line != prev_line: # Remove empty lines and duplicates cleaned_lines.append(line) prev_line = line answer = "\n".join(cleaned_lines) # Ensure the answer ends properly if answer and not answer.endswith(('.', '!', '?')): # Find the last complete sentence sentences = answer.split('.') if len(sentences) > 1: answer = '.'.join(sentences[:-1]) + '.' # Fallback response if answer is too short or empty if not answer or len(answer.strip()) < 10: answer = "Based on the provided context, I cannot generate a comprehensive answer to your question. Please try rephrasing your question or providing more specific context." return answer.strip() def summarize_text(self, text: str, max_length: int = 150) -> str: """Summarize given text""" try: prompt = f"Summarize the following text concisely:\n\n{text}\n\nSummary:" response = self.pipeline( prompt, max_length=max_length, min_length=30, temperature=0.5, do_sample=True, num_return_sequences=1 ) summary = response[0]['generated_text'] summary = self._clean_answer(summary, prompt) return summary except Exception as e: print(f"Error summarizing text: {e}") return "Unable to generate summary." def answer_with_confidence(self, question: str, context: List[str]) -> Dict[str, Any]: """Generate answer with confidence estimation""" try: # Generate multiple candidates candidates = [] for temp in [0.5, 0.7, 0.9]: context_text = "\n\n".join(context[:3]) prompt = self._construct_prompt(question, context_text) response = self.pipeline( prompt, max_length=200, temperature=temp, do_sample=True, num_return_sequences=1 ) answer = self._clean_answer(response[0]['generated_text'], prompt) candidates.append(answer) # Use the middle temperature answer as primary primary_answer = candidates[1] # Simple confidence estimation based on consistency confidence = self._estimate_confidence(candidates, context) return { 'answer': primary_answer, 'confidence': confidence, 'candidates': candidates } except Exception as e: return { 'answer': f"Error generating answer: {str(e)}", 'confidence': 0.0, 'candidates': [] } def _estimate_confidence(self, candidates: List[str], context: List[str]) -> float: """Estimate confidence based on answer consistency and context relevance""" if len(candidates) < 2: return 0.5 # Simple similarity check between candidates similarities = [] for i in range(len(candidates)): for j in range(i + 1, len(candidates)): # Simple word overlap similarity words1 = set(candidates[i].lower().split()) words2 = set(candidates[j].lower().split()) if len(words1) + len(words2) == 0: sim = 0.0 else: sim = len(words1.intersection(words2)) / len(words1.union(words2)) similarities.append(sim) # Average similarity as confidence proxy confidence = sum(similarities) / len(similarities) if similarities else 0.5 # Adjust based on context relevance (simple keyword matching) if context: context_words = set(" ".join(context).lower().split()) answer_words = set(candidates[0].lower().split()) relevance = len(context_words.intersection(answer_words)) / len(answer_words) if answer_words else 0 confidence = (confidence + relevance) / 2 return min(1.0, max(0.0, confidence)) def get_model_info(self) -> Dict[str, Any]: """Get information about the loaded model""" return { 'model_name': self.config.LLM_MODEL, 'device': self.device, 'model_size': sum(p.numel() for p in self.model.parameters()) if self.model else 0, 'tokenizer_vocab_size': len(self.tokenizer) if self.tokenizer else 0 }