talha420 commited on
Commit
bd2cbcd
·
1 Parent(s): 6c0cea2

moved gradcam inside app

Browse files
Files changed (3) hide show
  1. app.py +29 -18
  2. main.py +0 -58
  3. requirements.txt +3 -3
app.py CHANGED
@@ -2,41 +2,41 @@ import gradio as gr
2
  import numpy as np
3
  from app.predictor import predict_disease, compute_ai_risk
4
  from app.gradcam import generate_gradcam
5
- import torch
6
 
7
 
8
  def predict(image):
9
  if image is None:
10
  return "Upload image first", None
11
 
12
- # Prediction
13
  label, confidence, tensor, original = predict_disease(image)
14
 
15
- # GradCAM (FIXED CALL)
16
  heatmap = generate_gradcam(tensor, label)
17
 
 
18
  heatmap = np.array(heatmap)
19
 
20
- # normalize safely
 
 
21
  if heatmap.max() > 1:
22
  heatmap = heatmap / 255.0
23
 
24
  heatmap = np.clip(heatmap, 0, 1)
25
  heatmap = (heatmap * 255).astype(np.uint8)
26
 
 
27
  if len(heatmap.shape) == 2:
28
- heatmap = np.stack([heatmap] * 3, axis=-1)
29
 
30
- # Risk
31
  risk = compute_ai_risk(label, confidence, heatmap)
32
 
33
- report = f"""
34
- 🫀 CardioGuard AI Report
35
-
36
- Disease: {label}
37
- Confidence: {confidence*100:.2f}%
38
- Risk: {risk}
39
- """
40
 
41
  return report, heatmap
42
 
@@ -46,13 +46,24 @@ demo = gr.Blocks()
46
  with demo:
47
  gr.Markdown("# 🫀 CardioGuard AI")
48
 
49
- img = gr.Image(type="pil")
 
 
 
 
 
50
  btn = gr.Button("Analyze")
51
 
52
- out1 = gr.Textbox()
53
- out2 = gr.Image()
 
 
 
54
 
55
- btn.click(predict, img, [out1, out2])
56
 
 
57
 
58
- demo.launch(share=True)
 
 
 
 
2
  import numpy as np
3
  from app.predictor import predict_disease, compute_ai_risk
4
  from app.gradcam import generate_gradcam
 
5
 
6
 
7
  def predict(image):
8
  if image is None:
9
  return "Upload image first", None
10
 
 
11
  label, confidence, tensor, original = predict_disease(image)
12
 
13
+ # GradCAM
14
  heatmap = generate_gradcam(tensor, label)
15
 
16
+ # ensure numpy
17
  heatmap = np.array(heatmap)
18
 
19
+ # safe normalization
20
+ heatmap = heatmap.astype(np.float32)
21
+
22
  if heatmap.max() > 1:
23
  heatmap = heatmap / 255.0
24
 
25
  heatmap = np.clip(heatmap, 0, 1)
26
  heatmap = (heatmap * 255).astype(np.uint8)
27
 
28
+ # ensure RGB format
29
  if len(heatmap.shape) == 2:
30
+ heatmap = np.stack([heatmap]*3, axis=-1)
31
 
 
32
  risk = compute_ai_risk(label, confidence, heatmap)
33
 
34
+ report = (
35
+ "🫀 CardioGuard AI Report\n\n"
36
+ f"Disease: {label}\n"
37
+ f"Confidence: {confidence*100:.2f}%\n"
38
+ f"Risk: {risk}"
39
+ )
 
40
 
41
  return report, heatmap
42
 
 
46
  with demo:
47
  gr.Markdown("# 🫀 CardioGuard AI")
48
 
49
+ with gr.Row():
50
+ img = gr.Image(type="pil")
51
+ out2 = gr.Image()
52
+
53
+ out1 = gr.Textbox(label="Report")
54
+
55
  btn = gr.Button("Analyze")
56
 
57
+ btn.click(
58
+ fn=predict,
59
+ inputs=img,
60
+ outputs=[out1, out2]
61
+ )
62
 
 
63
 
64
+ demo.queue()
65
 
66
+ demo.launch(
67
+ server_name="0.0.0.0",
68
+ server_port=7860
69
+ )
main.py DELETED
@@ -1,58 +0,0 @@
1
- print("STEP 1: main.py starting import")
2
-
3
- from fastapi import FastAPI, UploadFile, File
4
- print("STEP 2: fastapi imported")
5
-
6
- from app.model_loader import load_all_models
7
- print("STEP 3: model_loader imported")
8
-
9
- from app.predictor import run_full_pipeline
10
- print("STEP 4: predictor imported")
11
-
12
-
13
-
14
- from fastapi import FastAPI, UploadFile, File
15
- import numpy as np
16
- import cv2
17
-
18
- from app.model_loader import load_all_models
19
- from app.predictor import run_full_pipeline
20
-
21
- app = FastAPI(title="CardioGuard AI API")
22
-
23
- # 🔥 LOAD MODELS WHEN SERVER STARTS
24
- @app.on_event("startup")
25
- def startup_event():
26
- load_all_models()
27
-
28
-
29
- # 🩺 Health check route
30
- @app.get("/")
31
- def root():
32
- return {"message": "CardioGuard API is running"}
33
-
34
-
35
- # ❤️‍🔥 Prediction Route
36
- @app.post("/predict")
37
- async def predict_xray(file: UploadFile = File(...)):
38
-
39
- image_bytes = await file.read()
40
-
41
- # convert bytes → PIL image
42
- from PIL import Image
43
- import io
44
- image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
45
-
46
- result = run_full_pipeline(image)
47
-
48
- # convert heatmap image → base64 (so frontend can display)
49
- import base64
50
- _, buffer = cv2.imencode(".png", result["heatmap"])
51
- heatmap_base64 = base64.b64encode(buffer).decode("utf-8")
52
-
53
- return {
54
- "disease": result["disease"],
55
- "confidence": result["confidence"],
56
- "risk": result["risk"],
57
- "heatmap": heatmap_base64
58
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
- gradio==4.44.1
 
2
 
3
  fastapi==0.110.0
4
  uvicorn==0.27.1
@@ -6,13 +7,12 @@ uvicorn==0.27.1
6
  torch
7
  torchvision
8
 
9
- numpy
10
  pillow
11
  opencv-python-headless
12
  scikit-learn
13
  joblib
14
 
15
  safetensors>=0.4.0
16
-
17
  huggingface_hub>=0.25.2,<1.0
18
  spaces
 
1
+ gradio==4.29.0
2
+ gradio-client==1.3.0
3
 
4
  fastapi==0.110.0
5
  uvicorn==0.27.1
 
7
  torch
8
  torchvision
9
 
10
+ numpy<2
11
  pillow
12
  opencv-python-headless
13
  scikit-learn
14
  joblib
15
 
16
  safetensors>=0.4.0
 
17
  huggingface_hub>=0.25.2,<1.0
18
  spaces