File size: 12,210 Bytes
f2bc4d7
 
 
 
c46b088
 
 
f2bc4d7
 
 
 
1a81fb2
f2bc4d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a60a006
f2bc4d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a81fb2
 
 
 
 
f2bc4d7
 
 
 
 
 
 
1a81fb2
f2bc4d7
 
 
 
1a81fb2
f2bc4d7
 
 
1a81fb2
f2bc4d7
1a81fb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
037bb79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55df2c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
import os
import io
import cv2
import numpy as np
import os
os.environ['TF_USE_LEGACY_KERAS'] = '1'
import tf_keras as keras
import tensorflow as tf
from fastapi import FastAPI, UploadFile, File
from PIL import Image
from huggingface_hub import snapshot_download
from fastapi.responses import StreamingResponse

from object_detection.utils import label_map_util, config_util
from object_detection.builders import model_builder

app = FastAPI()

# 1. Download Private Models
HF_TOKEN = os.getenv("HF_Token")
REPO_ID = "SaniaE/Car_Damage_Detection"

print("Downloading models from Hugging Face...")
model_dir = snapshot_download(
    repo_id=REPO_ID, 
    token=HF_TOKEN, 
    local_dir="./models_data"
)

PIPELINE_CONFIG = os.path.join(model_dir, "object_detection_model/pipeline.config")
CHECKPOINT_PATH = os.path.join(model_dir, "object_detection_model/ckpt-37")
LABEL_MAP_PATH = os.path.join(model_dir, "object_detection_model/label_map.pbtxt")
CNN_MODEL_PATH = os.path.join(model_dir, "cnn_filter.h5")

# 3. Load Models
# Load CNN Filter
cnn_filter = tf.keras.models.load_model(CNN_MODEL_PATH, compile=False)

# Load Object Detection Model
configs = config_util.get_configs_from_pipeline_file(PIPELINE_CONFIG)
detection_model = model_builder.build(model_config=configs['model'], is_training=False)
ckpt = tf.compat.v2.train.Checkpoint(model=detection_model)
ckpt.restore(CHECKPOINT_PATH).expect_partial()

category_index = label_map_util.create_category_index_from_labelmap(LABEL_MAP_PATH)

@tf.function
def detect_fn(image):
    image, shapes = detection_model.preprocess(image)
    prediction_dict = detection_model.predict(image, shapes)
    detections = detection_model.postprocess(prediction_dict, shapes)
    return detections

@app.get("/")
def read_root():
    return {"status": "Model is Online", "model_repo": REPO_ID}

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    # Read Image
    contents = await file.read()
    image_pil = Image.open(io.BytesIO(contents)).convert("RGB")
    image_np = np.array(image_pil)
    
    # We need a BGR version for OpenCV drawing
    image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
    height, width, _ = image_cv.shape

    # Step 1: CNN Filter
    img_cnn = image_pil.resize((64, 64))
    x = tf.keras.preprocessing.image.img_to_array(img_cnn)
    x = np.expand_dims(x, axis=0)
    
    cnn_pred = cnn_filter.predict(x)
    is_damage_labels = ['Clear', 'Damaged']
    status = is_damage_labels[np.argmax(cnn_pred)]

    # Step 2: Object Detection (If damaged)
    if status == 'Damaged':
        input_tensor = tf.convert_to_tensor(np.expand_dims(image_np, 0), dtype=tf.float32)
        detections = detect_fn(input_tensor)

        scores = detections['detection_scores'][0].numpy()
        classes = detections['detection_classes'][0].numpy().astype(int)
        boxes = detections['detection_boxes'][0].numpy()

        for i in range(len(scores)):
            if scores[i] > 0.4:
                # TFOD Boxes are [ymin, xmin, ymax, xmax] in normalized coordinates
                ymin, xmin, ymax, xmax = boxes[i]
                (left, right, top, bottom) = (xmin * width, xmax * width, 
                                              ymin * height, ymax * height)

                # Draw Bounding Box (Teal color to match your vibe)
                cv2.rectangle(image_cv, (int(left), int(top)), (int(right), int(bottom)), (255, 255, 0), 2)

                # Draw Label
                label = f"{category_index.get(classes[i] + 1, {}).get('name', 'unknown')}: {int(scores[i]*100)}%"
                cv2.putText(image_cv, label, (int(left), int(top) - 10),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 0), 2)

    # Encode the image back to JPEG
    _, buffer = cv2.imencode('.jpg', image_cv)
    return StreamingResponse(io.BytesIO(buffer.tobytes()), media_type="image/jpeg")


