serene-abyss commited on
Commit
c048140
·
verified ·
1 Parent(s): ac42e7e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -66
app.py CHANGED
@@ -3,52 +3,53 @@ import torch.nn.functional as F
3
  from transformers import AutoModelForImageClassification
4
  from torchvision import transforms
5
  from PIL import Image, ImageStat
6
- from fastapi import FastAPI, File, UploadFile, HTTPException
7
  from fastapi.responses import HTMLResponse
8
  import io
9
- import os
10
  import gc
11
 
12
  # ==========================================
13
- # 1. CONFIGURATION (Rural India Context)
14
  # ==========================================
15
  MODELS = {
16
  "lungs": {
17
  "id": "nickmuchi/vit-finetuned-chest-xray-pneumonia",
18
  "desc": "Tuberculosis & Pneumonia (Chest X-Ray)",
19
  "safe": ["NORMAL", "normal", "No Pneumonia"],
20
- "check_color": True # X-Rays must be black & white
 
21
  },
22
  "blood": {
23
  "id": "mrm8488/vit-base-patch16-224-finetuned-malaria-detection",
24
  "desc": "Malaria Screening (Microscopic Slide)",
25
  "safe": ["Uninfected", "uninfected"],
26
- "check_color": False
 
27
  },
28
  "eye": {
29
  "id": "AventIQ-AI/resnet18-cataract-detection-system",
30
  "desc": "Cataract Detection (Smartphone Eye Photo)",
31
  "safe": ["Normal", "normal", "healthy"],
32
- "check_color": False
 
33
  },
34
  "skin": {
35
  "id": "Anwarkh1/Skin_Cancer-Image_Classification",
36
  "desc": "Dermatology & Lesion Analysis",
37
  "safe": ["Benign", "benign", "nv", "bkl"],
38
- "check_color": False
 
39
  }
40
  }
41
 
42
  # ==========================================
43
- # 2. AI ENGINE (The Brain)
44
  # ==========================================
45
  class MedicalEngine:
46
  def __init__(self):
47
- # Force CPU to avoid memory crashes on Free Tier
48
  self.device = "cpu"
49
- print("✅ System Initialized: Medical Engine Ready (Lazy Loading)")
50
 
51
- # Standard Image Transformation
52
  self.transform = transforms.Compose([
53
  transforms.Resize((224, 224)),
54
  transforms.ToTensor(),
@@ -56,59 +57,84 @@ class MedicalEngine:
56
  ])
57
 
58
  def validate_image(self, image, task):
59
- """Guardrail: Rejects obvious bad images (e.g. Selfies for X-Ray)"""
60
- if MODELS[task]["check_color"]:
61
- # Check saturation for X-Rays
62
- stat = ImageStat.Stat(image.convert('HSV'))
63
- saturation = stat.mean[1] # 0 = Gray, 255 = Color
64
- if saturation > 35: # Threshold
65
- return False, "⚠️ Invalid Image: This looks like a color photo. Please upload a Black & White X-Ray."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  return True, ""
67
 
68
  def predict(self, image_bytes, task):
69
- # A. Validation
70
  try:
71
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
72
  except:
73
  return {"error": "File is not a valid image."}
74
 
 
75
  is_valid, msg = self.validate_image(image, task)
76
  if not is_valid:
77
- return {"error": msg, "risk": "INVALID"}
 
 
 
 
 
 
 
78
 
79
- # B. Load Model (On Demand)
80
- print(f"⏳ Downloading/Loading Model for: {task}...")
81
  try:
82
  model_id = MODELS[task]["id"]
83
  model = AutoModelForImageClassification.from_pretrained(model_id)
84
  model.to(self.device)
85
  model.eval()
86
  except Exception as e:
87
- print(f"❌ Model Load Error: {e}")
88
  return {"error": "Failed to load AI model. Try again."}
89
 
90
- # C. Inference
91
  try:
92
  inputs = self.transform(image).unsqueeze(0).to(self.device)
93
  with torch.no_grad():
94
  outputs = model(inputs)
95
  probs = F.softmax(outputs.logits, dim=-1)
96
 
97
- # Extract Results
98
  results = []
99
  for i, score in enumerate(probs[0]):
100
  label = model.config.id2label[i]
101
  results.append({"label": label, "score": float(score)})
102
  results.sort(key=lambda x: x['score'], reverse=True)
103
-
104
  top = results[0]
105
 
106
- # D. Risk Logic
107
  safe_words = MODELS[task]["safe"]
108
  is_safe = any(s.lower() in top["label"].lower() for s in safe_words)
109
 
110
  if top["score"] < 0.5:
