HusainNaqvijobs commited on
Commit
2c51f91
·
verified ·
1 Parent(s): 6304ad6
Files changed (1) hide show
  1. app.py +125 -0
app.py CHANGED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ from transformers import (
5
+ CLIPProcessor, CLIPModel,
6
+ BlipProcessor, BlipForConditionalGeneration,
7
+ GPT2Tokenizer, GPT2LMHeadModel
8
+ )
9
+ from unsloth import FastLanguageModel # For fast quantized TinyLlama
10
+
11
+ # Device and dtype setup
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ dtype = None if torch.cuda.is_available() else torch.float32
14
+
15
+ # Load models (Gradio/HF caches 'em)
16
+ @gr.cache
17
+ def load_models():
18
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch.float16 if device == "cuda" else torch.float32).to(device)
19
+ clip_proc = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=False)
20
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float16 if device == "cuda" else torch.float32).to(device)
21
+ blip_proc = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base", use_fast=True)
22
+
23
+ # Unsloth for fast TinyLlama (fallback to regular if issues)
24
+ try:
25
+ model, llama_tok = FastLanguageModel.from_pretrained(
26
+ model_name="unsloth/tinyllama-bnb-4bit",
27
+ max_seq_length=2048,
28
+ dtype=dtype,
29
+ load_in_4bit=True,
30
+ )
31
+ FastLanguageModel.for_inference(model)
32
+ except:
33
+ # Fallback to regular
34
+ from transformers import AutoTokenizer, AutoModelForCausalLM
35
+ llama_tok = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
36
+ if llama_tok.pad_token is None:
37
+ llama_tok.pad_token = llama_tok.eos_token
38
+ model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.float16 if device == "cuda" else torch.float32).to(device)
39
+
40
+ gpt2_tok = GPT2Tokenizer.from_pretrained("distilgpt2")
41
+ if gpt2_tok.pad_token is None:
42
+ gpt2_tok.pad_token = gpt2_tok.eos_token
43
+ gpt2_model = GPT2LMHeadModel.from_pretrained("distilgpt2", torch_dtype=torch.float16 if device == "cuda" else torch.float32).to(device)
44
+
45
+ return clip_model, clip_proc, blip_model, blip_proc, llama_tok, model, gpt2_tok, gpt2_model
46
+
47
+ clip_model, clip_proc, blip_model, blip_proc, llama_tok, llama_model, gpt2_tok, gpt2_model = load_models()
48
+
49
+ def generate_report(image):
50
+ if image is None:
51
+ return "Upload an X-ray to get started!", "", ""
52
+
53
+ image = image.convert("RGB").resize((224, 224))
54
+
55
+ # Step 1: Caption
56
+ with torch.no_grad():
57
+ blip_inputs = blip_proc(images=image, return_tensors="pt").to(device)
58
+ caption_ids = blip_model.generate(**blip_inputs)
59
+ caption = blip_proc.decode(caption_ids[0], skip_special_tokens=True)
60
+
61
+ # Step 2: Findings
62
+ reference_findings = [
63
+ "There is evidence of right lower lobe consolidation.",
64
+ "No acute cardiopulmonary abnormality.",
65
+ "Mild cardiomegaly with clear lungs.",
66
+ "Findings consistent with pneumonia.",
67
+ "Chronic interstitial changes noted."
68
+ ]
69
+ with torch.no_grad():
70
+ img_inputs = clip_proc(images=image, return_tensors="pt").to(device)
71
+ img_features = clip_model.get_image_features(**img_inputs)
72
+ img_features = torch.nn.functional.normalize(img_features, dim=-1)
73
+
74
+ text_inputs = clip_proc.tokenizer(reference_findings, return_tensors="pt", padding=True, truncation=True).to(device)
75
+ txt_features = clip_model.get_text_features(**text_inputs)
76
+ txt_features = torch.nn.functional.normalize(txt_features, dim=-1)
77
+
78
+ similarities = torch.matmul(txt_features, img_features.T).squeeze()
79
+ top_indices = torch.topk(similarities, k=3).indices
80
+ top_findings = [reference_findings[i] for i in top_indices]
81
+
82
+ findings_text = f"• {top_findings[0]}\n• {top_findings[1]}\n• {top_findings[2]}"
83
+
84
+ # Step 3: Draft
85
+ llama_prompt = f"🧠 Caption: {caption}\n📚 Retrieved Reports:\n" + "\n".join(f"- {f}" for f in top_findings) + "\n\nGenerate a clinical-style radiology report:"
86
+ with torch.no_grad():
87
+ inputs = llama_tok(llama_prompt, return_tensors="pt").to(device)
88
+ outputs = llama_model.generate(**inputs, max_new_tokens=75, use_cache=True, do_sample=True, temperature=0.7)
89
+ draft_report = llama_tok.decode(outputs[0], skip_special_tokens=True)
90
+
91
+ # Steps 4-5: Refine
92
+ gpt2_input_1 = f"Caption: {caption}\nDraft: {draft_report}\nRefine this into a structured radiology report:"
93
+ with torch.no_grad():
94
+ gpt2_inputs_1 = gpt2_tok(gpt2_input_1, return_tensors="pt").to(device)
95
+ gpt2_output_1 = gpt2_model.generate(**gpt2_inputs_1, max_new_tokens=100, do_sample=True, temperature=0.7, pad_token_id=gpt2_tok.eos_token_id)
96
+ refined_1 = gpt2_tok.decode(gpt2_output_1[0], skip_special_tokens=True)
97
+
98
+ gpt2_input_2 = f"{refined_1}\nRefine the new report:"
99
+ gpt2_inputs_2 = gpt2_tok(gpt2_input_2, return_tensors="pt").to(device)
100
+ gpt2_output_2 = gpt2_model.generate(**gpt2_inputs_2, max_new_tokens=75, do_sample=True, temperature=0.7, pad_token_id=gpt2_tok.eos_token_id)
101
+ refined_2 = gpt2_tok.decode(gpt2_output_2[0], skip_special_tokens=True)
102
+
103
+ return f"**Caption:** {caption}", findings_text, refined_2
104
+
105
+ # Gradio Interface
106
+ with gr.Blocks(title="🩺 AI Radiology Assistant") as demo:
107
+ gr.Markdown("# 🩺 AI-Powered Chest X-ray Interpretation Tool")
108
+
109
+ with gr.Row():
110
+ image_input = gr.Image(type="pil", label="Upload Chest X-ray (PNG/JPG)")
111
+
112
+ caption_out = gr.Textbox(label="🧠 AI Captioning", interactive=False)
113
+ findings_out = gr.Textbox(label="🔍 Top Similar Clinical Findings", interactive=False)
114
+ report_out = gr.Textbox(label="📄 Final Structured Radiology Report", lines=20, interactive=False)
115
+
116
+ submit_btn = gr.Button("Generate Report", variant="primary")
117
+
118
+ submit_btn.click(
119
+ generate_report,
120
+ inputs=image_input,
121
+ outputs=[caption_out, findings_out, report_out]
122
+ )
123
+
124
+ if __name__ == "__main__":
125
+ demo.launch()