|
|
import chromadb |
|
|
import traceback |
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline |
|
|
from retriever import retrieve |
|
|
from utils import build_prompt, refine_response |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_vector_store = None |
|
|
_finetuned_llm = None |
|
|
_base_model = None |
|
|
|
|
|
def get_vector_store(): |
|
|
"""Load vector store (lazy-loaded on first use)""" |
|
|
global _vector_store |
|
|
if _vector_store is None: |
|
|
db_client = chromadb.PersistentClient(path="./MedQuAD_db") |
|
|
try: |
|
|
_vector_store = db_client.get_collection("medical_rag") |
|
|
except: |
|
|
|
|
|
_vector_store = db_client.create_collection(name="medical_rag") |
|
|
return _vector_store |
|
|
|
|
|
def get_finetuned_llm(): |
|
|
"""Load fine-tuned model (lazy-loaded on first use)""" |
|
|
global _finetuned_llm |
|
|
if _finetuned_llm is None: |
|
|
ft_model_id = "amiraghhh/fine-tuned-flan-t5-small" |
|
|
ft_tokenizer = AutoTokenizer.from_pretrained(ft_model_id) |
|
|
ft_model = AutoModelForSeq2SeqLM.from_pretrained(ft_model_id) |
|
|
|
|
|
_finetuned_llm = pipeline( |
|
|
"text2text-generation", |
|
|
model=ft_model, |
|
|
tokenizer=ft_tokenizer, |
|
|
decoder_start_token_id=ft_model.config.pad_token_id |
|
|
) |
|
|
return _finetuned_llm |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def rag(user_query): |
|
|
"""Main RAG function: retrieve context and generate answer. |
|
|
Takes a question string and returns an answer string with confidence. |
|
|
Returns: str(generated_answer)""" |
|
|
|
|
|
try: |
|
|
|
|
|
vector_store = get_vector_store() |
|
|
finetuned_llm = get_finetuned_llm() |
|
|
|
|
|
|
|
|
emergency_keywords = ["emergency", "severe pain", "bleeding", |
|
|
"blind", "lose consciousness", "pass out"] |
|
|
|
|
|
if any(keyword in user_query.lower() for keyword in emergency_keywords): |
|
|
emergency_msg = """I am an AI and cannot provide medical advice for emergencies. |
|
|
PLEASE contact emergency services or a medical professional immediately.""" |
|
|
|
|
|
try: |
|
|
|
|
|
contexts = retrieve(vector_store, user_query, top_k=3, use_reranking=True) |
|
|
|
|
|
if not contexts: |
|
|
return f"{emergency_msg}\n\nNo relevant information found for your query." |
|
|
|
|
|
prompt = build_prompt(user_query, contexts) |
|
|
result = finetuned_llm( |
|
|
prompt, |
|
|
max_new_tokens=70, |
|
|
num_beams=3, |
|
|
early_stopping=True, |
|
|
do_sample=False, |
|
|
repetition_penalty=1.4, |
|
|
eos_token_id=finetuned_llm.tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
answer = result[0]['generated_text'].strip() |
|
|
answer = refine_response(answer) |
|
|
|
|
|
|
|
|
if contexts: |
|
|
avg_distance = sum(c.get('chroma_distance', 1.0) for c in contexts) / len(contexts) |
|
|
confidence_score = (1 - avg_distance) * 100 |
|
|
confidence_score = max(0, min(100, confidence_score)) |
|
|
else: |
|
|
confidence_score = 0 |
|
|
|
|
|
return f"{emergency_msg}\n\n[Confidence: {confidence_score:.1f}%]\n\n{answer}" |
|
|
|
|
|
except Exception as e: |
|
|
return f"{emergency_msg}\n\nError generating answer: {str(e)}" |
|
|
|
|
|
|
|
|
contexts = retrieve(vector_store, user_query, top_k=3, use_reranking=True) |
|
|
|
|
|
if not contexts: |
|
|
return "I'm not confident about my answer (0%).\n\nCouldn't find relevant information to answer your question." |
|
|
|
|
|
|
|
|
prompt = build_prompt(user_query, contexts) |
|
|
|
|
|
|
|
|
result = finetuned_llm( |
|
|
prompt, |
|
|
max_new_tokens=70, |
|
|
num_beams=3, |
|
|
early_stopping=True, |
|
|
do_sample=False, |
|
|
repetition_penalty=1.4, |
|
|
eos_token_id=finetuned_llm.tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
answer = result[0]['generated_text'].strip() |
|
|
answer = refine_response(answer) |
|
|
|
|
|
|
|
|
if contexts and len(contexts) > 0: |
|
|
avg_distance = sum(c.get('chroma_distance', 1.0) for c in contexts) / len(contexts) |
|
|
confidence_score = (1 - avg_distance) * 100 |
|
|
confidence_score = max(0, min(100, confidence_score)) |
|
|
|
|
|
|
|
|
if confidence_score < 40: |
|
|
final_response = f"I'm not confident about my answer ({confidence_score:.1f}%).\n\n{answer}" |
|
|
else: |
|
|
final_response = f"{answer}\n\n[Confidence: {confidence_score:.1f}%]" |
|
|
else: |
|
|
final_response = "I'm not confident about my answer (0%).\n\n" + answer |
|
|
|
|
|
return final_response |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"ERROR in RAG pipeline: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" |
|
|
print(error_msg) |
|
|
return error_msg |
|
|
|