111
- risk = "UNCERTAIN" # Low confidence
 
112
  elif is_safe:
113
  risk = "LOW"
114
  else:
@@ -118,7 +144,6 @@ class MedicalEngine:
118
  return {"error": f"Prediction Error: {str(e)}"}
119
 
120
  finally:
121
- # E. Cleanup RAM (Critical for Free Tier)
122
  del model
123
  gc.collect()
124
 
@@ -130,22 +155,17 @@ class MedicalEngine:
130
  }
131
 
132
  # ==========================================
133
- # 3. API SERVER (FastAPI)
134
  # ==========================================
135
  app = FastAPI()
136
  engine = MedicalEngine()
137
 
138
  @app.post("/predict/{task}")
139
  async def predict_route(task: str, file: UploadFile = File(...)):
140
- if task not in MODELS:
141
- return {"error": "Invalid Task"}
142
-
143
  content = await file.read()
144
  return engine.predict(content, task)
145
 
146
- # ==========================================
147
- # 4. FRONTEND UI (HTML embedded)
148
- # ==========================================
149
  @app.get("/", response_class=HTMLResponse)
150
  def home():
151
  return """
@@ -207,7 +227,7 @@ def home():
207
  <i class="fas fa-cloud-upload-alt text-3xl text-gray-400 mb-2"></i>
208
  <p class="text-gray-500 text-sm">Tap to upload image</p>
209
  </div>
210
- <img id="preview" class="hidden mx-auto max-h-48 rounded shadow">
211
  </div>
212
  </div>
213
 
@@ -218,7 +238,6 @@ def home():
218
  <div id="loader" class="hidden text-center py-6">
219
  <div class="inline-block animate-spin rounded-full h-8 w-8 border-4 border-blue-500 border-t-transparent"></div>
220
  <p class="text-sm text-gray-500 mt-2 font-semibold">Downloading AI Model & Analyzing...</p>
221
- <p class="text-xs text-gray-400">(This takes ~20s on first run)</p>
222
  </div>
223
 
224
  <div id="result-box" class="hidden mt-6 border-t pt-6">
@@ -244,28 +263,21 @@ def home():
244
  </div>
245
  </div>
246
 
247
- <footer class="text-center p-4 text-xs text-gray-400">
248
- ⚠️ Disclaimer: AI Research Tool. Verify with a Doctor.
249
- </footer>
250
-
251
  <script>
252
  let currTask = null;
253
  let currFile = null;
254
 
255
  function setTask(task) {
256
  currTask = task;
257
-
258
- // Highlight Buttons
259
  document.querySelectorAll('button[id^="btn-"]').forEach(b => b.classList.remove('ring-2', 'ring-blue-400', 'border-blue-500'));
260
  document.getElementById('btn-'+task).classList.add('ring-2', 'ring-blue-400', 'border-blue-500');
261
-
262
- // Reset UI
263
  document.getElementById('header-text').innerHTML = `Upload <span class="uppercase text-blue-600">${task}</span> Image`;
264
  document.getElementById('inputs').classList.remove('opacity-50', 'pointer-events-none');
265
  document.getElementById('result-box').classList.add('hidden');
266
  document.getElementById('run-btn').classList.add('hidden');
267
  document.getElementById('placeholder').classList.remove('hidden');
268
  document.getElementById('preview').classList.add('hidden');
 
269
  currFile = null;
270
  }
271
 
@@ -286,13 +298,8 @@ def home():
286
 
287
  async function analyze() {
288
  if (!currTask || !currFile) return;
289
-
290
- if (!document.getElementById('p-name').value) {
291
- alert("Please enter Patient Name.");
292
- return;
293
- }
294
 
295
- // Show Loader
296
  document.getElementById('run-btn').classList.add('hidden');
297
  document.getElementById('loader').classList.remove('hidden');
298
  document.getElementById('result-box').classList.add('hidden');
@@ -301,21 +308,27 @@ def home():
301
  formData.append("file", currFile);
302
 
303
  try {
304
- // Determine API URL (Handles Hugging Face URL structure)
305
- let url = "/predict/" + currTask;
306
-
307
- let res = await fetch(url, { method: "POST", body: formData });
308
  let data = await res.json();
309
 
310
- if (data.error) {
311
- alert("Error: " + data.error);
312
- resetLoading();
 
 
 
 
 
 
 
 
 
 
 
313
  return;
314
  }
315
 
316
- // Show Results
317
- document.getElementById('loader').classList.add('hidden');
318
- document.getElementById('result-box').classList.remove('hidden');
319
 
320
  document.getElementById('res-label').innerText = data.prediction.label;
321
  document.getElementById('res-conf').innerText = (data.prediction.score * 100).toFixed(1) + "%";
@@ -325,24 +338,25 @@ def home():
325
 
326
  if (data.risk === "HIGH") {
327
  badge.className = "px-3 py-1 rounded text-sm font-bold uppercase bg-red-100 text-red-700";
 
328
  alertBox.classList.remove('hidden');
329
  document.getElementById('alert-text').innerText = "High Risk. Immediate Referral Recommended.";
330
  } else if (data.risk === "MODERATE") {
331
  badge.className = "px-3 py-1 rounded text-sm font-bold uppercase bg-yellow-100 text-yellow-700";
 
332
  alertBox.classList.remove('hidden');
333
  document.getElementById('alert-text').innerText = "Moderate Risk. Consult Doctor.";
334
- } else if (data.risk === "INVALID" || data.risk === "UNCERTAIN") {
335
  badge.className = "px-3 py-1 rounded text-sm font-bold uppercase bg-gray-200 text-gray-700";
 
336
  alertBox.classList.remove('hidden');
337
- document.getElementById('alert-text').innerText = "Image Unclear / Invalid. Retake Photo.";
338
  } else {
339
  badge.className = "px-3 py-1 rounded text-sm font-bold uppercase bg-green-100 text-green-700";
340
  alertBox.classList.add('hidden');
341
  }
342
-
343
  badge.innerText = data.risk + " RISK";
344
 
345
- // Sync Animation
346
  setTimeout(() => {
347
  document.getElementById('sync-msg').innerHTML = "<i class='fas fa-check-circle'></i> Synced!";
348
  document.getElementById('sync-msg').className = "text-green-600 font-bold";
 
3
  from transformers import AutoModelForImageClassification
4
  from torchvision import transforms
5
  from PIL import Image, ImageStat
6
+ from fastapi import FastAPI, File, UploadFile
7
  from fastapi.responses import HTMLResponse
8
  import io
 
9
  import gc
10
 
11
  # ==========================================
12
+ # 1. CONFIGURATION (With STRICT Guardrails)
13
  # ==========================================
14
  MODELS = {
15
  "lungs": {
16
  "id": "nickmuchi/vit-finetuned-chest-xray-pneumonia",
17
  "desc": "Tuberculosis & Pneumonia (Chest X-Ray)",
18
  "safe": ["NORMAL", "normal", "No Pneumonia"],
19
+ # Rule: Saturation must be LOW (Grayscale)
20
+ "guardrails": {"max_sat": 35}
21
  },
22
  "blood": {
23
  "id": "mrm8488/vit-base-patch16-224-finetuned-malaria-detection",
24
  "desc": "Malaria Screening (Microscopic Slide)",
25
  "safe": ["Uninfected", "uninfected"],
26
+ # Rule: Must be bright (Backlit slide)
27
+ "guardrails": {"min_bright": 60}
28
  },
29
  "eye": {
30
  "id": "AventIQ-AI/resnet18-cataract-detection-system",
31
  "desc": "Cataract Detection (Smartphone Eye Photo)",
32
  "safe": ["Normal", "normal", "healthy"],
33
+ # Rule: Saturation must be HIGH (Color photo) - Blocks X-Rays
34
+ "guardrails": {"min_sat": 20}
35
  },
36
  "skin": {
37
  "id": "Anwarkh1/Skin_Cancer-Image_Classification",
38
  "desc": "Dermatology & Lesion Analysis",
39
  "safe": ["Benign", "benign", "nv", "bkl"],
40
+ # Rule: Saturation must be HIGH (Color photo) - Blocks X-Rays
41
+ "guardrails": {"min_sat": 20}
42
  }
43
  }
44
 
45
  # ==========================================
46
+ # 2. AI ENGINE
47
  # ==========================================
48
  class MedicalEngine:
49
  def __init__(self):
 
50
  self.device = "cpu"
51
+ print("✅ System Initialized: Medical Engine Ready")
52
 
 
53
  self.transform = transforms.Compose([
54
  transforms.Resize((224, 224)),
55
  transforms.ToTensor(),
 
57
  ])
58
 
59
  def validate_image(self, image, task):
60
+ """
61
+ Universal Guardrails:
62
+ - Prevents X-Rays in Skin/Eye tabs (Checks Min Saturation)
63
+ - Prevents Selfies in X-Ray tab (Checks Max Saturation)
64
+ """
65
+ rules = MODELS[task].get("guardrails", {})
66
+
67
+ # Convert to HSV (Hue, Saturation, Value)
68
+ # Saturation (Index 1): 0 = Gray, 255 = Color
69
+ # Value (Index 2): 0 = Dark, 255 = Bright
70
+ stat = ImageStat.Stat(image.convert('HSV'))
71
+ avg_sat = stat.mean[1]
72
+ avg_bright = stat.mean[2]
73
+
74
+ # 1. Check Max Saturation (Block Colorful images)
75
+ if "max_sat" in rules and avg_sat > rules["max_sat"]:
76
+ return False, f"⚠️ Invalid Image: Too colorful ({int(avg_sat)}). This looks like a photo, not an X-Ray."
77
+
78
+ # 2. Check Min Saturation (Block Grayscale images)
79
+ if "min_sat" in rules and avg_sat < rules["min_sat"]:
80
+ return False, f"⚠️ Invalid Image: Too gray ({int(avg_sat)}). This looks like an X-Ray/Doc. Please upload a color photo."
81
+
82
+ # 3. Check Min Brightness (Block Dark images)
83
+ if "min_bright" in rules and avg_bright < rules["min_bright"]:
84
+ return False, "⚠️ Invalid Image: Too dark. Microscope slides must be backlit."
85
+
86
  return True, ""
87
 
88
  def predict(self, image_bytes, task):
89
+ # A. Load Image
90
  try:
91
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
92
  except:
93
  return {"error": "File is not a valid image."}
94
 
95
+ # B. Run Validation
96
  is_valid, msg = self.validate_image(image, task)
97
  if not is_valid:
98
+ # Return a special "INVALID" state
99
+ return {
100
+ "task": task,
101
+ "desc": MODELS[task]["desc"],
102
+ "prediction": {"label": "Invalid Image", "score": 0.0},
103
+ "risk": "INVALID",
104
+ "error": msg
105
+ }
106
 
107
+ # C. Load Model
108
+ print(f"⏳ Loading Model for: {task}...")
109
  try:
110
  model_id = MODELS[task]["id"]
111
  model = AutoModelForImageClassification.from_pretrained(model_id)
112
  model.to(self.device)
113
  model.eval()
114
  except Exception as e:
 
115
  return {"error": "Failed to load AI model. Try again."}
116
 
117
+ # D. Inference
118
  try:
119
  inputs = self.transform(image).unsqueeze(0).to(self.device)
120
  with torch.no_grad():
121
  outputs = model(inputs)
122
  probs = F.softmax(outputs.logits, dim=-1)
123
 
 
124
  results = []
125
  for i, score in enumerate(probs[0]):
126
  label = model.config.id2label[i]
127
  results.append({"label": label, "score": float(score)})
128
  results.sort(key=lambda x: x['score'], reverse=True)
 
129
  top = results[0]
130
 
131
+ # E. Risk Logic
132
  safe_words = MODELS[task]["safe"]
133
  is_safe = any(s.lower() in top["label"].lower() for s in safe_words)
134
 
135
  if top["score"] < 0.5:
136
+ risk = "UNCERTAIN"
137
+ top["label"] = "Inconclusive / Unknown"
138
  elif is_safe:
139
  risk = "LOW"
140
  else:
 
144
  return {"error": f"Prediction Error: {str(e)}"}
145
 
146
  finally:
 
147
  del model
148
  gc.collect()
149
 
 
155
  }
156
 
157
  # ==========================================
158
+ # 3. API & UI
159
  # ==========================================
160
  app = FastAPI()
161
  engine = MedicalEngine()
162
 
163
  @app.post("/predict/{task}")
164
  async def predict_route(task: str, file: UploadFile = File(...)):
165
+ if task not in MODELS: return {"error": "Invalid Task"}
 
 
166
  content = await file.read()
167
  return engine.predict(content, task)
168
 
 
 
 
169
  @app.get("/", response_class=HTMLResponse)
170
  def home():
171
  return """
 
