import chromadb import traceback from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline from retriever import retrieve from utils import build_prompt, refine_response # ============================================================================ # LAZY-LOAD MODELS AND VECTOR STORE (load on first use, not at import) # ============================================================================ _vector_store = None _finetuned_llm = None _base_model = None def get_vector_store(): """Load vector store (lazy-loaded on first use)""" global _vector_store if _vector_store is None: db_client = chromadb.PersistentClient(path="./MedQuAD_db") try: _vector_store = db_client.get_collection("medical_rag") except: # If collection doesn't exist, create it _vector_store = db_client.create_collection(name="medical_rag") return _vector_store def get_finetuned_llm(): """Load fine-tuned model (lazy-loaded on first use)""" global _finetuned_llm if _finetuned_llm is None: ft_model_id = "amiraghhh/fine-tuned-flan-t5-small" ft_tokenizer = AutoTokenizer.from_pretrained(ft_model_id) ft_model = AutoModelForSeq2SeqLM.from_pretrained(ft_model_id) _finetuned_llm = pipeline( "text2text-generation", model=ft_model, tokenizer=ft_tokenizer, decoder_start_token_id=ft_model.config.pad_token_id ) return _finetuned_llm # ============================================================================ # MAIN RAG FUNCTION # ============================================================================ def rag(user_query): """Main RAG function: retrieve context and generate answer. Takes a question string and returns an answer string with confidence. Returns: str(generated_answer)""" try: # Load models on first use vector_store = get_vector_store() finetuned_llm = get_finetuned_llm() # 1. Check for emergency keywords emergency_keywords = ["emergency", "severe pain", "bleeding", "blind", "lose consciousness", "pass out"] if any(keyword in user_query.lower() for keyword in emergency_keywords): emergency_msg = """I am an AI and cannot provide medical advice for emergencies. PLEASE contact emergency services or a medical professional immediately.""" try: # Still generate answer for context contexts = retrieve(vector_store, user_query, top_k=3, use_reranking=True) if not contexts: return f"{emergency_msg}\n\nNo relevant information found for your query." prompt = build_prompt(user_query, contexts) result = finetuned_llm( prompt, max_new_tokens=70, num_beams=3, early_stopping=True, do_sample=False, repetition_penalty=1.4, eos_token_id=finetuned_llm.tokenizer.eos_token_id ) answer = result[0]['generated_text'].strip() answer = refine_response(answer) # Calculate confidence if contexts: avg_distance = sum(c.get('chroma_distance', 1.0) for c in contexts) / len(contexts) confidence_score = (1 - avg_distance) * 100 confidence_score = max(0, min(100, confidence_score)) else: confidence_score = 0 return f"{emergency_msg}\n\n[Confidence: {confidence_score:.1f}%]\n\n{answer}" except Exception as e: return f"{emergency_msg}\n\nError generating answer: {str(e)}" # 2. Retrieve relevant contexts contexts = retrieve(vector_store, user_query, top_k=3, use_reranking=True) if not contexts: return "I'm not confident about my answer (0%).\n\nCouldn't find relevant information to answer your question." # 3. Build prompt with context prompt = build_prompt(user_query, contexts) # 4. Generate answer result = finetuned_llm( prompt, max_new_tokens=70, num_beams=3, early_stopping=True, do_sample=False, repetition_penalty=1.4, eos_token_id=finetuned_llm.tokenizer.eos_token_id ) answer = result[0]['generated_text'].strip() answer = refine_response(answer) # 5. Calculate confidence score based on retrieval quality if contexts and len(contexts) > 0: avg_distance = sum(c.get('chroma_distance', 1.0) for c in contexts) / len(contexts) confidence_score = (1 - avg_distance) * 100 confidence_score = max(0, min(100, confidence_score)) # Build final response with confidence if confidence_score < 40: final_response = f"I'm not confident about my answer ({confidence_score:.1f}%).\n\n{answer}" else: final_response = f"{answer}\n\n[Confidence: {confidence_score:.1f}%]" else: final_response = "I'm not confident about my answer (0%).\n\n" + answer return final_response except Exception as e: error_msg = f"ERROR in RAG pipeline: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" print(error_msg) return error_msg