gayatrikancherla commited on
Commit
9615eb6
·
verified ·
1 Parent(s): c57bc22

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -33
app.py CHANGED
@@ -1,45 +1,145 @@
1
- from flask import Flask, request, jsonify
2
- import openai
3
  import os
 
 
 
 
 
 
4
 
5
- app = Flask(__name__)
 
 
 
6
 
7
- # Load API key from environment variable
8
- openai.api_key = os.getenv("OPENAI_API_KEY")
 
 
9
 
10
- @app.route('/')
11
- def index():
12
- return jsonify({"message": "Radiology AI API is running."})
13
 
14
- @app.route('/analyze', methods=['POST'])
15
- def analyze_scan():
16
- scan_type = request.form.get('scan_type', 'medical scan')
17
- findings_input = request.form.get('findings', '')
 
 
 
18
 
19
- # Prompt focuses on expanding only the Results
20
- prompt = f"""
21
- You are a professional radiology AI assistant.
22
- A {scan_type} scan was uploaded. Findings extracted: {findings_input}.
 
 
 
 
 
23
 
24
- Write a structured report with two sections:
 
 
 
 
 
 
 
25
 
26
- 1. Results → Must be elaborated into at least 2–3 sentences (more than 3 lines), written in a formal radiology report style. Expand with detail, clinical context, and interpretation. Avoid bullet points.
 
 
 
 
27
 
28
- 2. Recommendations → Keep concise, 1–2 lines, with direct clinical advice, follow-up, or specialist referral if appropriate.
29
- """
 
 
 
 
 
30
 
31
- response = openai.ChatCompletion.create(
32
- model="gpt-4o-mini",
33
- messages=[
34
- {"role": "system", "content": "You are a helpful medical radiology AI assistant."},
35
- {"role": "user", "content": prompt}
36
- ],
37
- max_tokens=600,
38
- temperature=0.7
39
- )
40
 