227
  <i class="fas fa-cloud-upload-alt text-3xl text-gray-400 mb-2"></i>
228
  <p class="text-gray-500 text-sm">Tap to upload image</p>
229
  </div>
230
+ <img id="preview" class="hidden mx-auto max-h-48 rounded shadow object-contain">
231
  </div>
232
  </div>
233
 
 
238
  <div id="loader" class="hidden text-center py-6">
239
  <div class="inline-block animate-spin rounded-full h-8 w-8 border-4 border-blue-500 border-t-transparent"></div>
240
  <p class="text-sm text-gray-500 mt-2 font-semibold">Downloading AI Model & Analyzing...</p>
 
241
  </div>
242
 
243
  <div id="result-box" class="hidden mt-6 border-t pt-6">
 
263
  </div>
264
  </div>
265
 
 
 
 
 
266
  <script>
267
  let currTask = null;
268
  let currFile = null;
269
 
270
  function setTask(task) {
271
  currTask = task;
 
 
272
  document.querySelectorAll('button[id^="btn-"]').forEach(b => b.classList.remove('ring-2', 'ring-blue-400', 'border-blue-500'));
273
  document.getElementById('btn-'+task).classList.add('ring-2', 'ring-blue-400', 'border-blue-500');
 
 
274
  document.getElementById('header-text').innerHTML = `Upload <span class="uppercase text-blue-600">${task}</span> Image`;
275
  document.getElementById('inputs').classList.remove('opacity-50', 'pointer-events-none');
276
  document.getElementById('result-box').classList.add('hidden');
277
  document.getElementById('run-btn').classList.add('hidden');
278
  document.getElementById('placeholder').classList.remove('hidden');
279
  document.getElementById('preview').classList.add('hidden');
280
+ document.getElementById('preview').src = "";
281
  currFile = null;
282
  }
