SaniaE's picture
added check for failed detections
31f8a15 verified
import io
import os
import numpy as np
import cv2
import tensorflow as tf
from PIL import Image
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, Response
from huggingface_hub import hf_hub_download, snapshot_download, login
# 1. Run environment fixes immediately
import patches
patches.apply_fixes()
# 2. Setup Hub and Downloads
REPO_ID = "SaniaE/MRCNN_Petrol_Pump_Segmentation"
FILENAME = "mask_rcnn_petrol station_0080.h5"
token = os.getenv("HF_Token")
snapshot_download(repo_id=REPO_ID, allow_patterns=["mrcnn/*"], local_dir=".", token=token)
patches.patch_model_file()
# 3. Import Mask R-CNN after patching
import mrcnn.model as modellib
from model_utils import PredictionConfig, visualize_detections
app = FastAPI()
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
# Global state
config = PredictionConfig()
graph = tf.get_default_graph()
model_eval = None
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Permits all origins
allow_credentials=True,
allow_methods=["*"], # Permits all methods
allow_headers=["*"], # Permits all headers
)
@app.on_event("startup")
def load_model():
global model_eval
if token: login(token=token)
weights_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, token=token)
with graph.as_default():
model_eval = modellib.MaskRCNN(mode="inference", model_dir=".", config=config)
model_eval.load_weights(weights_path, by_name=True)
model_eval.keras_model._make_predict_function()
print("Service Ready.")
@app.get("/")
async def root():
return {"status": "online"}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
# Read Image
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
image_np = np.array(image)
# Inference
with graph.as_default():
results = model_eval.detect([image_np], verbose=0)
# --- ADD THIS CHECK ---
# results[0] contains 'rois', 'masks', 'class_ids', and 'scores'
if len(results[0]['scores']) == 0:
# Returning a 204 status triggers 'response.status === 204' in your React code
return Response(status_code=204)
# ----------------------
# Process Image (Only happens if segments exist)
processed_img = visualize_detections(image_np, results)
# Return Stream
_, buffer = cv2.imencode('.jpg', processed_img)
return StreamingResponse(io.BytesIO(buffer), media_type="image/jpeg")
@app.post("/explain/occlusion")
async def explain_occlusion(file: UploadFile = File(...)):
# 1. Load and Downsample
# Resizing to 256x256 makes the math much lighter for CPU
contents = await file.read()
image_pil = Image.open(io.BytesIO(contents)).convert("RGB").resize((256, 256))
image_np = np.array(image_pil)
h, w, _ = image_np.shape
# 2. Get Baseline
with graph.as_default():
baseline_res = model_eval.detect([image_np], verbose=0)[0]
if len(baseline_res['scores']) == 0:
return {"error": "No target detected for occlusion analysis."}
baseline_score = baseline_res['scores'][0]
# 3. Fast Parameters
# Patch size 64 with stride 64 means only 16 total inferences (4x4 grid)
# This is the "Safety Zone" for preventing timeouts
patch_size = 64
stride = 64
sensitivity_map = np.zeros((h, w), dtype=np.float32)
# 4. Optimized Inference Loop
print("Starting Fast Occlusion Analysis...")
with graph.as_default():
for y in range(0, h, stride):
for x in range(0, w, stride):
# Ensure patch doesn't exceed image bounds
y_end = min(y + patch_size, h)
x_end = min(x + patch_size, w)
img_occ = image_np.copy()
img_occ[y:y_end, x:x_end, :] = 128 # Neutral Gray
# Single inference
res = model_eval.detect([img_occ], verbose=0)[0]
# Calculate drop
current_score = res['scores'][0] if len(res['scores']) > 0 else 0
drop = max(0, baseline_score - current_score)
# Fill map
sensitivity_map[y:y_end, x:x_end] = drop
print(f"Processed patch at ({x}, {y}) - Drop: {drop:.4f}")
# 5. Smooth the Heatmap
# Since we have big blocks, a slight Gaussian blur makes it look professional
if np.max(sensitivity_map) > 0:
sensitivity_map = (sensitivity_map - np.min(sensitivity_map)) / (np.max(sensitivity_map) - np.min(sensitivity_map) + 1e-8)
sensitivity_map = cv2.GaussianBlur(sensitivity_map, (15, 15), 0)
heatmap = cv2.applyColorMap(np.uint8(255 * sensitivity_map), cv2.COLORMAP_JET)
# 6. Final Composite
original_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
overlay = cv2.addWeighted(original_bgr, 0.6, heatmap, 0.4, 0)
# Label it as a "Sample" to justify the lower resolution
cv2.putText(overlay, "XAI: Occlusion Sensitivity (Fast Sample)", (10, 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
# 7. Stream
_, buffer = cv2.imencode('.jpg', overlay)
return StreamingResponse(io.BytesIO(buffer.tobytes()), media_type="image/jpeg")