File size: 5,752 Bytes
72533b2 831deda 72533b2 831deda 72533b2 831deda 72533b2 831deda 72533b2 831deda 72533b2 831deda 72533b2 831deda 72533b2 831deda 72533b2 831deda 72533b2 831deda 72533b2 831deda 72533b2 831deda 72533b2 831deda 72533b2 831deda 72533b2 831deda 72533b2 831deda |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import chromadb
import traceback
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from retriever import retrieve
from utils import build_prompt, refine_response
# ============================================================================
# LAZY-LOAD MODELS AND VECTOR STORE (load on first use, not at import)
# ============================================================================
_vector_store = None
_finetuned_llm = None
_base_model = None
def get_vector_store():
"""Load vector store (lazy-loaded on first use)"""
global _vector_store
if _vector_store is None:
db_client = chromadb.PersistentClient(path="./MedQuAD_db")
try:
_vector_store = db_client.get_collection("medical_rag")
except:
# If collection doesn't exist, create it
_vector_store = db_client.create_collection(name="medical_rag")
return _vector_store
def get_finetuned_llm():
"""Load fine-tuned model (lazy-loaded on first use)"""
global _finetuned_llm
if _finetuned_llm is None:
ft_model_id = "amiraghhh/fine-tuned-flan-t5-small"
ft_tokenizer = AutoTokenizer.from_pretrained(ft_model_id)
ft_model = AutoModelForSeq2SeqLM.from_pretrained(ft_model_id)
_finetuned_llm = pipeline(
"text2text-generation",
model=ft_model,
tokenizer=ft_tokenizer,
decoder_start_token_id=ft_model.config.pad_token_id
)
return _finetuned_llm
# ============================================================================
# MAIN RAG FUNCTION
# ============================================================================
def rag(user_query):
"""Main RAG function: retrieve context and generate answer.
Takes a question string and returns an answer string with confidence.
Returns: str(generated_answer)"""
try:
# Load models on first use
vector_store = get_vector_store()
finetuned_llm = get_finetuned_llm()
# 1. Check for emergency keywords
emergency_keywords = ["emergency", "severe pain", "bleeding",
"blind", "lose consciousness", "pass out"]
if any(keyword in user_query.lower() for keyword in emergency_keywords):
emergency_msg = """I am an AI and cannot provide medical advice for emergencies.
PLEASE contact emergency services or a medical professional immediately."""
try:
# Still generate answer for context
contexts = retrieve(vector_store, user_query, top_k=3, use_reranking=True)
if not contexts:
return f"{emergency_msg}\n\nNo relevant information found for your query."
prompt = build_prompt(user_query, contexts)
result = finetuned_llm(
prompt,
max_new_tokens=70,
num_beams=3,
early_stopping=True,
do_sample=False,
repetition_penalty=1.4,
eos_token_id=finetuned_llm.tokenizer.eos_token_id
)
answer = result[0]['generated_text'].strip()
answer = refine_response(answer)
# Calculate confidence
if contexts:
avg_distance = sum(c.get('chroma_distance', 1.0) for c in contexts) / len(contexts)
confidence_score = (1 - avg_distance) * 100
confidence_score = max(0, min(100, confidence_score))
else:
confidence_score = 0
return f"{emergency_msg}\n\n[Confidence: {confidence_score:.1f}%]\n\n{answer}"
except Exception as e:
return f"{emergency_msg}\n\nError generating answer: {str(e)}"
# 2. Retrieve relevant contexts
contexts = retrieve(vector_store, user_query, top_k=3, use_reranking=True)
if not contexts:
return "I'm not confident about my answer (0%).\n\nCouldn't find relevant information to answer your question."
# 3. Build prompt with context
prompt = build_prompt(user_query, contexts)
# 4. Generate answer
result = finetuned_llm(
prompt,
max_new_tokens=70,
num_beams=3,
early_stopping=True,
do_sample=False,
repetition_penalty=1.4,
eos_token_id=finetuned_llm.tokenizer.eos_token_id
)
answer = result[0]['generated_text'].strip()
answer = refine_response(answer)
# 5. Calculate confidence score based on retrieval quality
if contexts and len(contexts) > 0:
avg_distance = sum(c.get('chroma_distance', 1.0) for c in contexts) / len(contexts)
confidence_score = (1 - avg_distance) * 100
confidence_score = max(0, min(100, confidence_score))
# Build final response with confidence
if confidence_score < 40:
final_response = f"I'm not confident about my answer ({confidence_score:.1f}%).\n\n{answer}"
else:
final_response = f"{answer}\n\n[Confidence: {confidence_score:.1f}%]"
else:
final_response = "I'm not confident about my answer (0%).\n\n" + answer
return final_response
except Exception as e:
error_msg = f"ERROR in RAG pipeline: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
print(error_msg)
return error_msg
|