Spaces:
Runtime error
Runtime error
| import os | |
| from typing import List, Dict, Any, Optional | |
| import torch | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForSeq2SeqLM, | |
| pipeline, | |
| T5ForConditionalGeneration, | |
| T5Tokenizer | |
| ) | |
| from config import Config | |
| class LLMHandler: | |
| """Handle LLM operations for answer generation""" | |
| def __init__(self, config: Config = None): | |
| self.config = config or Config() | |
| self.model = None | |
| self.tokenizer = None | |
| self.pipeline = None | |
| # Set device | |
| self.device = "cuda" if torch.cuda.is_available() and self.config.USE_GPU else "cpu" | |
| print(f"🔧 Using device: {self.device}") | |
| # Load model | |
| self._load_model() | |
| def _load_model(self): | |
| """Load the LLM model and tokenizer""" | |
| try: | |
| print(f"🤖 Loading model: {self.config.LLM_MODEL}") | |
| # Load tokenizer | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| self.config.LLM_MODEL, | |
| cache_dir=self.config.HF_CACHE_DIR | |
| ) | |
| # Load model | |
| if "flan-t5" in self.config.LLM_MODEL.lower(): | |
| # T5 models | |
| self.model = T5ForConditionalGeneration.from_pretrained( | |
| self.config.LLM_MODEL, | |
| cache_dir=self.config.HF_CACHE_DIR, | |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, | |
| device_map="auto" if self.device == "cuda" else None | |
| ) | |
| else: | |
| # Generic sequence-to-sequence models | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained( | |
| self.config.LLM_MODEL, | |
| cache_dir=self.config.HF_CACHE_DIR, | |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 | |
| ) | |
| # Move model to device if not using device_map | |
| if self.device == "cpu" or "device_map" not in self.model.config.__dict__: | |
| self.model.to(self.device) | |
| # Create pipeline | |
| self.pipeline = pipeline( | |
| "text2text-generation", | |
| model=self.model, | |
| tokenizer=self.tokenizer, | |
| device=0 if self.device == "cuda" else -1, | |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 | |
| ) | |
| print("✅ LLM model loaded successfully") | |
| except Exception as e: | |
| print(f"❌ Error loading model: {e}") | |
| # Fallback to a simpler model | |
| self._load_fallback_model() | |
| def _load_fallback_model(self): | |
| """Load a fallback model if primary model fails""" | |
| try: | |
| print("🔄 Loading fallback model: google/flan-t5-small") | |
| self.tokenizer = T5Tokenizer.from_pretrained( | |
| "google/flan-t5-small", | |
| cache_dir=self.config.HF_CACHE_DIR | |
| ) | |
| self.model = T5ForConditionalGeneration.from_pretrained( | |
| "google/flan-t5-small", | |
| cache_dir=self.config.HF_CACHE_DIR | |
| ) | |
| self.model.to(self.device) | |
| self.pipeline = pipeline( | |
| "text2text-generation", | |
| model=self.model, | |
| tokenizer=self.tokenizer, | |
| device=0 if self.device == "cuda" else -1 | |
| ) | |
| print("✅ Fallback model loaded successfully") | |
| except Exception as e: | |
| print(f"❌ Fallback model also failed: {e}") | |
| raise | |
| def generate_answer(self, question: str, context: List[str], max_length: int = 200) -> str: | |
| """Generate answer based on question and context""" | |
| try: | |
| if not context: | |
| return "I don't have enough context to answer this question." | |
| # Prepare context (use top 3 most relevant chunks) | |
| context_text = "\n\n".join(context[:3]) | |
| # Construct prompt | |
| prompt = self._construct_prompt(question, context_text) | |
| # Generate answer | |
| response = self.pipeline( | |
| prompt, | |
| max_length=max_length, | |
| min_length=20, | |
| temperature=0.7, | |
| do_sample=True, | |
| top_p=0.9, | |
| repetition_penalty=1.2, | |
| num_return_sequences=1, | |
| pad_token_id=self.tokenizer.eos_token_id | |
| ) | |
| # Extract and clean answer | |
| answer = response[0]['generated_text'] | |
| answer = self._clean_answer(answer, prompt) | |
| return answer | |
| except Exception as e: | |
| print(f"❌ Error generating answer: {e}") | |
| return f"I apologize, but I encountered an error while generating the answer: {str(e)}" | |
| def _construct_prompt(self, question: str, context: str) -> str: | |
| """Construct prompt for the model""" | |
| # Different prompt templates for different models | |
| if "flan-t5" in self.config.LLM_MODEL.lower(): | |
| prompt = f"""Answer the following question based on the given context. Be concise and accurate. | |
| Context: | |
| {context} | |
| Question: {question} | |
| Answer:""" | |
| else: | |
| prompt = f"""Based on the context below, please answer the question. | |
| Context: {context} | |
| Question: {question} | |
| Answer:""" | |
| # Truncate if too long | |
| max_prompt_length = 1500 # Leave room for generation | |
| if len(prompt) > max_prompt_length: | |
| # Truncate context while keeping question | |
| context_limit = max_prompt_length - len(question) - 100 | |
| truncated_context = context[:context_limit] + "..." | |
| prompt = f"""Answer the following question based on the given context. Be concise and accurate. | |
| Context: | |
| {truncated_context} | |
| Question: {question} | |
| Answer:""" | |
| return prompt | |
| def _clean_answer(self, answer: str, prompt: str) -> str: | |
| """Clean and post-process the generated answer""" | |
| # Remove the prompt from the answer if it's repeated | |
| if prompt in answer: | |
| answer = answer.replace(prompt, "").strip() | |
| # Remove common artifacts | |
| if "Answer:" in answer: | |
| answer = answer.split("Answer:")[-1].strip() | |
| # Remove repetitive patterns | |
| lines = answer.split('\n') | |
| cleaned_lines = [] | |
| prev_line = "" | |
| for line in lines: | |
| line = line.strip() | |
| if line and line != prev_line: # Remove empty lines and duplicates | |
| cleaned_lines.append(line) | |
| prev_line = line | |
| answer = "\n".join(cleaned_lines) | |
| # Ensure the answer ends properly | |
| if answer and not answer.endswith(('.', '!', '?')): | |
| # Find the last complete sentence | |
| sentences = answer.split('.') | |
| if len(sentences) > 1: | |
| answer = '.'.join(sentences[:-1]) + '.' | |
| # Fallback response if answer is too short or empty | |
| if not answer or len(answer.strip()) < 10: | |
| answer = "Based on the provided context, I cannot generate a comprehensive answer to your question. Please try rephrasing your question or providing more specific context." | |
| return answer.strip() | |
| def summarize_text(self, text: str, max_length: int = 150) -> str: | |
| """Summarize given text""" | |
| try: | |
| prompt = f"Summarize the following text concisely:\n\n{text}\n\nSummary:" | |
| response = self.pipeline( | |
| prompt, | |
| max_length=max_length, | |
| min_length=30, | |
| temperature=0.5, | |
| do_sample=True, | |
| num_return_sequences=1 | |
| ) | |
| summary = response[0]['generated_text'] | |
| summary = self._clean_answer(summary, prompt) | |
| return summary | |
| except Exception as e: | |
| print(f"Error summarizing text: {e}") | |
| return "Unable to generate summary." | |
| def answer_with_confidence(self, question: str, context: List[str]) -> Dict[str, Any]: | |
| """Generate answer with confidence estimation""" | |
| try: | |
| # Generate multiple candidates | |
| candidates = [] | |
| for temp in [0.5, 0.7, 0.9]: | |
| context_text = "\n\n".join(context[:3]) | |
| prompt = self._construct_prompt(question, context_text) | |
| response = self.pipeline( | |
| prompt, | |
| max_length=200, | |
| temperature=temp, | |
| do_sample=True, | |
| num_return_sequences=1 | |
| ) | |
| answer = self._clean_answer(response[0]['generated_text'], prompt) | |
| candidates.append(answer) | |
| # Use the middle temperature answer as primary | |
| primary_answer = candidates[1] | |
| # Simple confidence estimation based on consistency | |
| confidence = self._estimate_confidence(candidates, context) | |
| return { | |
| 'answer': primary_answer, | |
| 'confidence': confidence, | |
| 'candidates': candidates | |
| } | |
| except Exception as e: | |
| return { | |
| 'answer': f"Error generating answer: {str(e)}", | |
| 'confidence': 0.0, | |
| 'candidates': [] | |
| } | |
| def _estimate_confidence(self, candidates: List[str], context: List[str]) -> float: | |
| """Estimate confidence based on answer consistency and context relevance""" | |
| if len(candidates) < 2: | |
| return 0.5 | |
| # Simple similarity check between candidates | |
| similarities = [] | |
| for i in range(len(candidates)): | |
| for j in range(i + 1, len(candidates)): | |
| # Simple word overlap similarity | |
| words1 = set(candidates[i].lower().split()) | |
| words2 = set(candidates[j].lower().split()) | |
| if len(words1) + len(words2) == 0: | |
| sim = 0.0 | |
| else: | |
| sim = len(words1.intersection(words2)) / len(words1.union(words2)) | |
| similarities.append(sim) | |
| # Average similarity as confidence proxy | |
| confidence = sum(similarities) / len(similarities) if similarities else 0.5 | |
| # Adjust based on context relevance (simple keyword matching) | |
| if context: | |
| context_words = set(" ".join(context).lower().split()) | |
| answer_words = set(candidates[0].lower().split()) | |
| relevance = len(context_words.intersection(answer_words)) / len(answer_words) if answer_words else 0 | |
| confidence = (confidence + relevance) / 2 | |
| return min(1.0, max(0.0, confidence)) | |
| def get_model_info(self) -> Dict[str, Any]: | |
| """Get information about the loaded model""" | |
| return { | |
| 'model_name': self.config.LLM_MODEL, | |
| 'device': self.device, | |
| 'model_size': sum(p.numel() for p in self.model.parameters()) if self.model else 0, | |
| 'tokenizer_vocab_size': len(self.tokenizer) if self.tokenizer else 0 | |
| } |