import gradio as gr import torch from PIL import Image from transformers import ( CLIPProcessor, CLIPModel, BlipProcessor, BlipForConditionalGeneration, GPT2Tokenizer, GPT2LMHeadModel ) from unsloth import FastLanguageModel # For fast quantized TinyLlama # Device and dtype setup device = "cuda" if torch.cuda.is_available() else "cpu" dtype = None if torch.cuda.is_available() else torch.float32 # Load models (Gradio/HF caches 'em) @gr.cache def load_models(): clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch.float16 if device == "cuda" else torch.float32).to(device) clip_proc = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=False) blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float16 if device == "cuda" else torch.float32).to(device) blip_proc = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base", use_fast=True) # Unsloth for fast TinyLlama (fallback to regular if issues) try: model, llama_tok = FastLanguageModel.from_pretrained( model_name="unsloth/tinyllama-bnb-4bit", max_seq_length=2048, dtype=dtype, load_in_4bit=True, ) FastLanguageModel.for_inference(model) except: # Fallback to regular from transformers import AutoTokenizer, AutoModelForCausalLM llama_tok = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") if llama_tok.pad_token is None: llama_tok.pad_token = llama_tok.eos_token model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.float16 if device == "cuda" else torch.float32).to(device) gpt2_tok = GPT2Tokenizer.from_pretrained("distilgpt2") if gpt2_tok.pad_token is None: gpt2_tok.pad_token = gpt2_tok.eos_token gpt2_model = GPT2LMHeadModel.from_pretrained("distilgpt2", torch_dtype=torch.float16 if device == "cuda" else torch.float32).to(device) return clip_model, clip_proc, blip_model, blip_proc, llama_tok, model, gpt2_tok, gpt2_model clip_model, clip_proc, blip_model, blip_proc, llama_tok, llama_model, gpt2_tok, gpt2_model = load_models() def generate_report(image): if image is None: return "Upload an X-ray to get started!", "", "" image = image.convert("RGB").resize((224, 224)) # Step 1: Caption with torch.no_grad(): blip_inputs = blip_proc(images=image, return_tensors="pt").to(device) caption_ids = blip_model.generate(**blip_inputs) caption = blip_proc.decode(caption_ids[0], skip_special_tokens=True) # Step 2: Findings reference_findings = [ "There is evidence of right lower lobe consolidation.", "No acute cardiopulmonary abnormality.", "Mild cardiomegaly with clear lungs.", "Findings consistent with pneumonia.", "Chronic interstitial changes noted." ] with torch.no_grad(): img_inputs = clip_proc(images=image, return_tensors="pt").to(device) img_features = clip_model.get_image_features(**img_inputs) img_features = torch.nn.functional.normalize(img_features, dim=-1) text_inputs = clip_proc.tokenizer(reference_findings, return_tensors="pt", padding=True, truncation=True).to(device) txt_features = clip_model.get_text_features(**text_inputs) txt_features = torch.nn.functional.normalize(txt_features, dim=-1) similarities = torch.matmul(txt_features, img_features.T).squeeze() top_indices = torch.topk(similarities, k=3).indices top_findings = [reference_findings[i] for i in top_indices] findings_text = f"• {top_findings[0]}\n• {top_findings[1]}\n• {top_findings[2]}" # Step 3: Draft llama_prompt = f"🧠 Caption: {caption}\nšŸ“š Retrieved Reports:\n" + "\n".join(f"- {f}" for f in top_findings) + "\n\nGenerate a clinical-style radiology report:" with torch.no_grad(): inputs = llama_tok(llama_prompt, return_tensors="pt").to(device) outputs = llama_model.generate(**inputs, max_new_tokens=75, use_cache=True, do_sample=True, temperature=0.7) draft_report = llama_tok.decode(outputs[0], skip_special_tokens=True) # Steps 4-5: Refine gpt2_input_1 = f"Caption: {caption}\nDraft: {draft_report}\nRefine this into a structured radiology report:" with torch.no_grad(): gpt2_inputs_1 = gpt2_tok(gpt2_input_1, return_tensors="pt").to(device) gpt2_output_1 = gpt2_model.generate(**gpt2_inputs_1, max_new_tokens=100, do_sample=True, temperature=0.7, pad_token_id=gpt2_tok.eos_token_id) refined_1 = gpt2_tok.decode(gpt2_output_1[0], skip_special_tokens=True) gpt2_input_2 = f"{refined_1}\nRefine the new report:" gpt2_inputs_2 = gpt2_tok(gpt2_input_2, return_tensors="pt").to(device) gpt2_output_2 = gpt2_model.generate(**gpt2_inputs_2, max_new_tokens=75, do_sample=True, temperature=0.7, pad_token_id=gpt2_tok.eos_token_id) refined_2 = gpt2_tok.decode(gpt2_output_2[0], skip_special_tokens=True) return f"**Caption:** {caption}", findings_text, refined_2 # Gradio Interface with gr.Blocks(title="🩺 AI Radiology Assistant") as demo: gr.Markdown("# 🩺 AI-Powered Chest X-ray Interpretation Tool") with gr.Row(): image_input = gr.Image(type="pil", label="Upload Chest X-ray (PNG/JPG)") caption_out = gr.Textbox(label="🧠 AI Captioning", interactive=False) findings_out = gr.Textbox(label="šŸ” Top Similar Clinical Findings", interactive=False) report_out = gr.Textbox(label="šŸ“„ Final Structured Radiology Report", lines=20, interactive=False) submit_btn = gr.Button("Generate Report", variant="primary") submit_btn.click( generate_report, inputs=image_input, outputs=[caption_out, findings_out, report_out] ) if __name__ == "__main__": demo.launch()