medimind-api / model.py
Manikantaperla's picture
initial medimind backend
d0c827a
from transformers import pipeline
class QAPipeline:
def __init__(self):
self.qa_pipeline = pipeline(
"question-answering",
model="deepset/bert-base-cased-squad2",
)
self.generator = pipeline(
"text2text-generation",
model="google/flan-t5-base",
max_new_tokens=200,
no_repeat_ngram_size=3,
)
self.retriever_ref = None
print("Models loaded successfully")
def extract(self, question: str, context: str) -> dict:
if not context:
return {"answer": "", "score": 0.0}
result = self.qa_pipeline(question=question, context=context)
return {
"answer": result.get("answer", ""),
"score": float(result.get("score", 0.0)),
}
def generate(self, prompt: str) -> str:
result = self.generator(prompt)
if not result:
return ""
generated = result[0].get("generated_text", "")
return generated.strip()
def clean_output(self, text: str, span: str) -> str:
bad_phrases = [
"your answer must",
"then explain",
"instructions:",
"use only facts",
"answer:",
"respond directly",
"use this format",
"example format",
"use exactly",
"based on the context",
"you are a medical"
]
if any(bp in text.lower() for bp in bad_phrases):
return span if span and len(span) > 3 else ""
return text
def ensure_how_steps(self, text: str, span: str) -> str:
clean = text.strip()
if not clean:
base = span if span else "Medical information is limited for this mechanism."
return f"Step 1: {base}.\nStep 2: This process affects the body through related biological pathways.\nStep 3: Clinical impact depends on severity and patient-specific factors."
if "Step 1:" in clean and "Step 2:" in clean and "Step 3:" in clean:
return clean
sentence = span if span and len(span) > 8 else clean.split(".")[0].strip()
if not sentence:
sentence = "Medical information is limited for this mechanism"
if not sentence.endswith("."):
sentence += "."
return (
f"Step 1: {sentence}\n"
"Step 2: The mechanism continues through biologic or physiologic effects in related tissues.\n"
"Step 3: The final outcome depends on disease severity, timing, and treatment response."
)
def force_bullets(self, text: str) -> str:
lines = text.split(",")
return "\n".join([f"- {line.strip()}" for line in lines if line.strip()])
def answer(self, question: str, passages: list) -> dict:
context = "\n".join([p.get("answer", "") for p in passages[:3]])
source = passages[0].get("source", "unknown") if passages else "unknown"
extracted = self.extract(question, context)
score = extracted["score"]
span = extracted["answer"]
from router import detect_question_type, get_prompt
question_type = detect_question_type(question)
# Hardcoded false premise detection
false_premises = [
"smoking improve", "smoking help", "smoking benefit",
"alcohol improve", "alcohol cure", "drugs improve"
]
q_lower = question.lower()
if any(fp in q_lower for fp in false_premises):
return {
"final_answer": "NO: This premise is medically incorrect. " + (span if span else "Smoking, alcohol and drug abuse are harmful to health."),
"extracted_span": span,
"confidence": round(score, 3),
"source": source,
"question_type": "yes_no",
"low_confidence": False,
"very_low_confidence": False
}
# Abstain if confidence too low
if score < 0.001:
return {
"final_answer": "I could not find reliable medical information for this question. Try rephrasing or asking something more specific.",
"extracted_span": "",
"confidence": score,
"source": source,
"question_type": "abstained",
"low_confidence": True,
"very_low_confidence": True
}
# Special comparison handling — split and dual retrieve
if question_type == "comparison" and self.retriever_ref is not None:
split_patterns = [" vs ", " versus "]
parts = None
for pattern in split_patterns:
if pattern in q_lower:
parts = q_lower.split(pattern, 1)
parts = [
parts[0].replace("compare", "").strip(),
parts[1].strip()
]
break
if parts and len(parts) == 2:
passages1 = self.retriever_ref.retrieve(parts[0].strip(), top_k=2)
passages2 = self.retriever_ref.retrieve(parts[1].strip(), top_k=2)
if not passages1 or not passages2:
return {
"final_answer": "I could not find enough information to compare these items. Try asking about each one separately.",
"extracted_span": span,
"confidence": round(score, 3),
"source": source,
"question_type": "comparison",
"low_confidence": True,
"very_low_confidence": True
}
context1 = "\n".join([p["answer"] for p in passages1])
context2 = "\n".join([p["answer"] for p in passages2])
comparison_context = (context1 + "\n" + context2).lower()
if any(part.lower() not in comparison_context for part in parts):
return {
"final_answer": "I could not find reliable context to compare these options.",
"extracted_span": "",
"confidence": round(score, 3),
"source": source,
"question_type": "abstained",
"low_confidence": True,
"very_low_confidence": True
}
combined_context = f"About {parts[0].strip()}:\n{context1}\n\nAbout {parts[1].strip()}:\n{context2}"
prompt = f"""You are a medical assistant.
Context:
{combined_context}
Question: {question}
Compare these two items using ONLY the context above.
Use exactly this format:
{parts[0].strip().title()}: [key facts]
{parts[1].strip().title()}: [key facts]
Key difference: [main difference]
Answer: {parts[0].strip().title()}:"""
generated = self.generate(prompt)
generated = self.clean_output(generated, span)
fallback = f"{parts[0].strip().title()}: {context1.split('.')[0]}. {parts[1].strip().title()}: {context2.split('.')[0]}."
return {
"final_answer": generated if len(generated.strip()) > 20 else fallback,
"extracted_span": span,
"confidence": round(score, 3),
"source": source,
"question_type": "comparison",
"low_confidence": False,
"very_low_confidence": False
}
# YES/NO with forced prefix
if question_type == "yes_no" and span:
if score < 0.05:
return {
"final_answer": "I could not find reliable medical information for a safe YES/NO answer. Try rephrasing the question.",
"extracted_span": "",
"confidence": score,
"source": source,
"question_type": "abstained",
"low_confidence": True,
"very_low_confidence": True
}
prompt = get_prompt(question_type, question, context)
generated = self.generate(prompt)
generated = self.clean_output(generated, span)
if generated.strip() in ["YES", "NO"]:
explanation = span if span else "based on available medical information"
generated = generated + ": " + explanation
gen_upper = generated.strip().upper()
has_yes_no = gen_upper.startswith("YES") or gen_upper.startswith("NO")
if not has_yes_no or len(generated.strip()) < 5:
context_text = (question + " " + span).lower()
negative_words = ["not", "no", "cannot", "does not",
"harmful", "dangerous", "incorrect",
"never", "false", "wrong"]
positive_words = ["can", "helps", "reduces", "treats",
"effective", "improves", "beneficial",
"yes", "does", "is used"]
neg_count = sum(1 for w in negative_words if w in context_text)
pos_count = sum(1 for w in positive_words if w in context_text)
prefix = "NO: " if neg_count > pos_count else "YES: "
clean_text = span if span and len(span) > 5 else "insufficient medical data found"
generated = prefix + clean_text
return {
"final_answer": generated,
"extracted_span": span,
"confidence": round(score, 3),
"source": source,
"question_type": question_type,
"low_confidence": score < 0.4,
"very_low_confidence": score < 0.2
}
# All other types — use Flan-T5
prompt = get_prompt(question_type, question, context)
generated = self.generate(prompt)
generated = self.clean_output(generated, span)
if question_type == "list":
generated = self.force_bullets(generated)
if question_type == "how":
generated = self.ensure_how_steps(generated, span)
if question_type == "yes_no" and len(generated.strip()) < 10:
generated = span if span else generated
return {
"final_answer": generated if len(generated.strip()) > 10 else span,
"extracted_span": span,
"confidence": round(score, 3),
"source": source,
"question_type": question_type,
"low_confidence": score < 0.4,
"very_low_confidence": score < 0.2
}
if __name__ == "__main__":
from retriever import MedicalRetriever
pipeline_instance = QAPipeline()
retriever = MedicalRetriever.load("artifacts/retriever.pkl")
question = "Does aspirin reduce fever?"
passages = retriever.retrieve(question)
result = pipeline_instance.answer(question, passages)
print(result)