| import chromadb |
| import traceback |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline |
| from retriever import retrieve |
| from utils import build_prompt, refine_response |
|
|
|
|
| |
| |
| |
|
|
| _vector_store = None |
| _finetuned_llm = None |
| _base_model = None |
|
|
| def get_vector_store(): |
| """Load vector store (lazy-loaded on first use)""" |
| global _vector_store |
| if _vector_store is None: |
| db_client = chromadb.PersistentClient(path="./MedQuAD_db") |
| try: |
| _vector_store = db_client.get_collection("medical_rag") |
| except: |
| |
| _vector_store = db_client.create_collection(name="medical_rag") |
| return _vector_store |
|
|
| def get_finetuned_llm(): |
| """Load fine-tuned model (lazy-loaded on first use)""" |
| global _finetuned_llm |
| if _finetuned_llm is None: |
| ft_model_id = "amiraghhh/fine-tuned-flan-t5-small" |
| ft_tokenizer = AutoTokenizer.from_pretrained(ft_model_id) |
| ft_model = AutoModelForSeq2SeqLM.from_pretrained(ft_model_id) |
| |
| _finetuned_llm = pipeline( |
| "text2text-generation", |
| model=ft_model, |
| tokenizer=ft_tokenizer, |
| decoder_start_token_id=ft_model.config.pad_token_id |
| ) |
| return _finetuned_llm |
|
|
|
|
| |
| |
| |
|
|
| def rag(user_query): |
| """Main RAG function: retrieve context and generate answer. |
| Takes a question string and returns an answer string with confidence. |
| Returns: str(generated_answer)""" |
| |
| try: |
| |
| vector_store = get_vector_store() |
| finetuned_llm = get_finetuned_llm() |
| |
| |
| emergency_keywords = ["emergency", "severe pain", "bleeding", |
| "blind", "lose consciousness", "pass out"] |
| |
| if any(keyword in user_query.lower() for keyword in emergency_keywords): |
| emergency_msg = """I am an AI and cannot provide medical advice for emergencies. |
| PLEASE contact emergency services or a medical professional immediately.""" |
| |
| try: |
| |
| contexts = retrieve(vector_store, user_query, top_k=3, use_reranking=True) |
| |
| if not contexts: |
| return f"{emergency_msg}\n\nNo relevant information found for your query." |
| |
| prompt = build_prompt(user_query, contexts) |
| result = finetuned_llm( |
| prompt, |
| max_new_tokens=70, |
| num_beams=3, |
| early_stopping=True, |
| do_sample=False, |
| repetition_penalty=1.4, |
| eos_token_id=finetuned_llm.tokenizer.eos_token_id |
| ) |
| |
| answer = result[0]['generated_text'].strip() |
| answer = refine_response(answer) |
| |
| |
| if contexts: |
| avg_distance = sum(c.get('chroma_distance', 1.0) for c in contexts) / len(contexts) |
| confidence_score = (1 - avg_distance) * 100 |
| confidence_score = max(0, min(100, confidence_score)) |
| else: |
| confidence_score = 0 |
| |
| return f"{emergency_msg}\n\n[Confidence: {confidence_score:.1f}%]\n\n{answer}" |
| |
| except Exception as e: |
| return f"{emergency_msg}\n\nError generating answer: {str(e)}" |
| |
| |
| contexts = retrieve(vector_store, user_query, top_k=3, use_reranking=True) |
| |
| if not contexts: |
| return "I'm not confident about my answer (0%).\n\nCouldn't find relevant information to answer your question." |
| |
| |
| prompt = build_prompt(user_query, contexts) |
| |
| |
| result = finetuned_llm( |
| prompt, |
| max_new_tokens=70, |
| num_beams=3, |
| early_stopping=True, |
| do_sample=False, |
| repetition_penalty=1.4, |
| eos_token_id=finetuned_llm.tokenizer.eos_token_id |
| ) |
| |
| answer = result[0]['generated_text'].strip() |
| answer = refine_response(answer) |
| |
| |
| if contexts and len(contexts) > 0: |
| avg_distance = sum(c.get('chroma_distance', 1.0) for c in contexts) / len(contexts) |
| confidence_score = (1 - avg_distance) * 100 |
| confidence_score = max(0, min(100, confidence_score)) |
| |
| |
| if confidence_score < 40: |
| final_response = f"I'm not confident about my answer ({confidence_score:.1f}%).\n\n{answer}" |
| else: |
| final_response = f"{answer}\n\n[Confidence: {confidence_score:.1f}%]" |
| else: |
| final_response = "I'm not confident about my answer (0%).\n\n" + answer |
| |
| return final_response |
| |
| except Exception as e: |
| error_msg = f"ERROR in RAG pipeline: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" |
| print(error_msg) |
| return error_msg |
|
|