def get_top_prediction(detections):
    """Extracts the index of the most confident detection."""
    scores = detections['detection_scores'][0].numpy()
    if len(scores) > 0 and scores[0] > 0.4:
        # Returns index 0 (top score) and the class ID
        return 0, int(detections['detection_classes'][0].numpy()[0])
    return None, None

    
@app.post("/explain")
async def explain(file: UploadFile = File(...)):
    # 1. Prepare Image
    contents = await file.read()
    image_pil = Image.open(io.BytesIO(contents)).convert("RGB")
    image_np = np.array(image_pil).astype(np.float32)
    input_tensor = tf.convert_to_tensor(np.expand_dims(image_np, 0), dtype=tf.float32)

    # 2. Gradient Tape for Saliency
    with tf.GradientTape() as tape:
        tape.watch(input_tensor)
        
        # Manually run the forward pass through the detection model
        image, shapes = detection_model.preprocess(input_tensor)
        prediction_dict = detection_model.predict(image, shapes)
        
        # 'class_predictions_with_background' is standard for TFOD SSD/FasterRCNN models
        # It usually has shape [1, num_anchors, num_classes]
        raw_scores = prediction_dict['class_predictions_with_background'][0]
        
        # We need a reference detection to know which class to compute gradients for
        detections = detection_model.postprocess(prediction_dict, shapes)
        _, top_class = get_top_prediction(detections)

        if top_class is None:
            return {"error": "No object detected with sufficient confidence to explain."}

        # Focus loss on the max score for that specific class across all anchors
        loss = tf.reduce_max(raw_scores[:, top_class])

    # 3. Compute Gradients
    grads = tape.gradient(loss, input_tensor)
    # Take absolute max across color channels
    saliency = np.max(np.abs(grads.numpy()), axis=-1)[0]

    # 4. Normalize and Create Heatmap
    # Using 95th percentile to reduce noise/outliers
    v_min, v_max = np.percentile(saliency, (5, 95))
    saliency = np.clip((saliency - v_min) / (v_max - v_min + 1e-8), 0, 1)
    
    # Create the JET heatmap (Blue = low, Red = high)
    heatmap = cv2.applyColorMap(np.uint8(255 * saliency), cv2.COLORMAP_JET)
    
    # 5. Overlay on original image (Convert original to BGR first)
    original_bgr = cv2.cvtColor(image_np.astype(np.uint8), cv2.COLOR_RGB2BGR)
    overlay = cv2.addWeighted(original_bgr, 0.6, heatmap, 0.4, 0)

    # Add text label for what we are explaining
    class_name = category_index.get(top_class + 1, {}).get('name', 'unknown')
    cv2.putText(overlay, f"Explaining: {class_name}", (10, 30), 
                cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)

    # 6. Stream Result
    _, buffer = cv2.imencode('.jpg', overlay)
    return StreamingResponse(io.BytesIO(buffer.tobytes()), media_type="image/jpeg")


