File size: 6,080 Bytes
2c51f91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import gradio as gr
import torch
from PIL import Image
from transformers import (
    CLIPProcessor, CLIPModel,
    BlipProcessor, BlipForConditionalGeneration,
    GPT2Tokenizer, GPT2LMHeadModel
)
from unsloth import FastLanguageModel  # For fast quantized TinyLlama

# Device and dtype setup
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = None if torch.cuda.is_available() else torch.float32

# Load models (Gradio/HF caches 'em)
@gr.cache
def load_models():
    clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch.float16 if device == "cuda" else torch.float32).to(device)
    clip_proc = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=False)
    blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float16 if device == "cuda" else torch.float32).to(device)
    blip_proc = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base", use_fast=True)
    
    # Unsloth for fast TinyLlama (fallback to regular if issues)
    try:
        model, llama_tok = FastLanguageModel.from_pretrained(
            model_name="unsloth/tinyllama-bnb-4bit",
            max_seq_length=2048,
            dtype=dtype,
            load_in_4bit=True,
        )
        FastLanguageModel.for_inference(model)
    except:
        # Fallback to regular
        from transformers import AutoTokenizer, AutoModelForCausalLM
        llama_tok = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
        if llama_tok.pad_token is None:
            llama_tok.pad_token = llama_tok.eos_token
        model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.float16 if device == "cuda" else torch.float32).to(device)
    
    gpt2_tok = GPT2Tokenizer.from_pretrained("distilgpt2")
    if gpt2_tok.pad_token is None:
        gpt2_tok.pad_token = gpt2_tok.eos_token
    gpt2_model = GPT2LMHeadModel.from_pretrained("distilgpt2", torch_dtype=torch.float16 if device == "cuda" else torch.float32).to(device)
    
    return clip_model, clip_proc, blip_model, blip_proc, llama_tok, model, gpt2_tok, gpt2_model

clip_model, clip_proc, blip_model, blip_proc, llama_tok, llama_model, gpt2_tok, gpt2_model = load_models()

def generate_report(image):
    if image is None:
        return "Upload an X-ray to get started!", "", ""
    
    image = image.convert("RGB").resize((224, 224))
    
    # Step 1: Caption
    with torch.no_grad():
        blip_inputs = blip_proc(images=image, return_tensors="pt").to(device)
        caption_ids = blip_model.generate(**blip_inputs)
        caption = blip_proc.decode(caption_ids[0], skip_special_tokens=True)
    
    # Step 2: Findings
    reference_findings = [
        "There is evidence of right lower lobe consolidation.",
        "No acute cardiopulmonary abnormality.",
        "Mild cardiomegaly with clear lungs.",
        "Findings consistent with pneumonia.",
        "Chronic interstitial changes noted."
    ]
    with torch.no_grad():
        img_inputs = clip_proc(images=image, return_tensors="pt").to(device)
        img_features = clip_model.get_image_features(**img_inputs)
        img_features = torch.nn.functional.normalize(img_features, dim=-1)
        
        text_inputs = clip_proc.tokenizer(reference_findings, return_tensors="pt", padding=True, truncation=True).to(device)
        txt_features = clip_model.get_text_features(**text_inputs)
        txt_features = torch.nn.functional.normalize(txt_features, dim=-1)
        
        similarities = torch.matmul(txt_features, img_features.T).squeeze()
        top_indices = torch.topk(similarities, k=3).indices
        top_findings = [reference_findings[i] for i in top_indices]
    
    findings_text = f"• {top_findings[0]}\n• {top_findings[1]}\n• {top_findings[2]}"
    
    # Step 3: Draft
    llama_prompt = f"🧠 Caption: {caption}\n📚 Retrieved Reports:\n" + "\n".join(f"- {f}" for f in top_findings) + "\n\nGenerate a clinical-style radiology report:"
    with torch.no_grad():
        inputs = llama_tok(llama_prompt, return_tensors="pt").to(device)
        outputs = llama_model.generate(**inputs, max_new_tokens=75, use_cache=True, do_sample=True, temperature=0.7)
        draft_report = llama_tok.decode(outputs[0], skip_special_tokens=True)
    
    # Steps 4-5: Refine
    gpt2_input_1 = f"Caption: {caption}\nDraft: {draft_report}\nRefine this into a structured radiology report:"
    with torch.no_grad():
        gpt2_inputs_1 = gpt2_tok(gpt2_input_1, return_tensors="pt").to(device)
        gpt2_output_1 = gpt2_model.generate(**gpt2_inputs_1, max_new_tokens=100, do_sample=True, temperature=0.7, pad_token_id=gpt2_tok.eos_token_id)
        refined_1 = gpt2_tok.decode(gpt2_output_1[0], skip_special_tokens=True)
        
        gpt2_input_2 = f"{refined_1}\nRefine the new report:"
        gpt2_inputs_2 = gpt2_tok(gpt2_input_2, return_tensors="pt").to(device)
        gpt2_output_2 = gpt2_model.generate(**gpt2_inputs_2, max_new_tokens=75, do_sample=True, temperature=0.7, pad_token_id=gpt2_tok.eos_token_id)
        refined_2 = gpt2_tok.decode(gpt2_output_2[0], skip_special_tokens=True)
    
    return f"**Caption:** {caption}", findings_text, refined_2

# Gradio Interface
with gr.Blocks(title="🩺 AI Radiology Assistant") as demo:
    gr.Markdown("# 🩺 AI-Powered Chest X-ray Interpretation Tool")
    
    with gr.Row():
        image_input = gr.Image(type="pil", label="Upload Chest X-ray (PNG/JPG)")
    
    caption_out = gr.Textbox(label="🧠 AI Captioning", interactive=False)
    findings_out = gr.Textbox(label="🔍 Top Similar Clinical Findings", interactive=False)
    report_out = gr.Textbox(label="📄 Final Structured Radiology Report", lines=20, interactive=False)
    
    submit_btn = gr.Button("Generate Report", variant="primary")
    
    submit_btn.click(
        generate_report,
        inputs=image_input,
        outputs=[caption_out, findings_out, report_out]
    )

if __name__ == "__main__":
    demo.launch()