Spaces:
Running
Running
| from transformers import pipeline | |
| class QAPipeline: | |
| def __init__(self): | |
| self.qa_pipeline = pipeline( | |
| "question-answering", | |
| model="deepset/bert-base-cased-squad2", | |
| ) | |
| self.generator = pipeline( | |
| "text2text-generation", | |
| model="google/flan-t5-base", | |
| max_new_tokens=200, | |
| no_repeat_ngram_size=3, | |
| ) | |
| self.retriever_ref = None | |
| print("Models loaded successfully") | |
| def extract(self, question: str, context: str) -> dict: | |
| if not context: | |
| return {"answer": "", "score": 0.0} | |
| result = self.qa_pipeline(question=question, context=context) | |
| return { | |
| "answer": result.get("answer", ""), | |
| "score": float(result.get("score", 0.0)), | |
| } | |
| def generate(self, prompt: str) -> str: | |
| result = self.generator(prompt) | |
| if not result: | |
| return "" | |
| generated = result[0].get("generated_text", "") | |
| return generated.strip() | |
| def clean_output(self, text: str, span: str) -> str: | |
| bad_phrases = [ | |
| "your answer must", | |
| "then explain", | |
| "instructions:", | |
| "use only facts", | |
| "answer:", | |
| "respond directly", | |
| "use this format", | |
| "example format", | |
| "use exactly", | |
| "based on the context", | |
| "you are a medical" | |
| ] | |
| if any(bp in text.lower() for bp in bad_phrases): | |
| return span if span and len(span) > 3 else "" | |
| return text | |
| def ensure_how_steps(self, text: str, span: str) -> str: | |
| clean = text.strip() | |
| if not clean: | |
| base = span if span else "Medical information is limited for this mechanism." | |
| return f"Step 1: {base}.\nStep 2: This process affects the body through related biological pathways.\nStep 3: Clinical impact depends on severity and patient-specific factors." | |
| if "Step 1:" in clean and "Step 2:" in clean and "Step 3:" in clean: | |
| return clean | |
| sentence = span if span and len(span) > 8 else clean.split(".")[0].strip() | |
| if not sentence: | |
| sentence = "Medical information is limited for this mechanism" | |
| if not sentence.endswith("."): | |
| sentence += "." | |
| return ( | |
| f"Step 1: {sentence}\n" | |
| "Step 2: The mechanism continues through biologic or physiologic effects in related tissues.\n" | |
| "Step 3: The final outcome depends on disease severity, timing, and treatment response." | |
| ) | |
| def force_bullets(self, text: str) -> str: | |
| lines = text.split(",") | |
| return "\n".join([f"- {line.strip()}" for line in lines if line.strip()]) | |
| def answer(self, question: str, passages: list) -> dict: | |
| context = "\n".join([p.get("answer", "") for p in passages[:3]]) | |
| source = passages[0].get("source", "unknown") if passages else "unknown" | |
| extracted = self.extract(question, context) | |
| score = extracted["score"] | |
| span = extracted["answer"] | |
| from router import detect_question_type, get_prompt | |
| question_type = detect_question_type(question) | |
| # Hardcoded false premise detection | |
| false_premises = [ | |
| "smoking improve", "smoking help", "smoking benefit", | |
| "alcohol improve", "alcohol cure", "drugs improve" | |
| ] | |
| q_lower = question.lower() | |
| if any(fp in q_lower for fp in false_premises): | |
| return { | |
| "final_answer": "NO: This premise is medically incorrect. " + (span if span else "Smoking, alcohol and drug abuse are harmful to health."), | |
| "extracted_span": span, | |
| "confidence": round(score, 3), | |
| "source": source, | |
| "question_type": "yes_no", | |
| "low_confidence": False, | |
| "very_low_confidence": False | |
| } | |
| # Abstain if confidence too low | |
| if score < 0.001: | |
| return { | |
| "final_answer": "I could not find reliable medical information for this question. Try rephrasing or asking something more specific.", | |
| "extracted_span": "", | |
| "confidence": score, | |
| "source": source, | |
| "question_type": "abstained", | |
| "low_confidence": True, | |
| "very_low_confidence": True | |
| } | |
| # Special comparison handling — split and dual retrieve | |
| if question_type == "comparison" and self.retriever_ref is not None: | |
| split_patterns = [" vs ", " versus "] | |
| parts = None | |
| for pattern in split_patterns: | |
| if pattern in q_lower: | |
| parts = q_lower.split(pattern, 1) | |
| parts = [ | |
| parts[0].replace("compare", "").strip(), | |
| parts[1].strip() | |
| ] | |
| break | |
| if parts and len(parts) == 2: | |
| passages1 = self.retriever_ref.retrieve(parts[0].strip(), top_k=2) | |
| passages2 = self.retriever_ref.retrieve(parts[1].strip(), top_k=2) | |
| if not passages1 or not passages2: | |
| return { | |
| "final_answer": "I could not find enough information to compare these items. Try asking about each one separately.", | |
| "extracted_span": span, | |
| "confidence": round(score, 3), | |
| "source": source, | |
| "question_type": "comparison", | |
| "low_confidence": True, | |
| "very_low_confidence": True | |
| } | |
| context1 = "\n".join([p["answer"] for p in passages1]) | |
| context2 = "\n".join([p["answer"] for p in passages2]) | |
| comparison_context = (context1 + "\n" + context2).lower() | |
| if any(part.lower() not in comparison_context for part in parts): | |
| return { | |
| "final_answer": "I could not find reliable context to compare these options.", | |
| "extracted_span": "", | |
| "confidence": round(score, 3), | |
| "source": source, | |
| "question_type": "abstained", | |
| "low_confidence": True, | |
| "very_low_confidence": True | |
| } | |
| combined_context = f"About {parts[0].strip()}:\n{context1}\n\nAbout {parts[1].strip()}:\n{context2}" | |
| prompt = f"""You are a medical assistant. | |
| Context: | |
| {combined_context} | |
| Question: {question} | |
| Compare these two items using ONLY the context above. | |
| Use exactly this format: | |
| {parts[0].strip().title()}: [key facts] | |
| {parts[1].strip().title()}: [key facts] | |
| Key difference: [main difference] | |
| Answer: {parts[0].strip().title()}:""" | |
| generated = self.generate(prompt) | |
| generated = self.clean_output(generated, span) | |
| fallback = f"{parts[0].strip().title()}: {context1.split('.')[0]}. {parts[1].strip().title()}: {context2.split('.')[0]}." | |
| return { | |
| "final_answer": generated if len(generated.strip()) > 20 else fallback, | |
| "extracted_span": span, | |
| "confidence": round(score, 3), | |
| "source": source, | |
| "question_type": "comparison", | |
| "low_confidence": False, | |
| "very_low_confidence": False | |
| } | |
| # YES/NO with forced prefix | |
| if question_type == "yes_no" and span: | |
| if score < 0.05: | |
| return { | |
| "final_answer": "I could not find reliable medical information for a safe YES/NO answer. Try rephrasing the question.", | |
| "extracted_span": "", | |
| "confidence": score, | |
| "source": source, | |
| "question_type": "abstained", | |
| "low_confidence": True, | |
| "very_low_confidence": True | |
| } | |
| prompt = get_prompt(question_type, question, context) | |
| generated = self.generate(prompt) | |
| generated = self.clean_output(generated, span) | |
| if generated.strip() in ["YES", "NO"]: | |
| explanation = span if span else "based on available medical information" | |
| generated = generated + ": " + explanation | |
| gen_upper = generated.strip().upper() | |
| has_yes_no = gen_upper.startswith("YES") or gen_upper.startswith("NO") | |
| if not has_yes_no or len(generated.strip()) < 5: | |
| context_text = (question + " " + span).lower() | |
| negative_words = ["not", "no", "cannot", "does not", | |
| "harmful", "dangerous", "incorrect", | |
| "never", "false", "wrong"] | |
| positive_words = ["can", "helps", "reduces", "treats", | |
| "effective", "improves", "beneficial", | |
| "yes", "does", "is used"] | |
| neg_count = sum(1 for w in negative_words if w in context_text) | |
| pos_count = sum(1 for w in positive_words if w in context_text) | |
| prefix = "NO: " if neg_count > pos_count else "YES: " | |
| clean_text = span if span and len(span) > 5 else "insufficient medical data found" | |
| generated = prefix + clean_text | |
| return { | |
| "final_answer": generated, | |
| "extracted_span": span, | |
| "confidence": round(score, 3), | |
| "source": source, | |
| "question_type": question_type, | |
| "low_confidence": score < 0.4, | |
| "very_low_confidence": score < 0.2 | |
| } | |
| # All other types — use Flan-T5 | |
| prompt = get_prompt(question_type, question, context) | |
| generated = self.generate(prompt) | |
| generated = self.clean_output(generated, span) | |
| if question_type == "list": | |
| generated = self.force_bullets(generated) | |
| if question_type == "how": | |
| generated = self.ensure_how_steps(generated, span) | |
| if question_type == "yes_no" and len(generated.strip()) < 10: | |
| generated = span if span else generated | |
| return { | |
| "final_answer": generated if len(generated.strip()) > 10 else span, | |
| "extracted_span": span, | |
| "confidence": round(score, 3), | |
| "source": source, | |
| "question_type": question_type, | |
| "low_confidence": score < 0.4, | |
| "very_low_confidence": score < 0.2 | |
| } | |
| if __name__ == "__main__": | |
| from retriever import MedicalRetriever | |
| pipeline_instance = QAPipeline() | |
| retriever = MedicalRetriever.load("artifacts/retriever.pkl") | |
| question = "Does aspirin reduce fever?" | |
| passages = retriever.retrieve(question) | |
| result = pipeline_instance.answer(question, passages) | |
| print(result) | |