Spaces:
Sleeping
Sleeping
| import io | |
| import os | |
| import numpy as np | |
| import cv2 | |
| import tensorflow as tf | |
| from PIL import Image | |
| from fastapi import FastAPI, File, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse, Response | |
| from huggingface_hub import hf_hub_download, snapshot_download, login | |
| # 1. Run environment fixes immediately | |
| import patches | |
| patches.apply_fixes() | |
| # 2. Setup Hub and Downloads | |
| REPO_ID = "SaniaE/MRCNN_Petrol_Pump_Segmentation" | |
| FILENAME = "mask_rcnn_petrol station_0080.h5" | |
| token = os.getenv("HF_Token") | |
| snapshot_download(repo_id=REPO_ID, allow_patterns=["mrcnn/*"], local_dir=".", token=token) | |
| patches.patch_model_file() | |
| # 3. Import Mask R-CNN after patching | |
| import mrcnn.model as modellib | |
| from model_utils import PredictionConfig, visualize_detections | |
| app = FastAPI() | |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) | |
| # Global state | |
| config = PredictionConfig() | |
| graph = tf.get_default_graph() | |
| model_eval = None | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Permits all origins | |
| allow_credentials=True, | |
| allow_methods=["*"], # Permits all methods | |
| allow_headers=["*"], # Permits all headers | |
| ) | |
| def load_model(): | |
| global model_eval | |
| if token: login(token=token) | |
| weights_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, token=token) | |
| with graph.as_default(): | |
| model_eval = modellib.MaskRCNN(mode="inference", model_dir=".", config=config) | |
| model_eval.load_weights(weights_path, by_name=True) | |
| model_eval.keras_model._make_predict_function() | |
| print("Service Ready.") | |
| async def root(): | |
| return {"status": "online"} | |
| async def predict(file: UploadFile = File(...)): | |
| # Read Image | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)).convert("RGB") | |
| image_np = np.array(image) | |
| # Inference | |
| with graph.as_default(): | |
| results = model_eval.detect([image_np], verbose=0) | |
| # --- ADD THIS CHECK --- | |
| # results[0] contains 'rois', 'masks', 'class_ids', and 'scores' | |
| if len(results[0]['scores']) == 0: | |
| # Returning a 204 status triggers 'response.status === 204' in your React code | |
| return Response(status_code=204) | |
| # ---------------------- | |
| # Process Image (Only happens if segments exist) | |
| processed_img = visualize_detections(image_np, results) | |
| # Return Stream | |
| _, buffer = cv2.imencode('.jpg', processed_img) | |
| return StreamingResponse(io.BytesIO(buffer), media_type="image/jpeg") | |
| async def explain_occlusion(file: UploadFile = File(...)): | |
| # 1. Load and Downsample | |
| # Resizing to 256x256 makes the math much lighter for CPU | |
| contents = await file.read() | |
| image_pil = Image.open(io.BytesIO(contents)).convert("RGB").resize((256, 256)) | |
| image_np = np.array(image_pil) | |
| h, w, _ = image_np.shape | |
| # 2. Get Baseline | |
| with graph.as_default(): | |
| baseline_res = model_eval.detect([image_np], verbose=0)[0] | |
| if len(baseline_res['scores']) == 0: | |
| return {"error": "No target detected for occlusion analysis."} | |
| baseline_score = baseline_res['scores'][0] | |
| # 3. Fast Parameters | |
| # Patch size 64 with stride 64 means only 16 total inferences (4x4 grid) | |
| # This is the "Safety Zone" for preventing timeouts | |
| patch_size = 64 | |
| stride = 64 | |
| sensitivity_map = np.zeros((h, w), dtype=np.float32) | |
| # 4. Optimized Inference Loop | |
| print("Starting Fast Occlusion Analysis...") | |
| with graph.as_default(): | |
| for y in range(0, h, stride): | |
| for x in range(0, w, stride): | |
| # Ensure patch doesn't exceed image bounds | |
| y_end = min(y + patch_size, h) | |
| x_end = min(x + patch_size, w) | |
| img_occ = image_np.copy() | |
| img_occ[y:y_end, x:x_end, :] = 128 # Neutral Gray | |
| # Single inference | |
| res = model_eval.detect([img_occ], verbose=0)[0] | |
| # Calculate drop | |
| current_score = res['scores'][0] if len(res['scores']) > 0 else 0 | |
| drop = max(0, baseline_score - current_score) | |
| # Fill map | |
| sensitivity_map[y:y_end, x:x_end] = drop | |
| print(f"Processed patch at ({x}, {y}) - Drop: {drop:.4f}") | |
| # 5. Smooth the Heatmap | |
| # Since we have big blocks, a slight Gaussian blur makes it look professional | |
| if np.max(sensitivity_map) > 0: | |
| sensitivity_map = (sensitivity_map - np.min(sensitivity_map)) / (np.max(sensitivity_map) - np.min(sensitivity_map) + 1e-8) | |
| sensitivity_map = cv2.GaussianBlur(sensitivity_map, (15, 15), 0) | |
| heatmap = cv2.applyColorMap(np.uint8(255 * sensitivity_map), cv2.COLORMAP_JET) | |
| # 6. Final Composite | |
| original_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) | |
| overlay = cv2.addWeighted(original_bgr, 0.6, heatmap, 0.4, 0) | |
| # Label it as a "Sample" to justify the lower resolution | |
| cv2.putText(overlay, "XAI: Occlusion Sensitivity (Fast Sample)", (10, 20), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1) | |
| # 7. Stream | |
| _, buffer = cv2.imencode('.jpg', overlay) | |
| return StreamingResponse(io.BytesIO(buffer.tobytes()), media_type="image/jpeg") |