AIOmarRehan commited on
Commit
aeb005f
·
verified ·
1 Parent(s): d605acd

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +60 -65
app/main.py CHANGED
@@ -1,65 +1,60 @@
1
- # ---------------------------
2
- # File: app/main.py
3
- # ---------------------------
4
- from fastapi import FastAPI, UploadFile, File, Query
5
- from fastapi.responses import JSONResponse, StreamingResponse
6
- from PIL import Image
7
- import io
8
- import numpy as np
9
- import traceback
10
-
11
- # Import the model utilities
12
- from app.model import predict, gradcam, CLASS_NAMES
13
-
14
- app = FastAPI(title="Brain Tumor MRI Classifier (InceptionV3 + Grad-CAM)")
15
-
16
- @app.post("/predict")
17
- async def predict_image(file: UploadFile = File(...)):
18
- try:
19
- contents = await file.read()
20
- pil_img = Image.open(io.BytesIO(contents)).convert("RGB")
21
- label, confidence, probs = predict(pil_img)
22
- return JSONResponse({
23
- "predicted_label": label,
24
- "confidence": round(confidence, 3),
25
- "probabilities": {k: round(v, 6) for k, v in probs.items()}
26
- })
27
- except Exception as e:
28
- tb = traceback.format_exc()
29
- return JSONResponse({"error": str(e), "trace": tb}, status_code=500)
30
-
31
- @app.post("/gradcam")
32
- async def gradcam_image(file: UploadFile = File(...), interpolant: float = Query(0.5, ge=0.0, le=1.0)):
33
- """
34
- Returns a PNG image (overlay) produced by gradcam().
35
- `interpolant` controls mixing (0..1).
36
- """
37
- try:
38
- contents = await file.read()
39
- pil_img = Image.open(io.BytesIO(contents)).convert("RGB")
40
-
41
- # Compute overlay (this calls the optimized gradcam in model.py)
42
- overlay = gradcam(pil_img, interpolant=float(interpolant))
43
-
44
- # Ensure correct dtype and shape
45
- overlay = np.asarray(overlay).astype("uint8")
46
- if overlay.ndim == 2:
47
- overlay = np.stack([overlay] * 3, axis=-1)
48
-
49
- # Convert to PNG bytes
50
- buf = io.BytesIO()
51
- Image.fromarray(overlay).save(buf, format="PNG")
52
- buf.seek(0)
53
- return StreamingResponse(buf, media_type="image/png")
54
-
55
- except Exception as e:
56
- tb = traceback.format_exc()
57
- return JSONResponse({"error": str(e), "trace": tb}, status_code=500)
58
-
59
- # Optional health endpoint
60
- @app.get("/health")
61
- async def health():
62
- return {"status": "ok", "classes": CLASS_NAMES}
63
- # ---------------------------
64
- # End of app/main.py
65
- # ---------------------------
 
1
+ # File: app/main.py
2
+ from fastapi import FastAPI, UploadFile, File, Query
3
+ from fastapi.responses import JSONResponse, StreamingResponse
4
+ from PIL import Image
5
+ import io
6
+ import numpy as np
7
+ import traceback
8
+
9
+ # Import the model utilities
10
+ from app.model import predict, gradcam, CLASS_NAMES
11
+
12
+ app = FastAPI(title="Brain Tumor MRI Classifier (InceptionV3 + Grad-CAM)")
13
+
14
+ @app.post("/predict")
15
+ async def predict_image(file: UploadFile = File(...)):
16
+ try:
17
+ contents = await file.read()
18
+ pil_img = Image.open(io.BytesIO(contents)).convert("RGB")
19
+ label, confidence, probs = predict(pil_img)
20
+ return JSONResponse({
21
+ "predicted_label": label,
22
+ "confidence": round(confidence, 3),
23
+ "probabilities": {k: round(v, 6) for k, v in probs.items()}
24
+ })
25
+ except Exception as e:
26
+ tb = traceback.format_exc()
27
+ return JSONResponse({"error": str(e), "trace": tb}, status_code=500)
28
+
29
+ @app.post("/gradcam")
30
+ async def gradcam_image(file: UploadFile = File(...), interpolant: float = Query(0.5, ge=0.0, le=1.0)):
31
+ """
32
+ Returns a PNG image (overlay) produced by gradcam().
33
+ `interpolant` controls mixing (0..1).
34
+ """
35
+ try:
36
+ contents = await file.read()
37
+ pil_img = Image.open(io.BytesIO(contents)).convert("RGB")
38
+
39
+ # Compute overlay (this calls the optimized gradcam in model.py)
40
+ overlay = gradcam(pil_img, interpolant=float(interpolant))
41
+
42
+ # Ensure correct dtype and shape
43
+ overlay = np.asarray(overlay).astype("uint8")
44
+ if overlay.ndim == 2:
45
+ overlay = np.stack([overlay] * 3, axis=-1)
46
+
47
+ # Convert to PNG bytes
48
+ buf = io.BytesIO()
49
+ Image.fromarray(overlay).save(buf, format="PNG")
50
+ buf.seek(0)
51
+ return StreamingResponse(buf, media_type="image/png")
52
+
53
+ except Exception as e:
54
+ tb = traceback.format_exc()
55
+ return JSONResponse({"error": str(e), "trace": tb}, status_code=500)
56
+
57
+ # Optional health endpoint
58
+ @app.get("/health")
59
+ async def health():
60
+ return {"status": "ok", "classes": CLASS_NAMES}