from transformers import pipeline class QAPipeline: def __init__(self): self.qa_pipeline = pipeline( "question-answering", model="deepset/bert-base-cased-squad2", ) self.generator = pipeline( "text2text-generation", model="google/flan-t5-base", max_new_tokens=200, no_repeat_ngram_size=3, ) self.retriever_ref = None print("Models loaded successfully") def extract(self, question: str, context: str) -> dict: if not context: return {"answer": "", "score": 0.0} result = self.qa_pipeline(question=question, context=context) return { "answer": result.get("answer", ""), "score": float(result.get("score", 0.0)), } def generate(self, prompt: str) -> str: result = self.generator(prompt) if not result: return "" generated = result[0].get("generated_text", "") return generated.strip() def clean_output(self, text: str, span: str) -> str: bad_phrases = [ "your answer must", "then explain", "instructions:", "use only facts", "answer:", "respond directly", "use this format", "example format", "use exactly", "based on the context", "you are a medical" ] if any(bp in text.lower() for bp in bad_phrases): return span if span and len(span) > 3 else "" return text def ensure_how_steps(self, text: str, span: str) -> str: clean = text.strip() if not clean: base = span if span else "Medical information is limited for this mechanism." return f"Step 1: {base}.\nStep 2: This process affects the body through related biological pathways.\nStep 3: Clinical impact depends on severity and patient-specific factors." if "Step 1:" in clean and "Step 2:" in clean and "Step 3:" in clean: return clean sentence = span if span and len(span) > 8 else clean.split(".")[0].strip() if not sentence: sentence = "Medical information is limited for this mechanism" if not sentence.endswith("."): sentence += "." return ( f"Step 1: {sentence}\n" "Step 2: The mechanism continues through biologic or physiologic effects in related tissues.\n" "Step 3: The final outcome depends on disease severity, timing, and treatment response." ) def force_bullets(self, text: str) -> str: lines = text.split(",") return "\n".join([f"- {line.strip()}" for line in lines if line.strip()]) def answer(self, question: str, passages: list) -> dict: context = "\n".join([p.get("answer", "") for p in passages[:3]]) source = passages[0].get("source", "unknown") if passages else "unknown" extracted = self.extract(question, context) score = extracted["score"] span = extracted["answer"] from router import detect_question_type, get_prompt question_type = detect_question_type(question) # Hardcoded false premise detection false_premises = [ "smoking improve", "smoking help", "smoking benefit", "alcohol improve", "alcohol cure", "drugs improve" ] q_lower = question.lower() if any(fp in q_lower for fp in false_premises): return { "final_answer": "NO: This premise is medically incorrect. " + (span if span else "Smoking, alcohol and drug abuse are harmful to health."), "extracted_span": span, "confidence": round(score, 3), "source": source, "question_type": "yes_no", "low_confidence": False, "very_low_confidence": False } # Abstain if confidence too low if score < 0.001: return { "final_answer": "I could not find reliable medical information for this question. Try rephrasing or asking something more specific.", "extracted_span": "", "confidence": score, "source": source, "question_type": "abstained", "low_confidence": True, "very_low_confidence": True } # Special comparison handling — split and dual retrieve if question_type == "comparison" and self.retriever_ref is not None: split_patterns = [" vs ", " versus "] parts = None for pattern in split_patterns: if pattern in q_lower: parts = q_lower.split(pattern, 1) parts = [ parts[0].replace("compare", "").strip(), parts[1].strip() ] break if parts and len(parts) == 2: passages1 = self.retriever_ref.retrieve(parts[0].strip(), top_k=2) passages2 = self.retriever_ref.retrieve(parts[1].strip(), top_k=2) if not passages1 or not passages2: return { "final_answer": "I could not find enough information to compare these items. Try asking about each one separately.", "extracted_span": span, "confidence": round(score, 3), "source": source, "question_type": "comparison", "low_confidence": True, "very_low_confidence": True } context1 = "\n".join([p["answer"] for p in passages1]) context2 = "\n".join([p["answer"] for p in passages2]) comparison_context = (context1 + "\n" + context2).lower() if any(part.lower() not in comparison_context for part in parts): return { "final_answer": "I could not find reliable context to compare these options.", "extracted_span": "", "confidence": round(score, 3), "source": source, "question_type": "abstained", "low_confidence": True, "very_low_confidence": True } combined_context = f"About {parts[0].strip()}:\n{context1}\n\nAbout {parts[1].strip()}:\n{context2}" prompt = f"""You are a medical assistant. Context: {combined_context} Question: {question} Compare these two items using ONLY the context above. Use exactly this format: {parts[0].strip().title()}: [key facts] {parts[1].strip().title()}: [key facts] Key difference: [main difference] Answer: {parts[0].strip().title()}:""" generated = self.generate(prompt) generated = self.clean_output(generated, span) fallback = f"{parts[0].strip().title()}: {context1.split('.')[0]}. {parts[1].strip().title()}: {context2.split('.')[0]}." return { "final_answer": generated if len(generated.strip()) > 20 else fallback, "extracted_span": span, "confidence": round(score, 3), "source": source, "question_type": "comparison", "low_confidence": False, "very_low_confidence": False } # YES/NO with forced prefix if question_type == "yes_no" and span: if score < 0.05: return { "final_answer": "I could not find reliable medical information for a safe YES/NO answer. Try rephrasing the question.", "extracted_span": "", "confidence": score, "source": source, "question_type": "abstained", "low_confidence": True, "very_low_confidence": True } prompt = get_prompt(question_type, question, context) generated = self.generate(prompt) generated = self.clean_output(generated, span) if generated.strip() in ["YES", "NO"]: explanation = span if span else "based on available medical information" generated = generated + ": " + explanation gen_upper = generated.strip().upper() has_yes_no = gen_upper.startswith("YES") or gen_upper.startswith("NO") if not has_yes_no or len(generated.strip()) < 5: context_text = (question + " " + span).lower() negative_words = ["not", "no", "cannot", "does not", "harmful", "dangerous", "incorrect", "never", "false", "wrong"] positive_words = ["can", "helps", "reduces", "treats", "effective", "improves", "beneficial", "yes", "does", "is used"] neg_count = sum(1 for w in negative_words if w in context_text) pos_count = sum(1 for w in positive_words if w in context_text) prefix = "NO: " if neg_count > pos_count else "YES: " clean_text = span if span and len(span) > 5 else "insufficient medical data found" generated = prefix + clean_text return { "final_answer": generated, "extracted_span": span, "confidence": round(score, 3), "source": source, "question_type": question_type, "low_confidence": score < 0.4, "very_low_confidence": score < 0.2 } # All other types — use Flan-T5 prompt = get_prompt(question_type, question, context) generated = self.generate(prompt) generated = self.clean_output(generated, span) if question_type == "list": generated = self.force_bullets(generated) if question_type == "how": generated = self.ensure_how_steps(generated, span) if question_type == "yes_no" and len(generated.strip()) < 10: generated = span if span else generated return { "final_answer": generated if len(generated.strip()) > 10 else span, "extracted_span": span, "confidence": round(score, 3), "source": source, "question_type": question_type, "low_confidence": score < 0.4, "very_low_confidence": score < 0.2 } if __name__ == "__main__": from retriever import MedicalRetriever pipeline_instance = QAPipeline() retriever = MedicalRetriever.load("artifacts/retriever.pkl") question = "Does aspirin reduce fever?" passages = retriever.retrieve(question) result = pipeline_instance.answer(question, passages) print(result)