rag / model.py
amiraghhh's picture
Update model.py
831deda verified
import chromadb
import traceback
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from retriever import retrieve
from utils import build_prompt, refine_response
# ============================================================================
# LAZY-LOAD MODELS AND VECTOR STORE (load on first use, not at import)
# ============================================================================
_vector_store = None
_finetuned_llm = None
_base_model = None
def get_vector_store():
"""Load vector store (lazy-loaded on first use)"""
global _vector_store
if _vector_store is None:
db_client = chromadb.PersistentClient(path="./MedQuAD_db")
try:
_vector_store = db_client.get_collection("medical_rag")
except:
# If collection doesn't exist, create it
_vector_store = db_client.create_collection(name="medical_rag")
return _vector_store
def get_finetuned_llm():
"""Load fine-tuned model (lazy-loaded on first use)"""
global _finetuned_llm
if _finetuned_llm is None:
ft_model_id = "amiraghhh/fine-tuned-flan-t5-small"
ft_tokenizer = AutoTokenizer.from_pretrained(ft_model_id)
ft_model = AutoModelForSeq2SeqLM.from_pretrained(ft_model_id)
_finetuned_llm = pipeline(
"text2text-generation",
model=ft_model,
tokenizer=ft_tokenizer,
decoder_start_token_id=ft_model.config.pad_token_id
)
return _finetuned_llm
# ============================================================================
# MAIN RAG FUNCTION
# ============================================================================
def rag(user_query):
"""Main RAG function: retrieve context and generate answer.
Takes a question string and returns an answer string with confidence.
Returns: str(generated_answer)"""
try:
# Load models on first use
vector_store = get_vector_store()
finetuned_llm = get_finetuned_llm()
# 1. Check for emergency keywords
emergency_keywords = ["emergency", "severe pain", "bleeding",
"blind", "lose consciousness", "pass out"]
if any(keyword in user_query.lower() for keyword in emergency_keywords):
emergency_msg = """I am an AI and cannot provide medical advice for emergencies.
PLEASE contact emergency services or a medical professional immediately."""
try:
# Still generate answer for context
contexts = retrieve(vector_store, user_query, top_k=3, use_reranking=True)
if not contexts:
return f"{emergency_msg}\n\nNo relevant information found for your query."
prompt = build_prompt(user_query, contexts)
result = finetuned_llm(
prompt,
max_new_tokens=70,
num_beams=3,
early_stopping=True,
do_sample=False,
repetition_penalty=1.4,
eos_token_id=finetuned_llm.tokenizer.eos_token_id
)
answer = result[0]['generated_text'].strip()
answer = refine_response(answer)
# Calculate confidence
if contexts:
avg_distance = sum(c.get('chroma_distance', 1.0) for c in contexts) / len(contexts)
confidence_score = (1 - avg_distance) * 100
confidence_score = max(0, min(100, confidence_score))
else:
confidence_score = 0
return f"{emergency_msg}\n\n[Confidence: {confidence_score:.1f}%]\n\n{answer}"
except Exception as e:
return f"{emergency_msg}\n\nError generating answer: {str(e)}"
# 2. Retrieve relevant contexts
contexts = retrieve(vector_store, user_query, top_k=3, use_reranking=True)
if not contexts:
return "I'm not confident about my answer (0%).\n\nCouldn't find relevant information to answer your question."
# 3. Build prompt with context
prompt = build_prompt(user_query, contexts)
# 4. Generate answer
result = finetuned_llm(
prompt,
max_new_tokens=70,
num_beams=3,
early_stopping=True,
do_sample=False,
repetition_penalty=1.4,
eos_token_id=finetuned_llm.tokenizer.eos_token_id
)
answer = result[0]['generated_text'].strip()
answer = refine_response(answer)
# 5. Calculate confidence score based on retrieval quality
if contexts and len(contexts) > 0:
avg_distance = sum(c.get('chroma_distance', 1.0) for c in contexts) / len(contexts)
confidence_score = (1 - avg_distance) * 100
confidence_score = max(0, min(100, confidence_score))
# Build final response with confidence
if confidence_score < 40:
final_response = f"I'm not confident about my answer ({confidence_score:.1f}%).\n\n{answer}"
else:
final_response = f"{answer}\n\n[Confidence: {confidence_score:.1f}%]"
else:
final_response = "I'm not confident about my answer (0%).\n\n" + answer
return final_response
except Exception as e:
error_msg = f"ERROR in RAG pipeline: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
print(error_msg)
return error_msg