File size: 5,449 Bytes
f7ba59a
26ed134
f7ba59a
486cd00
f7ba59a
 
26ed134
 
31f8a15
26ed134
 
 
cd91996
 
2d7fbe6
26ed134
2d7fbe6
 
 
 
26ed134
cd91996
e8f694e
26ed134
f7ba59a
26ed134
e8f694e
f7ba59a
26ed134
f7ba59a
26ed134
f7ba59a
 
 
 
ca6d16c
 
 
 
 
 
 
 
f7ba59a
 
 
26ed134
f7ba59a
 
 
2d7fbe6
 
 
 
26ed134
f7ba59a
527d49e
 
26ed134
8580587
f7ba59a
 
26ed134
f7ba59a
 
 
 
26ed134
f7ba59a
 
 
31f8a15
 
 
 
 
 
 
 
26ed134
31f8a15
26ed134
 
122dbdd
 
 
 
634212d
 
122dbdd
634212d
122dbdd
 
 
634212d
122dbdd
634212d
122dbdd
634212d
 
122dbdd
634212d
122dbdd
634212d
 
 
122dbdd
634212d
122dbdd
 
 
634212d
 
122dbdd
634212d
 
 
 
 
 
122dbdd
634212d
122dbdd
634212d
122dbdd
 
634212d
122dbdd
 
 
634212d
 
 
122dbdd
634212d
 
122dbdd
 
 
634212d
122dbdd
 
634212d
122dbdd
 
 
634212d
 
 
122dbdd
634212d
122dbdd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import io
import os
import numpy as np
import cv2 
import tensorflow as tf
from PIL import Image
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, Response 
from huggingface_hub import hf_hub_download, snapshot_download, login

# 1. Run environment fixes immediately
import patches
patches.apply_fixes()

# 2. Setup Hub and Downloads
REPO_ID = "SaniaE/MRCNN_Petrol_Pump_Segmentation"
FILENAME = "mask_rcnn_petrol station_0080.h5"
token = os.getenv("HF_Token")

snapshot_download(repo_id=REPO_ID, allow_patterns=["mrcnn/*"], local_dir=".", token=token)
patches.patch_model_file()

# 3. Import Mask R-CNN after patching
import mrcnn.model as modellib
from model_utils import PredictionConfig, visualize_detections

app = FastAPI()
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])

# Global state
config = PredictionConfig()
graph = tf.get_default_graph()
model_eval = None

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Permits all origins
    allow_credentials=True,
    allow_methods=["*"],  # Permits all methods
    allow_headers=["*"],  # Permits all headers
)

@app.on_event("startup")
def load_model():
    global model_eval
    if token: login(token=token)
    
    weights_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, token=token)
    
    with graph.as_default():
        model_eval = modellib.MaskRCNN(mode="inference", model_dir=".", config=config)
        model_eval.load_weights(weights_path, by_name=True)
        model_eval.keras_model._make_predict_function()
    print("Service Ready.")

@app.get("/")
async def root():
    return {"status": "online"}

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    # Read Image
    contents = await file.read()
    image = Image.open(io.BytesIO(contents)).convert("RGB")
    image_np = np.array(image)

    # Inference
    with graph.as_default():
        results = model_eval.detect([image_np], verbose=0)
    
    # --- ADD THIS CHECK ---
    # results[0] contains 'rois', 'masks', 'class_ids', and 'scores'
    if len(results[0]['scores']) == 0:
        # Returning a 204 status triggers 'response.status === 204' in your React code
        return Response(status_code=204)
    # ----------------------

    # Process Image (Only happens if segments exist)
    processed_img = visualize_detections(image_np, results)
        
    # Return Stream
    _, buffer = cv2.imencode('.jpg', processed_img)
    return StreamingResponse(io.BytesIO(buffer), media_type="image/jpeg")

@app.post("/explain/occlusion")
async def explain_occlusion(file: UploadFile = File(...)):
    # 1. Load and Downsample
    # Resizing to 256x256 makes the math much lighter for CPU
    contents = await file.read()
    image_pil = Image.open(io.BytesIO(contents)).convert("RGB").resize((256, 256))
    image_np = np.array(image_pil)
    h, w, _ = image_np.shape

    # 2. Get Baseline
    with graph.as_default():
        baseline_res = model_eval.detect([image_np], verbose=0)[0]
    
    if len(baseline_res['scores']) == 0:
        return {"error": "No target detected for occlusion analysis."}
    
    baseline_score = baseline_res['scores'][0]

    # 3. Fast Parameters
    # Patch size 64 with stride 64 means only 16 total inferences (4x4 grid)
    # This is the "Safety Zone" for preventing timeouts
    patch_size = 64 
    stride = 64
    
    sensitivity_map = np.zeros((h, w), dtype=np.float32)

    # 4. Optimized Inference Loop
    print("Starting Fast Occlusion Analysis...")
    with graph.as_default():
        for y in range(0, h, stride):
            for x in range(0, w, stride):
                # Ensure patch doesn't exceed image bounds
                y_end = min(y + patch_size, h)
                x_end = min(x + patch_size, w)
                
                img_occ = image_np.copy()
                img_occ[y:y_end, x:x_end, :] = 128 # Neutral Gray
                
                # Single inference
                res = model_eval.detect([img_occ], verbose=0)[0]
                
                # Calculate drop
                current_score = res['scores'][0] if len(res['scores']) > 0 else 0
                drop = max(0, baseline_score - current_score)
                
                # Fill map
                sensitivity_map[y:y_end, x:x_end] = drop
                print(f"Processed patch at ({x}, {y}) - Drop: {drop:.4f}")

    # 5. Smooth the Heatmap
    # Since we have big blocks, a slight Gaussian blur makes it look professional
    if np.max(sensitivity_map) > 0:
        sensitivity_map = (sensitivity_map - np.min(sensitivity_map)) / (np.max(sensitivity_map) - np.min(sensitivity_map) + 1e-8)
    
    sensitivity_map = cv2.GaussianBlur(sensitivity_map, (15, 15), 0)
    heatmap = cv2.applyColorMap(np.uint8(255 * sensitivity_map), cv2.COLORMAP_JET)
    
    # 6. Final Composite
    original_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
    overlay = cv2.addWeighted(original_bgr, 0.6, heatmap, 0.4, 0)

    # Label it as a "Sample" to justify the lower resolution
    cv2.putText(overlay, "XAI: Occlusion Sensitivity (Fast Sample)", (10, 20), 
                cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)

    # 7. Stream
    _, buffer = cv2.imencode('.jpg', overlay)
    return StreamingResponse(io.BytesIO(buffer.tobytes()), media_type="image/jpeg")