Spaces:
Sleeping
Sleeping
File size: 5,449 Bytes
f7ba59a 26ed134 f7ba59a 486cd00 f7ba59a 26ed134 31f8a15 26ed134 cd91996 2d7fbe6 26ed134 2d7fbe6 26ed134 cd91996 e8f694e 26ed134 f7ba59a 26ed134 e8f694e f7ba59a 26ed134 f7ba59a 26ed134 f7ba59a ca6d16c f7ba59a 26ed134 f7ba59a 2d7fbe6 26ed134 f7ba59a 527d49e 26ed134 8580587 f7ba59a 26ed134 f7ba59a 26ed134 f7ba59a 31f8a15 26ed134 31f8a15 26ed134 122dbdd 634212d 122dbdd 634212d 122dbdd 634212d 122dbdd 634212d 122dbdd 634212d 122dbdd 634212d 122dbdd 634212d 122dbdd 634212d 122dbdd 634212d 122dbdd 634212d 122dbdd 634212d 122dbdd 634212d 122dbdd 634212d 122dbdd 634212d 122dbdd 634212d 122dbdd 634212d 122dbdd 634212d 122dbdd 634212d 122dbdd 634212d 122dbdd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | import io
import os
import numpy as np
import cv2
import tensorflow as tf
from PIL import Image
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, Response
from huggingface_hub import hf_hub_download, snapshot_download, login
# 1. Run environment fixes immediately
import patches
patches.apply_fixes()
# 2. Setup Hub and Downloads
REPO_ID = "SaniaE/MRCNN_Petrol_Pump_Segmentation"
FILENAME = "mask_rcnn_petrol station_0080.h5"
token = os.getenv("HF_Token")
snapshot_download(repo_id=REPO_ID, allow_patterns=["mrcnn/*"], local_dir=".", token=token)
patches.patch_model_file()
# 3. Import Mask R-CNN after patching
import mrcnn.model as modellib
from model_utils import PredictionConfig, visualize_detections
app = FastAPI()
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
# Global state
config = PredictionConfig()
graph = tf.get_default_graph()
model_eval = None
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Permits all origins
allow_credentials=True,
allow_methods=["*"], # Permits all methods
allow_headers=["*"], # Permits all headers
)
@app.on_event("startup")
def load_model():
global model_eval
if token: login(token=token)
weights_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, token=token)
with graph.as_default():
model_eval = modellib.MaskRCNN(mode="inference", model_dir=".", config=config)
model_eval.load_weights(weights_path, by_name=True)
model_eval.keras_model._make_predict_function()
print("Service Ready.")
@app.get("/")
async def root():
return {"status": "online"}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
# Read Image
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
image_np = np.array(image)
# Inference
with graph.as_default():
results = model_eval.detect([image_np], verbose=0)
# --- ADD THIS CHECK ---
# results[0] contains 'rois', 'masks', 'class_ids', and 'scores'
if len(results[0]['scores']) == 0:
# Returning a 204 status triggers 'response.status === 204' in your React code
return Response(status_code=204)
# ----------------------
# Process Image (Only happens if segments exist)
processed_img = visualize_detections(image_np, results)
# Return Stream
_, buffer = cv2.imencode('.jpg', processed_img)
return StreamingResponse(io.BytesIO(buffer), media_type="image/jpeg")
@app.post("/explain/occlusion")
async def explain_occlusion(file: UploadFile = File(...)):
# 1. Load and Downsample
# Resizing to 256x256 makes the math much lighter for CPU
contents = await file.read()
image_pil = Image.open(io.BytesIO(contents)).convert("RGB").resize((256, 256))
image_np = np.array(image_pil)
h, w, _ = image_np.shape
# 2. Get Baseline
with graph.as_default():
baseline_res = model_eval.detect([image_np], verbose=0)[0]
if len(baseline_res['scores']) == 0:
return {"error": "No target detected for occlusion analysis."}
baseline_score = baseline_res['scores'][0]
# 3. Fast Parameters
# Patch size 64 with stride 64 means only 16 total inferences (4x4 grid)
# This is the "Safety Zone" for preventing timeouts
patch_size = 64
stride = 64
sensitivity_map = np.zeros((h, w), dtype=np.float32)
# 4. Optimized Inference Loop
print("Starting Fast Occlusion Analysis...")
with graph.as_default():
for y in range(0, h, stride):
for x in range(0, w, stride):
# Ensure patch doesn't exceed image bounds
y_end = min(y + patch_size, h)
x_end = min(x + patch_size, w)
img_occ = image_np.copy()
img_occ[y:y_end, x:x_end, :] = 128 # Neutral Gray
# Single inference
res = model_eval.detect([img_occ], verbose=0)[0]
# Calculate drop
current_score = res['scores'][0] if len(res['scores']) > 0 else 0
drop = max(0, baseline_score - current_score)
# Fill map
sensitivity_map[y:y_end, x:x_end] = drop
print(f"Processed patch at ({x}, {y}) - Drop: {drop:.4f}")
# 5. Smooth the Heatmap
# Since we have big blocks, a slight Gaussian blur makes it look professional
if np.max(sensitivity_map) > 0:
sensitivity_map = (sensitivity_map - np.min(sensitivity_map)) / (np.max(sensitivity_map) - np.min(sensitivity_map) + 1e-8)
sensitivity_map = cv2.GaussianBlur(sensitivity_map, (15, 15), 0)
heatmap = cv2.applyColorMap(np.uint8(255 * sensitivity_map), cv2.COLORMAP_JET)
# 6. Final Composite
original_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
overlay = cv2.addWeighted(original_bgr, 0.6, heatmap, 0.4, 0)
# Label it as a "Sample" to justify the lower resolution
cv2.putText(overlay, "XAI: Occlusion Sensitivity (Fast Sample)", (10, 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
# 7. Stream
_, buffer = cv2.imencode('.jpg', overlay)
return StreamingResponse(io.BytesIO(buffer.tobytes()), media_type="image/jpeg") |