@app.post("/explain/tiled")
async def explain_tiled(file: UploadFile = File(...)):
    # 1. Prepare Base Image
    contents = await file.read()
    image_pil = Image.open(io.BytesIO(contents)).convert("RGB")
    image_np = np.array(image_pil).astype(np.float32)
    input_tensor = tf.convert_to_tensor(np.expand_dims(image_np, 0), dtype=tf.float32)
    
    # 2. Get Initial Detections to know what to "Explain"
    detections = detect_fn(input_tensor)
    scores = detections['detection_scores'][0].numpy()
    classes = detections['detection_classes'][0].numpy().astype(int)
    boxes = detections['detection_boxes'][0].numpy()

    # Create the Top-Left "Base" image with all boxes
    base_image = cv2.cvtColor(image_np.astype(np.uint8), cv2.COLOR_RGB2BGR)
    h_img, w_img, _ = base_image.shape
    
    for i in range(min(len(scores), 3)):
        if scores[i] > 0.4:
            ymin, xmin, ymax, xmax = boxes[i]
            cv2.rectangle(base_image, (int(xmin*w_img), int(ymin*h_img)), 
                          (int(xmax*w_img), int(ymax*h_img)), (255, 255, 0), 2)

    # 3. Generate Saliency Maps for the Top 3 detections
    panels = [base_image]
    
    for i in range(3):
        if i < len(scores) and scores[i] > 0.4:
            target_class = classes[i]
            
            with tf.GradientTape() as tape:
                tape.watch(input_tensor)
                image, shapes = detection_model.preprocess(input_tensor)
                prediction_dict = detection_model.predict(image, shapes)
                raw_scores = prediction_dict['class_predictions_with_background'][0]
                # Target the specific class at its most active anchor
                loss = tf.reduce_max(raw_scores[:, target_class])

            grads = tape.gradient(loss, input_tensor)
            saliency = np.max(np.abs(grads.numpy()), axis=-1)[0]
            
            # Normalize and Colorize
            v_min, v_max = np.percentile(saliency, (5, 95))
            saliency = np.clip((saliency - v_min) / (v_max - v_min + 1e-8), 0, 1)
            heatmap = cv2.applyColorMap(np.uint8(255 * saliency), cv2.COLORMAP_JET)
            
            # Overlay
            overlay = cv2.addWeighted(cv2.cvtColor(image_np.astype(np.uint8), cv2.COLOR_RGB2BGR), 0.6, heatmap, 0.4, 0)
            
            # Label the panel
            class_name = category_index.get(target_class + 1, {}).get('name', 'unknown')
            cv2.putText(overlay, f"Top {i+1}: {class_name}", (10, 30), 
                        cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
            panels.append(overlay)
        else:
            # Placeholder for empty slots if fewer than 3 detections exist
            panels.append(np.zeros_like(base_image))

    # 4. Assemble the 2x2 Grid
    # Panels are: [0:Base, 1:Top1, 2:Top2, 3:Top3]
    top_row = np.hstack((panels[0], panels[1]))
    bottom_row = np.hstack((panels[2], panels[3]))
    tiled_output = np.vstack((top_row, bottom_row))

    # 5. Stream Result
    _, buffer = cv2.imencode('.jpg', tiled_output)
    return StreamingResponse(io.BytesIO(buffer.tobytes()), media_type="image/jpeg")


@app.post("/explain/global")
async def explain_global(file: UploadFile = File(...)):
    # 1. Read and Prepare Image
    contents = await file.read()
    image_pil = Image.open(io.BytesIO(contents)).convert("RGB")
    image_np = np.array(image_pil).astype(np.float32)
    # Keeping a uint8 copy for the final BGR overlay
    image_bgr = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
    
    input_tensor = tf.convert_to_tensor(np.expand_dims(image_np, 0), dtype=tf.float32)

    # 2. Gradient Tape for Global Activation
    with tf.GradientTape() as tape:
        tape.watch(input_tensor)
        
        # Forward pass
        image, shapes = detection_model.preprocess(input_tensor)
        prediction_dict = detection_model.predict(image, shapes)
        
        # 'class_predictions_with_background' shape: [1, num_anchors, num_classes]
        raw_scores = prediction_dict['class_predictions_with_background'][0]
        
        # We ignore index 0 (Background/Clear) and look at all damage classes
        # We take the max score at each anchor point, then sum them for the global loss
        foreground_scores = raw_scores[:, 1:] 
        loss = tf.reduce_sum(tf.reduce_max(foreground_scores, axis=-1))

    # 3. Compute and Process Gradients
    grads = tape.gradient(loss, input_tensor)
    saliency = np.max(np.abs(grads.numpy()), axis=-1)[0]

    # 4. Refine Saliency Visualization
    # Using the 95th percentile helps ignore "pixel noise" and highlights the actual damage
    v_min, v_max = np.percentile(saliency, (5, 95))
    saliency = np.clip((saliency - v_min) / (v_max - v_min + 1e-8), 0, 1)
    
    # Create the heatmap overlay
    heatmap = cv2.applyColorMap(np.uint8(255 * saliency), cv2.COLORMAP_JET)
    
    # Blend: 60% original image, 40% heatmap
    # This maintains the "Pinterest-chic" aesthetic without washing out the car details
    overlay = cv2.addWeighted(image_bgr, 0.6, heatmap, 0.4, 0)

    # 5. Add Branding/Label
    # Teal text to match your office setup/portfolio theme
    cv2.putText(overlay, "Global Model Attention", (20, 40), 
                cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 0), 2)

    # 6. Stream Result
    _, buffer = cv2.imencode('.jpg', overlay)
    return StreamingResponse(io.BytesIO(buffer.tobytes()), media_type="image/jpeg")