Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from transformers import ( | |
| CLIPProcessor, CLIPModel, | |
| BlipProcessor, BlipForConditionalGeneration, | |
| GPT2Tokenizer, GPT2LMHeadModel | |
| ) | |
| from unsloth import FastLanguageModel # For fast quantized TinyLlama | |
| # Device and dtype setup | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = None if torch.cuda.is_available() else torch.float32 | |
| # Load models (Gradio/HF caches 'em) | |
| def load_models(): | |
| clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch.float16 if device == "cuda" else torch.float32).to(device) | |
| clip_proc = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=False) | |
| blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float16 if device == "cuda" else torch.float32).to(device) | |
| blip_proc = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base", use_fast=True) | |
| # Unsloth for fast TinyLlama (fallback to regular if issues) | |
| try: | |
| model, llama_tok = FastLanguageModel.from_pretrained( | |
| model_name="unsloth/tinyllama-bnb-4bit", | |
| max_seq_length=2048, | |
| dtype=dtype, | |
| load_in_4bit=True, | |
| ) | |
| FastLanguageModel.for_inference(model) | |
| except: | |
| # Fallback to regular | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| llama_tok = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") | |
| if llama_tok.pad_token is None: | |
| llama_tok.pad_token = llama_tok.eos_token | |
| model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.float16 if device == "cuda" else torch.float32).to(device) | |
| gpt2_tok = GPT2Tokenizer.from_pretrained("distilgpt2") | |
| if gpt2_tok.pad_token is None: | |
| gpt2_tok.pad_token = gpt2_tok.eos_token | |
| gpt2_model = GPT2LMHeadModel.from_pretrained("distilgpt2", torch_dtype=torch.float16 if device == "cuda" else torch.float32).to(device) | |
| return clip_model, clip_proc, blip_model, blip_proc, llama_tok, model, gpt2_tok, gpt2_model | |
| clip_model, clip_proc, blip_model, blip_proc, llama_tok, llama_model, gpt2_tok, gpt2_model = load_models() | |
| def generate_report(image): | |
| if image is None: | |
| return "Upload an X-ray to get started!", "", "" | |
| image = image.convert("RGB").resize((224, 224)) | |
| # Step 1: Caption | |
| with torch.no_grad(): | |
| blip_inputs = blip_proc(images=image, return_tensors="pt").to(device) | |
| caption_ids = blip_model.generate(**blip_inputs) | |
| caption = blip_proc.decode(caption_ids[0], skip_special_tokens=True) | |
| # Step 2: Findings | |
| reference_findings = [ | |
| "There is evidence of right lower lobe consolidation.", | |
| "No acute cardiopulmonary abnormality.", | |
| "Mild cardiomegaly with clear lungs.", | |
| "Findings consistent with pneumonia.", | |
| "Chronic interstitial changes noted." | |
| ] | |
| with torch.no_grad(): | |
| img_inputs = clip_proc(images=image, return_tensors="pt").to(device) | |
| img_features = clip_model.get_image_features(**img_inputs) | |
| img_features = torch.nn.functional.normalize(img_features, dim=-1) | |
| text_inputs = clip_proc.tokenizer(reference_findings, return_tensors="pt", padding=True, truncation=True).to(device) | |
| txt_features = clip_model.get_text_features(**text_inputs) | |
| txt_features = torch.nn.functional.normalize(txt_features, dim=-1) | |
| similarities = torch.matmul(txt_features, img_features.T).squeeze() | |
| top_indices = torch.topk(similarities, k=3).indices | |
| top_findings = [reference_findings[i] for i in top_indices] | |
| findings_text = f"• {top_findings[0]}\n• {top_findings[1]}\n• {top_findings[2]}" | |
| # Step 3: Draft | |
| llama_prompt = f"🧠 Caption: {caption}\n📚 Retrieved Reports:\n" + "\n".join(f"- {f}" for f in top_findings) + "\n\nGenerate a clinical-style radiology report:" | |
| with torch.no_grad(): | |
| inputs = llama_tok(llama_prompt, return_tensors="pt").to(device) | |
| outputs = llama_model.generate(**inputs, max_new_tokens=75, use_cache=True, do_sample=True, temperature=0.7) | |
| draft_report = llama_tok.decode(outputs[0], skip_special_tokens=True) | |
| # Steps 4-5: Refine | |
| gpt2_input_1 = f"Caption: {caption}\nDraft: {draft_report}\nRefine this into a structured radiology report:" | |
| with torch.no_grad(): | |
| gpt2_inputs_1 = gpt2_tok(gpt2_input_1, return_tensors="pt").to(device) | |
| gpt2_output_1 = gpt2_model.generate(**gpt2_inputs_1, max_new_tokens=100, do_sample=True, temperature=0.7, pad_token_id=gpt2_tok.eos_token_id) | |
| refined_1 = gpt2_tok.decode(gpt2_output_1[0], skip_special_tokens=True) | |
| gpt2_input_2 = f"{refined_1}\nRefine the new report:" | |
| gpt2_inputs_2 = gpt2_tok(gpt2_input_2, return_tensors="pt").to(device) | |
| gpt2_output_2 = gpt2_model.generate(**gpt2_inputs_2, max_new_tokens=75, do_sample=True, temperature=0.7, pad_token_id=gpt2_tok.eos_token_id) | |
| refined_2 = gpt2_tok.decode(gpt2_output_2[0], skip_special_tokens=True) | |
| return f"**Caption:** {caption}", findings_text, refined_2 | |
| # Gradio Interface | |
| with gr.Blocks(title="🩺 AI Radiology Assistant") as demo: | |
| gr.Markdown("# 🩺 AI-Powered Chest X-ray Interpretation Tool") | |
| with gr.Row(): | |
| image_input = gr.Image(type="pil", label="Upload Chest X-ray (PNG/JPG)") | |
| caption_out = gr.Textbox(label="🧠 AI Captioning", interactive=False) | |
| findings_out = gr.Textbox(label="🔍 Top Similar Clinical Findings", interactive=False) | |
| report_out = gr.Textbox(label="📄 Final Structured Radiology Report", lines=20, interactive=False) | |
| submit_btn = gr.Button("Generate Report", variant="primary") | |
| submit_btn.click( | |
| generate_report, | |
| inputs=image_input, | |
| outputs=[caption_out, findings_out, report_out] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |