talha420 commited on
Commit
b575aee
·
1 Parent(s): 159bb27

moved gradcam inside app

Browse files
Files changed (1) hide show
  1. app.py +20 -21
app.py CHANGED
@@ -5,36 +5,36 @@ from app.gradcam import generate_gradcam
5
 
6
 
7
  def predict(image):
8
- # 1️⃣ Disease prediction
 
9
  label, confidence, tensor, original = predict_disease(image)
10
 
11
- # 2️⃣ GradCAM heatmap
12
  heatmap = generate_gradcam(tensor, original)
13
 
14
- # Safety check (VERY IMPORTANT for Gradio image output)
15
- heatmap = np.array(heatmap)
16
-
17
- # Normalize if needed
18
- if heatmap.max() <= 1:
19
- heatmap = heatmap * 255
20
 
21
- heatmap = heatmap.astype(np.uint8)
 
 
22
 
23
- # 3️⃣ Risk prediction
24
  risk = compute_ai_risk(label, confidence, heatmap)
25
 
26
- report = f"""
27
- 🫀 AI Clinical Report
28
-
29
- Disease: {label}
30
- Confidence: {round(confidence * 100, 2)} %
31
- Risk Level: {risk}
32
- """
33
 
34
  return report, heatmap
35
 
36
 
37
- # 🔴 UI
38
  with gr.Blocks(title="CardioGuard AI") as demo:
39
 
40
  gr.Markdown("# 🫀 CardioGuard AI")
@@ -54,10 +54,9 @@ with gr.Blocks(title="CardioGuard AI") as demo:
54
  outputs=[report_output, heatmap_output]
55
  )
56
 
57
-
58
- # ✅ IMPORTANT FIX FOR SPACES / CLOUD
59
  demo.launch(
60
  server_name="0.0.0.0",
61
  server_port=7860,
62
- share=True
63
  )
 
5
 
6
 
7
  def predict(image):
8
+
9
+ # 1️⃣ Prediction
10
  label, confidence, tensor, original = predict_disease(image)
11
 
12
+ # 2️⃣ GradCAM
13
  heatmap = generate_gradcam(tensor, original)
14
 
15
+ # Ensure proper format (IMPORTANT FIX)
16
+ if heatmap.dtype != np.uint8:
17
+ heatmap = np.clip(heatmap, 0, 1)
18
+ heatmap = (heatmap * 255).astype(np.uint8)
 
 
19
 
20
+ # If grayscale → convert to RGB (prevents Gradio schema crash sometimes)
21
+ if len(heatmap.shape) == 2:
22
+ heatmap = np.stack([heatmap]*3, axis=-1)
23
 
24
+ # 3️⃣ Risk
25
  risk = compute_ai_risk(label, confidence, heatmap)
26
 
27
+ report = (
28
+ f"🫀 AI Clinical Report\n\n"
29
+ f"Disease: {label}\n"
30
+ f"Confidence: {confidence*100:.2f} %\n"
31
+ f"Risk Level: {risk}"
32
+ )
 
33
 
34
  return report, heatmap
35
 
36
 
37
+ # UI
38
  with gr.Blocks(title="CardioGuard AI") as demo:
39
 
40
  gr.Markdown("# 🫀 CardioGuard AI")
 
54
  outputs=[report_output, heatmap_output]
55
  )
56
 
57
+ # IMPORTANT FIX FOR HF / CLOUD ENV
 
58
  demo.launch(
59
  server_name="0.0.0.0",
60
  server_port=7860,
61
+ share=False
62
  )