import os import io import cv2 import numpy as np import os os.environ['TF_USE_LEGACY_KERAS'] = '1' import tf_keras as keras import tensorflow as tf from fastapi import FastAPI, UploadFile, File from PIL import Image from huggingface_hub import snapshot_download from fastapi.responses import StreamingResponse from object_detection.utils import label_map_util, config_util from object_detection.builders import model_builder app = FastAPI() # 1. Download Private Models HF_TOKEN = os.getenv("HF_Token") REPO_ID = "SaniaE/Car_Damage_Detection" print("Downloading models from Hugging Face...") model_dir = snapshot_download( repo_id=REPO_ID, token=HF_TOKEN, local_dir="./models_data" ) PIPELINE_CONFIG = os.path.join(model_dir, "object_detection_model/pipeline.config") CHECKPOINT_PATH = os.path.join(model_dir, "object_detection_model/ckpt-37") LABEL_MAP_PATH = os.path.join(model_dir, "object_detection_model/label_map.pbtxt") CNN_MODEL_PATH = os.path.join(model_dir, "cnn_filter.h5") # 3. Load Models # Load CNN Filter cnn_filter = tf.keras.models.load_model(CNN_MODEL_PATH, compile=False) # Load Object Detection Model configs = config_util.get_configs_from_pipeline_file(PIPELINE_CONFIG) detection_model = model_builder.build(model_config=configs['model'], is_training=False) ckpt = tf.compat.v2.train.Checkpoint(model=detection_model) ckpt.restore(CHECKPOINT_PATH).expect_partial() category_index = label_map_util.create_category_index_from_labelmap(LABEL_MAP_PATH) @tf.function def detect_fn(image): image, shapes = detection_model.preprocess(image) prediction_dict = detection_model.predict(image, shapes) detections = detection_model.postprocess(prediction_dict, shapes) return detections @app.get("/") def read_root(): return {"status": "Model is Online", "model_repo": REPO_ID} @app.post("/predict") async def predict(file: UploadFile = File(...)): # Read Image contents = await file.read() image_pil = Image.open(io.BytesIO(contents)).convert("RGB") image_np = np.array(image_pil) # We need a BGR version for OpenCV drawing image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) height, width, _ = image_cv.shape # Step 1: CNN Filter img_cnn = image_pil.resize((64, 64)) x = tf.keras.preprocessing.image.img_to_array(img_cnn) x = np.expand_dims(x, axis=0) cnn_pred = cnn_filter.predict(x) is_damage_labels = ['Clear', 'Damaged'] status = is_damage_labels[np.argmax(cnn_pred)] # Step 2: Object Detection (If damaged) if status == 'Damaged': input_tensor = tf.convert_to_tensor(np.expand_dims(image_np, 0), dtype=tf.float32) detections = detect_fn(input_tensor) scores = detections['detection_scores'][0].numpy() classes = detections['detection_classes'][0].numpy().astype(int) boxes = detections['detection_boxes'][0].numpy() for i in range(len(scores)): if scores[i] > 0.4: # TFOD Boxes are [ymin, xmin, ymax, xmax] in normalized coordinates ymin, xmin, ymax, xmax = boxes[i] (left, right, top, bottom) = (xmin * width, xmax * width, ymin * height, ymax * height) # Draw Bounding Box (Teal color to match your vibe) cv2.rectangle(image_cv, (int(left), int(top)), (int(right), int(bottom)), (255, 255, 0), 2) # Draw Label label = f"{category_index.get(classes[i] + 1, {}).get('name', 'unknown')}: {int(scores[i]*100)}%" cv2.putText(image_cv, label, (int(left), int(top) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 0), 2) # Encode the image back to JPEG _, buffer = cv2.imencode('.jpg', image_cv) return StreamingResponse(io.BytesIO(buffer.tobytes()), media_type="image/jpeg") def get_top_prediction(detections): """Extracts the index of the most confident detection.""" scores = detections['detection_scores'][0].numpy() if len(scores) > 0 and scores[0] > 0.4: # Returns index 0 (top score) and the class ID return 0, int(detections['detection_classes'][0].numpy()[0]) return None, None @app.post("/explain") async def explain(file: UploadFile = File(...)): # 1. Prepare Image contents = await file.read() image_pil = Image.open(io.BytesIO(contents)).convert("RGB") image_np = np.array(image_pil).astype(np.float32) input_tensor = tf.convert_to_tensor(np.expand_dims(image_np, 0), dtype=tf.float32) # 2. Gradient Tape for Saliency with tf.GradientTape() as tape: tape.watch(input_tensor) # Manually run the forward pass through the detection model image, shapes = detection_model.preprocess(input_tensor) prediction_dict = detection_model.predict(image, shapes) # 'class_predictions_with_background' is standard for TFOD SSD/FasterRCNN models # It usually has shape [1, num_anchors, num_classes] raw_scores = prediction_dict['class_predictions_with_background'][0] # We need a reference detection to know which class to compute gradients for detections = detection_model.postprocess(prediction_dict, shapes) _, top_class = get_top_prediction(detections) if top_class is None: return {"error": "No object detected with sufficient confidence to explain."} # Focus loss on the max score for that specific class across all anchors loss = tf.reduce_max(raw_scores[:, top_class]) # 3. Compute Gradients grads = tape.gradient(loss, input_tensor) # Take absolute max across color channels saliency = np.max(np.abs(grads.numpy()), axis=-1)[0] # 4. Normalize and Create Heatmap # Using 95th percentile to reduce noise/outliers v_min, v_max = np.percentile(saliency, (5, 95)) saliency = np.clip((saliency - v_min) / (v_max - v_min + 1e-8), 0, 1) # Create the JET heatmap (Blue = low, Red = high) heatmap = cv2.applyColorMap(np.uint8(255 * saliency), cv2.COLORMAP_JET) # 5. Overlay on original image (Convert original to BGR first) original_bgr = cv2.cvtColor(image_np.astype(np.uint8), cv2.COLOR_RGB2BGR) overlay = cv2.addWeighted(original_bgr, 0.6, heatmap, 0.4, 0) # Add text label for what we are explaining class_name = category_index.get(top_class + 1, {}).get('name', 'unknown') cv2.putText(overlay, f"Explaining: {class_name}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2) # 6. Stream Result _, buffer = cv2.imencode('.jpg', overlay) return StreamingResponse(io.BytesIO(buffer.tobytes()), media_type="image/jpeg") @app.post("/explain/tiled") async def explain_tiled(file: UploadFile = File(...)): # 1. Prepare Base Image contents = await file.read() image_pil = Image.open(io.BytesIO(contents)).convert("RGB") image_np = np.array(image_pil).astype(np.float32) input_tensor = tf.convert_to_tensor(np.expand_dims(image_np, 0), dtype=tf.float32) # 2. Get Initial Detections to know what to "Explain" detections = detect_fn(input_tensor) scores = detections['detection_scores'][0].numpy() classes = detections['detection_classes'][0].numpy().astype(int) boxes = detections['detection_boxes'][0].numpy() # Create the Top-Left "Base" image with all boxes base_image = cv2.cvtColor(image_np.astype(np.uint8), cv2.COLOR_RGB2BGR) h_img, w_img, _ = base_image.shape for i in range(min(len(scores), 3)): if scores[i] > 0.4: ymin, xmin, ymax, xmax = boxes[i] cv2.rectangle(base_image, (int(xmin*w_img), int(ymin*h_img)), (int(xmax*w_img), int(ymax*h_img)), (255, 255, 0), 2) # 3. Generate Saliency Maps for the Top 3 detections panels = [base_image] for i in range(3): if i < len(scores) and scores[i] > 0.4: target_class = classes[i] with tf.GradientTape() as tape: tape.watch(input_tensor) image, shapes = detection_model.preprocess(input_tensor) prediction_dict = detection_model.predict(image, shapes) raw_scores = prediction_dict['class_predictions_with_background'][0] # Target the specific class at its most active anchor loss = tf.reduce_max(raw_scores[:, target_class]) grads = tape.gradient(loss, input_tensor) saliency = np.max(np.abs(grads.numpy()), axis=-1)[0] # Normalize and Colorize v_min, v_max = np.percentile(saliency, (5, 95)) saliency = np.clip((saliency - v_min) / (v_max - v_min + 1e-8), 0, 1) heatmap = cv2.applyColorMap(np.uint8(255 * saliency), cv2.COLORMAP_JET) # Overlay overlay = cv2.addWeighted(cv2.cvtColor(image_np.astype(np.uint8), cv2.COLOR_RGB2BGR), 0.6, heatmap, 0.4, 0) # Label the panel class_name = category_index.get(target_class + 1, {}).get('name', 'unknown') cv2.putText(overlay, f"Top {i+1}: {class_name}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) panels.append(overlay) else: # Placeholder for empty slots if fewer than 3 detections exist panels.append(np.zeros_like(base_image)) # 4. Assemble the 2x2 Grid # Panels are: [0:Base, 1:Top1, 2:Top2, 3:Top3] top_row = np.hstack((panels[0], panels[1])) bottom_row = np.hstack((panels[2], panels[3])) tiled_output = np.vstack((top_row, bottom_row)) # 5. Stream Result _, buffer = cv2.imencode('.jpg', tiled_output) return StreamingResponse(io.BytesIO(buffer.tobytes()), media_type="image/jpeg") @app.post("/explain/global") async def explain_global(file: UploadFile = File(...)): # 1. Read and Prepare Image contents = await file.read() image_pil = Image.open(io.BytesIO(contents)).convert("RGB") image_np = np.array(image_pil).astype(np.float32) # Keeping a uint8 copy for the final BGR overlay image_bgr = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR) input_tensor = tf.convert_to_tensor(np.expand_dims(image_np, 0), dtype=tf.float32) # 2. Gradient Tape for Global Activation with tf.GradientTape() as tape: tape.watch(input_tensor) # Forward pass image, shapes = detection_model.preprocess(input_tensor) prediction_dict = detection_model.predict(image, shapes) # 'class_predictions_with_background' shape: [1, num_anchors, num_classes] raw_scores = prediction_dict['class_predictions_with_background'][0] # We ignore index 0 (Background/Clear) and look at all damage classes # We take the max score at each anchor point, then sum them for the global loss foreground_scores = raw_scores[:, 1:] loss = tf.reduce_sum(tf.reduce_max(foreground_scores, axis=-1)) # 3. Compute and Process Gradients grads = tape.gradient(loss, input_tensor) saliency = np.max(np.abs(grads.numpy()), axis=-1)[0] # 4. Refine Saliency Visualization # Using the 95th percentile helps ignore "pixel noise" and highlights the actual damage v_min, v_max = np.percentile(saliency, (5, 95)) saliency = np.clip((saliency - v_min) / (v_max - v_min + 1e-8), 0, 1) # Create the heatmap overlay heatmap = cv2.applyColorMap(np.uint8(255 * saliency), cv2.COLORMAP_JET) # Blend: 60% original image, 40% heatmap # This maintains the "Pinterest-chic" aesthetic without washing out the car details overlay = cv2.addWeighted(image_bgr, 0.6, heatmap, 0.4, 0) # 5. Add Branding/Label # Teal text to match your office setup/portfolio theme cv2.putText(overlay, "Global Model Attention", (20, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 0), 2) # 6. Stream Result _, buffer = cv2.imencode('.jpg', overlay) return StreamingResponse(io.BytesIO(buffer.tobytes()), media_type="image/jpeg")