AI-AVECINNA / tutor.py
bdtimuhammad's picture
Upload 5 files
f9d5d64 verified
import os
import json
import torch
from threading import Thread
from transformers import TextIteratorStreamer
def generate_vqa_response(model, tokenizer, message, history, modality, image):
"""
Generates an educational scenario and Socratic questions via local LLM.
IMPLEMENTED AS GENERATOR FOR STREAMING.
"""
if model is None or tokenizer is None:
yield "### Error\nFailed to load local MedGemma model. Please verify your HF_TOKEN and log output."
return
if image is None:
focus = "physiological markers, systemic interactions, and clinical diagnostic criteria"
system_prompt = f"""You are the Clinical Generalist, a highly specialized medical tutoring AI. Your purpose is to facilitate clinical reasoning, not just provide answers.
Follow this structural protocol for EVERY generation:
1. CLINICAL OVERVIEW: Provide a brief, concise overview of the pathophysiology or concepts surrounding the query.
2. SYSTEMIC INVENTORY: Explicitly focus on {focus}.
3. DIFFERENTIAL REASONING: Mention the primary differential diagnosis but immediately contrast it with a 'mimic'.
4. SOCRATIC QUESTIONING: Answer the clinician's query Socraticly. Challenge the clinician to justify their reasoning.
Tone: Professional, objective, and Socratic. Ensure you provide a complete answer."""
else:
if modality == "Chest X-Ray" or modality == "X-Ray":
focus = "costophrenic angles, hilar shadows, and cardiac silhouette"
elif modality == "CT Scan":
focus = "Hounsfield Units (HU), axial cross-sections, and windowing (Lung vs. Soft Tissue)"
elif modality == "MRI":
focus = "T1/T2 weighted signals, contrast enhancement, and multi-planar viewing"
else:
focus = "key anatomical landmarks"
system_prompt = f"""You are the NerdMedica Socratic Auditor, a highly specialized medical tutoring AI. Your purpose is to facilitate clinical reasoning, not just provide answers.
Follow this structural protocol for EVERY generation:
1. CLINICAL SCENARIO: Create a brief, realistic 3-sentence patient history (Age, Chief Complaint, Vitals) that matches the pathology seen in the provided {modality}.
2. ANATOMICAL INVENTORY: Explicitly focus on {focus}.
3. DIFFERENTIAL REASONING: Mention the primary finding but immediately contrast it with a 'mimic'.
4. SOCRATIC QUESTIONING: Answer the clinician's query Socraticly. Challenge the clinician to justify their diagnosis based on visual evidence.
Tone: Professional, objective, and Socratic. Ensure you provide a complete answer."""
prompt_content = f"Clinician Question: {message}"
history_text = ""
if history:
for msg in history:
role = "User" if msg["role"] == "user" else "AI"
content = msg["content"]
history_text += f"\n{role}: {content}"
if history_text:
formatted_prompt = f"{system_prompt}\n\nChat History:{history_text}\n\n{prompt_content}"
else:
formatted_prompt = f"{system_prompt}\n\n{prompt_content}"
messages = [
{"role": "user", "content": formatted_prompt}
]
print("Generating NerdMedica feedback using MedGemma (Streaming)...")
try:
prompt_str = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
)
inputs = tokenizer(prompt_str, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=512,
temperature=0.1,
top_p=0.9,
repetition_penalty=1.2,
do_sample=True
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
generated_text = ""
for new_text in streamer:
generated_text += new_text
yield generated_text
except Exception as e:
print(f"Error during MedGemma generation: {e}")
yield f"### Generation Error\nSystem Recalibrating: VRAM constraint exceeded or model error encountered. Details: {e}"