41
- report_text = response['choices'][0]['message']['content'].strip()
42
- return jsonify({"report": report_text})
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- if __name__ == '__main__':
45
- app.run(host="0.0.0.0", port=5000, debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import numpy as np
3
+ import torch
4
+ from PIL import Image
5
+ import torchxrayvision as xrv
6
+ from transformers import pipeline
7
+ import gradio as gr
8
 
9
+ # -----------------------------
10
+ # Setup
11
+ # -----------------------------
12
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
14
+ def load_models():
15
+ chest_model = xrv.models.get_model("densenet121-res224-all").to(DEVICE).eval()
16
+ clip = pipeline("zero-shot-image-classification", model="openai/clip-vit-base-patch32")
17
+ return chest_model, clip
18
 
19
+ CHEST_MODEL, CLIP = load_models()
20
+ LABELS = CHEST_MODEL.pathologies
 
21
 
22
+ # -----------------------------
23
+ # Helper Functions
24
+ # -----------------------------
25
+ def detect_scan_type(pil_img):
26
+ candidates = ["chest", "bone", "brain", "abdomen", "dental", "cardiac"]
27
+ result = CLIP(pil_img, candidate_labels=candidates)
28
+ return max(result, key=lambda x: x['score'])['label']
29
 
30
+ def preprocess_chest(pil_img):
31
+ if pil_img.mode != "L":
32
+ pil_img = pil_img.convert("L")
33
+ img = np.array(pil_img).astype(np.float32)
34
+ img = xrv.datasets.normalize(img, 255)
35
+ img = img[None, ...]
36
+ img = xrv.datasets.XRayCenterCrop()(img)
37
+ img = xrv.datasets.XRayResizer(224)(img)
38
+ return torch.from_numpy(img).unsqueeze(0).to(DEVICE)
39
 
40
+ def analyse_chest(pil_img):
41
+ x = preprocess_chest(pil_img)
42
+ with torch.no_grad():
43
+ output = CHEST_MODEL(x)
44
+ probs = torch.sigmoid(output)[0].cpu().numpy() * 100
45
+ findings = [(LABELS[i], p) for i, p in enumerate(probs)]
46
+ findings = sorted(findings, key=lambda x: x[1], reverse=True)
47
+ return [(cond, conf) for cond, conf in findings if conf > 40][:7]
48
 
49
+ # -----------------------------
50
+ # Clinical Interpretation
51
+ # -----------------------------
52
+ def interpret_findings(findings):
53
+ strong, moderate, mild = [], [], []
54
 
55
+ for cond, conf in findings:
56
+ if conf >= 75:
57
+ strong.append((cond, conf))
58
+ elif conf >= 55:
59
+ moderate.append((cond, conf))
60
+ elif conf >= 40:
61
+ mild.append((cond, conf))
62
 
63
+ return strong, moderate, mild
 
 
 
 
 
 
 
 
64
 
65
+ def recommendations_for_condition(cond):
66
+ recs = {
67
+ "Fibrosis": "Pulmonology referral, consider HRCT chest.",
68
+ "Infiltration": "Suggest clinical correlation (cough, fever, infection signs). Consider antibiotics and follow-up imaging.",
69
+ "Mass": "Oncology referral; biopsy or CT chest may be warranted.",
70
+ "Nodule": "Follow Fleischner guidelines; consider repeat imaging or biopsy.",
71
+ "Pleural_Thickening": "Consider TB or asbestos exposure; recommend pulmonology referral.",
72
+ "Consolidation": "Likely pneumonia; recommend clinical correlation, antibiotics, and repeat X-ray in 6–8 weeks.",
73
+ "Effusion": "Evaluate cardiac/renal function; consider thoracentesis if large.",
74
+ "Atelectasis": "Encourage physiotherapy, bronchoscopy if persistent.",
75
+ "Cardiomegaly": "Cardiology referral, ECG, and echocardiogram recommended."
76
+ }
77
+ return recs.get(cond, "Correlation with symptoms and specialist referral advised.")
78
 
79
+ def format_report(scan_type, findings):
80
+ if scan_type == "chest":
81
+ if not findings:
82
+ return "✅ **Chest X-ray Report**\n\nNormal chest. No significant abnormalities detected."
83
+
84
+ strong, moderate, mild = interpret_findings(findings)
85
+
86
+ report = "## 🫁 Chest X-ray Report\n\n"
87
+
88
+ # Summary
89
+ if strong:
90
+ report += f"**Overall Impression:** ⚠️ Abnormal – Strong evidence of **{', '.join([c for c, _ in strong])}**.\n\n"
91
+ elif moderate:
92
+ report += f"**Overall Impression:** ❗ Suspicious – Possible **{', '.join([c for c, _ in moderate])}**.\n\n"
93
+ else:
94
+ report += "**Overall Impression:** ✅ No major abnormalities detected.\n\n"
95
+
96
+ # Key Findings
97
+ report += "### 🔍 Key Findings\n"
98
+ if strong:
99
+ report += "\n**Significant Abnormalities:**\n"
100
+ for cond, conf in strong:
101
+ report += f"- {cond} ({conf:.1f}%) → {recommendations_for_condition(cond)}\n"
102
+ if moderate:
103
+ report += "\n**Possible Abnormalities:**\n"
104
+ for cond, conf in moderate:
105
+ report += f"- {cond} ({conf:.1f}%) → {recommendations_for_condition(cond)}\n"
106
+ if mild:
107
+ report += "\n**Minor / Non-specific Findings:**\n"
108
+ for cond, conf in mild:
109
+ report += f"- {cond} ({conf:.1f}%) → Monitor clinically.\n"
110
+
111
+ return report
112
+
113
+ # ------------------ Other Scans -------------------
114
+ elif scan_type == "bone":
115
+ return "🦴 **Bone X-ray Report**\n\n- Possible fracture or degenerative changes.\n- Recommendation: Orthopedic consultation."
116
+ elif scan_type == "brain":
117
+ return "🧠 **Brain MRI/CT Report**\n\n- Possible tumor, lesion, or stroke signs.\n- Recommendation: Neurology referral."
118
+ elif scan_type == "abdomen":
119
+ return "🩺 **Abdomen Scan Report**\n\n- Possible mass or opacity detected.\n- Recommendation: Gastroenterology referral."
120
+ elif scan_type == "dental":
121
+ return "😁 **Dental X-ray Report**\n\n- Possible cavities or gum disease.\n- Recommendation: Dentist referral."
122
+ elif scan_type == "cardiac":
123
+ return "❤️ **Cardiac Scan Report**\n\n- Possible cardiomegaly or valve abnormality.\n- Recommendation: Cardiology referral."
124
+ else:
125
+ return "⚠️ Unknown scan type. Please upload a valid medical scan."
126
+
127
+ # -----------------------------
128
+ # Gradio Interface
129
+ # -----------------------------
130
+ def analyse_scan(image):
131
+ pil_img = image.convert("RGB")
132
+ scan_type = detect_scan_type(pil_img)
133
+ findings = analyse_chest(pil_img) if scan_type == "chest" else None
134
+ return format_report(scan_type, findings)
135
+
136
+ demo = gr.Interface(
137
+ fn=analyse_scan,
138
+ inputs=gr.Image(type="pil"),
139
+ outputs="markdown",
140
+ title="🩻 Universal Radiology AI",
141
+ description="Upload a scan (Chest, Bone, Brain, Abdomen, Dental, Cardiac) to get an AI-generated structured radiology report. ⚠️ Not a medical device."
142
+ )
143
+
144
+ if __name__ == "__main__":
145
+ demo.launch()