283
 
 
298
 
299
  async function analyze() {
300
  if (!currTask || !currFile) return;
301
+ if (!document.getElementById('p-name').value) { alert("Please enter Patient Name."); return; }
 
 
 
 
302
 
 
303
  document.getElementById('run-btn').classList.add('hidden');
304
  document.getElementById('loader').classList.remove('hidden');
305
  document.getElementById('result-box').classList.add('hidden');
 
308
  formData.append("file", currFile);
309
 
310
  try {
311
+ let res = await fetch("/predict/" + currTask, { method: "POST", body: formData });
 
 
 
312
  let data = await res.json();
313
 
314
+ document.getElementById('loader').classList.add('hidden');
315
+ document.getElementById('result-box').classList.remove('hidden');
316
+
317
+ if (data.risk === "INVALID") {
318
+ document.getElementById('res-label').innerText = "Image Rejected";
319
+ document.getElementById('res-conf').innerText = "--";
320
+ let badge = document.getElementById('res-badge');
321
+ badge.className = "px-3 py-1 rounded text-sm font-bold uppercase bg-gray-200 text-gray-700";
322
+ badge.innerText = "INVALID";
323
+
324
+ let alertBox = document.getElementById('alert-box');
325
+ alertBox.className = "mt-4 p-3 bg-gray-100 text-gray-800 rounded border border-gray-300 text-sm";
326
+ alertBox.classList.remove('hidden');
327
+ document.getElementById('alert-text').innerText = data.error;
328
  return;
329
  }
330
 
331
+ if (data.error) { alert("Error: " + data.error); resetLoading(); return; }
 
 
332
 
333
  document.getElementById('res-label').innerText = data.prediction.label;
334
  document.getElementById('res-conf').innerText = (data.prediction.score * 100).toFixed(1) + "%";
 
338
 
339
  if (data.risk === "HIGH") {
340
  badge.className = "px-3 py-1 rounded text-sm font-bold uppercase bg-red-100 text-red-700";
341
+ alertBox.className = "mt-4 p-3 bg-red-50 text-red-800 rounded border border-red-200 text-sm";
342
  alertBox.classList.remove('hidden');
343
  document.getElementById('alert-text').innerText = "High Risk. Immediate Referral Recommended.";
344
  } else if (data.risk === "MODERATE") {
345
  badge.className = "px-3 py-1 rounded text-sm font-bold uppercase bg-yellow-100 text-yellow-700";
346
+ alertBox.className = "mt-4 p-3 bg-yellow-50 text-yellow-800 rounded border border-yellow-200 text-sm";
347
  alertBox.classList.remove('hidden');
348
  document.getElementById('alert-text').innerText = "Moderate Risk. Consult Doctor.";
349
+ } else if (data.risk === "UNCERTAIN") {
350
  badge.className = "px-3 py-1 rounded text-sm font-bold uppercase bg-gray-200 text-gray-700";
351
+ alertBox.className = "mt-4 p-3 bg-gray-100 text-gray-800 rounded border border-gray-200 text-sm";
352
  alertBox.classList.remove('hidden');
353
+ document.getElementById('alert-text').innerText = "Image Unclear. Retake Photo.";
354
  } else {
355
  badge.className = "px-3 py-1 rounded text-sm font-bold uppercase bg-green-100 text-green-700";
356
  alertBox.classList.add('hidden');
357
  }
 
358
  badge.innerText = data.risk + " RISK";
359
 
 
360
  setTimeout(() => {
361
  document.getElementById('sync-msg').innerHTML = "<i class='fas fa-check-circle'></i> Synced!";
362
  document.getElementById('sync-msg').className = "text-green-600 font-bold";