Rag-based-api-task / src /llm_handler.py
sairika's picture
Create llm_handler.py
e5ef013 verified
import os
from typing import List, Dict, Any, Optional
import torch
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
pipeline,
T5ForConditionalGeneration,
T5Tokenizer
)
from config import Config
class LLMHandler:
"""Handle LLM operations for answer generation"""
def __init__(self, config: Config = None):
self.config = config or Config()
self.model = None
self.tokenizer = None
self.pipeline = None
# Set device
self.device = "cuda" if torch.cuda.is_available() and self.config.USE_GPU else "cpu"
print(f"🔧 Using device: {self.device}")
# Load model
self._load_model()
def _load_model(self):
"""Load the LLM model and tokenizer"""
try:
print(f"🤖 Loading model: {self.config.LLM_MODEL}")
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
self.config.LLM_MODEL,
cache_dir=self.config.HF_CACHE_DIR
)
# Load model
if "flan-t5" in self.config.LLM_MODEL.lower():
# T5 models
self.model = T5ForConditionalGeneration.from_pretrained(
self.config.LLM_MODEL,
cache_dir=self.config.HF_CACHE_DIR,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
device_map="auto" if self.device == "cuda" else None
)
else:
# Generic sequence-to-sequence models
self.model = AutoModelForSeq2SeqLM.from_pretrained(
self.config.LLM_MODEL,
cache_dir=self.config.HF_CACHE_DIR,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
)
# Move model to device if not using device_map
if self.device == "cpu" or "device_map" not in self.model.config.__dict__:
self.model.to(self.device)
# Create pipeline
self.pipeline = pipeline(
"text2text-generation",
model=self.model,
tokenizer=self.tokenizer,
device=0 if self.device == "cuda" else -1,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
)
print("✅ LLM model loaded successfully")
except Exception as e:
print(f"❌ Error loading model: {e}")
# Fallback to a simpler model
self._load_fallback_model()
def _load_fallback_model(self):
"""Load a fallback model if primary model fails"""
try:
print("🔄 Loading fallback model: google/flan-t5-small")
self.tokenizer = T5Tokenizer.from_pretrained(
"google/flan-t5-small",
cache_dir=self.config.HF_CACHE_DIR
)
self.model = T5ForConditionalGeneration.from_pretrained(
"google/flan-t5-small",
cache_dir=self.config.HF_CACHE_DIR
)
self.model.to(self.device)
self.pipeline = pipeline(
"text2text-generation",
model=self.model,
tokenizer=self.tokenizer,
device=0 if self.device == "cuda" else -1
)
print("✅ Fallback model loaded successfully")
except Exception as e:
print(f"❌ Fallback model also failed: {e}")
raise
def generate_answer(self, question: str, context: List[str], max_length: int = 200) -> str:
"""Generate answer based on question and context"""
try:
if not context:
return "I don't have enough context to answer this question."
# Prepare context (use top 3 most relevant chunks)
context_text = "\n\n".join(context[:3])
# Construct prompt
prompt = self._construct_prompt(question, context_text)
# Generate answer
response = self.pipeline(
prompt,
max_length=max_length,
min_length=20,
temperature=0.7,
do_sample=True,
top_p=0.9,
repetition_penalty=1.2,
num_return_sequences=1,
pad_token_id=self.tokenizer.eos_token_id
)
# Extract and clean answer
answer = response[0]['generated_text']
answer = self._clean_answer(answer, prompt)
return answer
except Exception as e:
print(f"❌ Error generating answer: {e}")
return f"I apologize, but I encountered an error while generating the answer: {str(e)}"
def _construct_prompt(self, question: str, context: str) -> str:
"""Construct prompt for the model"""
# Different prompt templates for different models
if "flan-t5" in self.config.LLM_MODEL.lower():
prompt = f"""Answer the following question based on the given context. Be concise and accurate.
Context:
{context}
Question: {question}
Answer:"""
else:
prompt = f"""Based on the context below, please answer the question.
Context: {context}
Question: {question}
Answer:"""
# Truncate if too long
max_prompt_length = 1500 # Leave room for generation
if len(prompt) > max_prompt_length:
# Truncate context while keeping question
context_limit = max_prompt_length - len(question) - 100
truncated_context = context[:context_limit] + "..."
prompt = f"""Answer the following question based on the given context. Be concise and accurate.
Context:
{truncated_context}
Question: {question}
Answer:"""
return prompt
def _clean_answer(self, answer: str, prompt: str) -> str:
"""Clean and post-process the generated answer"""
# Remove the prompt from the answer if it's repeated
if prompt in answer:
answer = answer.replace(prompt, "").strip()
# Remove common artifacts
if "Answer:" in answer:
answer = answer.split("Answer:")[-1].strip()
# Remove repetitive patterns
lines = answer.split('\n')
cleaned_lines = []
prev_line = ""
for line in lines:
line = line.strip()
if line and line != prev_line: # Remove empty lines and duplicates
cleaned_lines.append(line)
prev_line = line
answer = "\n".join(cleaned_lines)
# Ensure the answer ends properly
if answer and not answer.endswith(('.', '!', '?')):
# Find the last complete sentence
sentences = answer.split('.')
if len(sentences) > 1:
answer = '.'.join(sentences[:-1]) + '.'
# Fallback response if answer is too short or empty
if not answer or len(answer.strip()) < 10:
answer = "Based on the provided context, I cannot generate a comprehensive answer to your question. Please try rephrasing your question or providing more specific context."
return answer.strip()
def summarize_text(self, text: str, max_length: int = 150) -> str:
"""Summarize given text"""
try:
prompt = f"Summarize the following text concisely:\n\n{text}\n\nSummary:"
response = self.pipeline(
prompt,
max_length=max_length,
min_length=30,
temperature=0.5,
do_sample=True,
num_return_sequences=1
)
summary = response[0]['generated_text']
summary = self._clean_answer(summary, prompt)
return summary
except Exception as e:
print(f"Error summarizing text: {e}")
return "Unable to generate summary."
def answer_with_confidence(self, question: str, context: List[str]) -> Dict[str, Any]:
"""Generate answer with confidence estimation"""
try:
# Generate multiple candidates
candidates = []
for temp in [0.5, 0.7, 0.9]:
context_text = "\n\n".join(context[:3])
prompt = self._construct_prompt(question, context_text)
response = self.pipeline(
prompt,
max_length=200,
temperature=temp,
do_sample=True,
num_return_sequences=1
)
answer = self._clean_answer(response[0]['generated_text'], prompt)
candidates.append(answer)
# Use the middle temperature answer as primary
primary_answer = candidates[1]
# Simple confidence estimation based on consistency
confidence = self._estimate_confidence(candidates, context)
return {
'answer': primary_answer,
'confidence': confidence,
'candidates': candidates
}
except Exception as e:
return {
'answer': f"Error generating answer: {str(e)}",
'confidence': 0.0,
'candidates': []
}
def _estimate_confidence(self, candidates: List[str], context: List[str]) -> float:
"""Estimate confidence based on answer consistency and context relevance"""
if len(candidates) < 2:
return 0.5
# Simple similarity check between candidates
similarities = []
for i in range(len(candidates)):
for j in range(i + 1, len(candidates)):
# Simple word overlap similarity
words1 = set(candidates[i].lower().split())
words2 = set(candidates[j].lower().split())
if len(words1) + len(words2) == 0:
sim = 0.0
else:
sim = len(words1.intersection(words2)) / len(words1.union(words2))
similarities.append(sim)
# Average similarity as confidence proxy
confidence = sum(similarities) / len(similarities) if similarities else 0.5
# Adjust based on context relevance (simple keyword matching)
if context:
context_words = set(" ".join(context).lower().split())
answer_words = set(candidates[0].lower().split())
relevance = len(context_words.intersection(answer_words)) / len(answer_words) if answer_words else 0
confidence = (confidence + relevance) / 2
return min(1.0, max(0.0, confidence))
def get_model_info(self) -> Dict[str, Any]:
"""Get information about the loaded model"""
return {
'model_name': self.config.LLM_MODEL,
'device': self.device,
'model_size': sum(p.numel() for p in self.model.parameters()) if self.model else 0,
'tokenizer_vocab_size': len(self.tokenizer) if self.tokenizer else 0
}