Spaces:
Runtime error
Runtime error
File size: 6,080 Bytes
2c51f91 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import gradio as gr
import torch
from PIL import Image
from transformers import (
CLIPProcessor, CLIPModel,
BlipProcessor, BlipForConditionalGeneration,
GPT2Tokenizer, GPT2LMHeadModel
)
from unsloth import FastLanguageModel # For fast quantized TinyLlama
# Device and dtype setup
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = None if torch.cuda.is_available() else torch.float32
# Load models (Gradio/HF caches 'em)
@gr.cache
def load_models():
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch.float16 if device == "cuda" else torch.float32).to(device)
clip_proc = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=False)
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float16 if device == "cuda" else torch.float32).to(device)
blip_proc = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base", use_fast=True)
# Unsloth for fast TinyLlama (fallback to regular if issues)
try:
model, llama_tok = FastLanguageModel.from_pretrained(
model_name="unsloth/tinyllama-bnb-4bit",
max_seq_length=2048,
dtype=dtype,
load_in_4bit=True,
)
FastLanguageModel.for_inference(model)
except:
# Fallback to regular
from transformers import AutoTokenizer, AutoModelForCausalLM
llama_tok = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
if llama_tok.pad_token is None:
llama_tok.pad_token = llama_tok.eos_token
model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.float16 if device == "cuda" else torch.float32).to(device)
gpt2_tok = GPT2Tokenizer.from_pretrained("distilgpt2")
if gpt2_tok.pad_token is None:
gpt2_tok.pad_token = gpt2_tok.eos_token
gpt2_model = GPT2LMHeadModel.from_pretrained("distilgpt2", torch_dtype=torch.float16 if device == "cuda" else torch.float32).to(device)
return clip_model, clip_proc, blip_model, blip_proc, llama_tok, model, gpt2_tok, gpt2_model
clip_model, clip_proc, blip_model, blip_proc, llama_tok, llama_model, gpt2_tok, gpt2_model = load_models()
def generate_report(image):
if image is None:
return "Upload an X-ray to get started!", "", ""
image = image.convert("RGB").resize((224, 224))
# Step 1: Caption
with torch.no_grad():
blip_inputs = blip_proc(images=image, return_tensors="pt").to(device)
caption_ids = blip_model.generate(**blip_inputs)
caption = blip_proc.decode(caption_ids[0], skip_special_tokens=True)
# Step 2: Findings
reference_findings = [
"There is evidence of right lower lobe consolidation.",
"No acute cardiopulmonary abnormality.",
"Mild cardiomegaly with clear lungs.",
"Findings consistent with pneumonia.",
"Chronic interstitial changes noted."
]
with torch.no_grad():
img_inputs = clip_proc(images=image, return_tensors="pt").to(device)
img_features = clip_model.get_image_features(**img_inputs)
img_features = torch.nn.functional.normalize(img_features, dim=-1)
text_inputs = clip_proc.tokenizer(reference_findings, return_tensors="pt", padding=True, truncation=True).to(device)
txt_features = clip_model.get_text_features(**text_inputs)
txt_features = torch.nn.functional.normalize(txt_features, dim=-1)
similarities = torch.matmul(txt_features, img_features.T).squeeze()
top_indices = torch.topk(similarities, k=3).indices
top_findings = [reference_findings[i] for i in top_indices]
findings_text = f"• {top_findings[0]}\n• {top_findings[1]}\n• {top_findings[2]}"
# Step 3: Draft
llama_prompt = f"🧠 Caption: {caption}\n📚 Retrieved Reports:\n" + "\n".join(f"- {f}" for f in top_findings) + "\n\nGenerate a clinical-style radiology report:"
with torch.no_grad():
inputs = llama_tok(llama_prompt, return_tensors="pt").to(device)
outputs = llama_model.generate(**inputs, max_new_tokens=75, use_cache=True, do_sample=True, temperature=0.7)
draft_report = llama_tok.decode(outputs[0], skip_special_tokens=True)
# Steps 4-5: Refine
gpt2_input_1 = f"Caption: {caption}\nDraft: {draft_report}\nRefine this into a structured radiology report:"
with torch.no_grad():
gpt2_inputs_1 = gpt2_tok(gpt2_input_1, return_tensors="pt").to(device)
gpt2_output_1 = gpt2_model.generate(**gpt2_inputs_1, max_new_tokens=100, do_sample=True, temperature=0.7, pad_token_id=gpt2_tok.eos_token_id)
refined_1 = gpt2_tok.decode(gpt2_output_1[0], skip_special_tokens=True)
gpt2_input_2 = f"{refined_1}\nRefine the new report:"
gpt2_inputs_2 = gpt2_tok(gpt2_input_2, return_tensors="pt").to(device)
gpt2_output_2 = gpt2_model.generate(**gpt2_inputs_2, max_new_tokens=75, do_sample=True, temperature=0.7, pad_token_id=gpt2_tok.eos_token_id)
refined_2 = gpt2_tok.decode(gpt2_output_2[0], skip_special_tokens=True)
return f"**Caption:** {caption}", findings_text, refined_2
# Gradio Interface
with gr.Blocks(title="🩺 AI Radiology Assistant") as demo:
gr.Markdown("# 🩺 AI-Powered Chest X-ray Interpretation Tool")
with gr.Row():
image_input = gr.Image(type="pil", label="Upload Chest X-ray (PNG/JPG)")
caption_out = gr.Textbox(label="🧠 AI Captioning", interactive=False)
findings_out = gr.Textbox(label="🔍 Top Similar Clinical Findings", interactive=False)
report_out = gr.Textbox(label="📄 Final Structured Radiology Report", lines=20, interactive=False)
submit_btn = gr.Button("Generate Report", variant="primary")
submit_btn.click(
generate_report,
inputs=image_input,
outputs=[caption_out, findings_out, report_out]
)
if __name__ == "__main__":
demo.launch() |