mport torch import torch.nn as nn from torchvision import transforms, models import numpy as np import cv2 from PIL import Image import io import json from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.responses import JSONResponse import base64 from typing import List, Dict import matplotlib.pyplot as plt import matplotlib matplotlib.use('Agg') # ============================================================================ # CONSTANTS # ============================================================================ IMG_SIZE = 224 NUM_CLASSES = 4 DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') UNHYGIENIC_CLASSES = [1, 2, 3] # Adjust based on your class indices # ============================================================================ # BOUNDING BOX DETECTION MODULE # ============================================================================ class BoundingBoxDetector: """Detects and localizes problem regions using attention maps""" def __init__(self, threshold=0.2, min_area=15, max_boxes=15): self.threshold = threshold self.min_area = min_area self.max_boxes = max_boxes def get_bboxes_from_heatmap(self, heatmap, orig_width, orig_height): """Extract bounding boxes from attention heatmap""" heatmap = cv2.resize(heatmap, (orig_width, orig_height)) # Normalize heatmap heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8) # Threshold threshold = np.percentile(heatmap, 85) binary = (heatmap > threshold).astype(np.uint8) * 255 # Find contours contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) bboxes = [] for contour in contours: area = cv2.contourArea(contour) if area < 100: # increase threshold continue x, y, w, h = cv2.boundingRect(contour) # Reject giant boxes if w > 0.9 * orig_width and h > 0.9 * orig_height: continue confidence = heatmap[y:y+h, x:x+w].mean() bboxes.append({ 'x': int(x), 'y': int(y), 'width': int(w), 'height': int(h), 'confidence': float(confidence), 'area': int(area) }) # Sort by confidence and keep top N bboxes = sorted(bboxes, key=lambda b: b['confidence'], reverse=True)[:self.max_boxes] return bboxes # ============================================================================ # INFERENCE RESULT CONTAINER # ============================================================================ class InferenceResult: """Container for all inference outputs""" def __init__(self): self.prediction = None # Class index self.confidence = None # Confidence score self.probabilities = None # All class probabilities self.gradcam = None # GradCAM heatmap (numpy) self.gradcam_image = None # GradCAM overlay (PIL Image) self.bbox_list = None # List of bounding boxes self.original_image = None # Input image (PIL Image) # ============================================================================ # MODEL DEFINITION # ============================================================================ class KitchenHygieneModelWithBBox(nn.Module): """EfficientNet with attention-based bbox localization AND integrated GradCAM""" def __init__(self, num_classes=NUM_CLASSES): super().__init__() # Base model base_model = models.efficientnet_b0(weights=None) # Freeze early layers for param in list(base_model.parameters())[:-35]: param.requires_grad = False self.features = base_model.features self.avgpool = base_model.avgpool # Classification head self.classifier = nn.Sequential( nn.Dropout(0.3), nn.Linear(base_model.classifier[1].in_features, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, num_classes) ) # Attention head for bbox localization self.attention_head = nn.Sequential( nn.Conv2d(base_model.classifier[1].in_features, 128, kernel_size=1), nn.BatchNorm2d(128), nn.ReLU(), nn.Conv2d(128, 64, kernel_size=1), nn.BatchNorm2d(64), nn.ReLU(), nn.Conv2d(64, 1, kernel_size=1), nn.Sigmoid() ) # Gradients for GradCAM self.gradients = None self.activations = None # Register hooks for GradCAM self.features[-1].register_forward_hook(self._save_activations) self.features[-1].register_full_backward_hook(self._save_gradients) self.num_classes = num_classes def _save_activations(self, module, input, output): self.activations = output.detach() def _save_gradients(self, module, grad_input, grad_output): self.gradients = grad_output[0].detach() def forward(self, x): # Feature extraction features = self.features(x) # Classification pool = self.avgpool(features) pool = torch.flatten(pool, 1) logits = self.classifier(pool) # Attention map for bbox attention_map = self.attention_head(features) # Return both return logits, attention_map def generate_gradcam(self, input_tensor, class_idx): """Generate GradCAM for specified class""" # Forward pass outputs, _ = self(input_tensor) # Backward pass self.zero_grad() one_hot = torch.zeros_like(outputs) one_hot[0][class_idx] = 1 outputs.backward(gradient=one_hot) # Calculate CAM if self.gradients is None or self.activations is None: return None gradients = self.gradients[0] activations = self.activations[0] # Weights: average gradients across spatial dimensions weights = gradients.mean(dim=(1, 2), keepdim=True) # Weighted activation maps cam = (weights * activations).sum(dim=0) # ReLU to keep only positive activations cam = torch.clamp(cam, min=0) # Normalize to 0-1 cam = cam - cam.min() cam = cam / (cam.max() + 1e-8) return cam.cpu().numpy() # ============================================================================ # UTILITY FUNCTIONS # ============================================================================ def overlay_gradcam_on_image(image, cam, alpha=0.5): """Overlay GradCAM heatmap on original image""" cam_resized = Image.fromarray((cam * 255).astype(np.uint8)).resize( (image.width, image.height), Image.BILINEAR ) cam_array = np.array(cam_resized) heatmap = plt.cm.hot(cam_array / 255.0) heatmap_rgb = Image.fromarray((heatmap[:, :, :3] * 255).astype(np.uint8)) blended = Image.blend(image.convert('RGB'), heatmap_rgb, alpha) return blended def image_to_base64(image): """Convert PIL Image to base64 string""" buffered = io.BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode() return img_str def draw_bboxes_on_image(image, bboxes, class_idx, class_names): """Draw bounding boxes on image and return as PIL Image""" img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) colors = { 1: (0, 0, 180), 2: (0, 0, 220), 3: (0, 0, 255) } color = colors.get(class_idx, (0, 0, 200)) for bbox in bboxes: x, y, w, h = int(bbox['x']), int(bbox['y']), int(bbox['width']), int(bbox['height']) conf = bbox['confidence'] # skip useless tiny boxes if w < 20 or h < 20: continue # Draw thick rectangle cv2.rectangle(img_cv, (x, y), (x + w, y + h), color, 6) # Label text label = f"{conf:.0%}" # Get text size (tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2) # Draw filled background cv2.rectangle(img_cv, (x, y - th - 10), (x + tw + 5, y), color, -1) # Put white text cv2.putText(img_cv, label, (x + 2, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2) return Image.fromarray(cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB)) # ============================================================================ # FASTAPI APP INITIALIZATION # ============================================================================ app = FastAPI( title="Kitchen Hygiene Classification API", description="Complete inference with GradCAM, Bounding Box Detection, and Prediction", version="1.0.0" ) from fastapi.responses import HTMLResponse @app.get("/", response_class=HTMLResponse) async def home(): return """
Go to /docs to test the API.
""" # Global variables model = None class_names = None @app.on_event("startup") async def load_model(): """Load model on startup""" global model, class_names try: # Load the full model model = KitchenHygieneModelWithBBox(num_classes=NUM_CLASSES) model.load_state_dict(torch.load("kitchen_model_new.pth", map_location=DEVICE)) model.to(DEVICE) model.eval() # Load class names from model info with open("model_info.json", "r") as f: model_info = json.load(f) class_names = model_info["classes"] print(f"✓ Model loaded successfully") print(f" Classes: {class_names}") print(f" Device: {DEVICE}") except Exception as e: print(f"ERROR loading model: {str(e)}") raise # ============================================================================ # API ENDPOINTS # ============================================================================ @app.get("/health") async def health_check(): """Health check endpoint""" return { "status": "healthy", "model_loaded": model is not None, "device": str(DEVICE), "num_classes": NUM_CLASSES, "classes": class_names } @app.post("/predict") async def predict(file: UploadFile = File(...)): """ Complete inference endpoint Returns: - prediction: predicted class name - confidence: confidence score - probabilities: all class probabilities - bounding_boxes: list of detected problem regions - gradcam_image: base64 encoded GradCAM overlay - bbox_image: base64 encoded image with bounding boxes """ if model is None: raise HTTPException(status_code=500, detail="Model not loaded") try: # Read uploaded image contents = await file.read() image = Image.open(io.BytesIO(contents)).convert('RGB') original_image = image.copy() orig_width, orig_height = image.size # Preprocess image transform = transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) image_tensor = transform(image).unsqueeze(0).to(DEVICE) # Step 1: Get prediction with torch.no_grad(): outputs, attention_maps = model(image_tensor) probabilities = torch.softmax(outputs[0], dim=0).detach().cpu().numpy() predicted_class_idx = int(np.argmax(probabilities)) confidence = float(probabilities[predicted_class_idx]) # Step 2: Generate GradCAM gradcam = model.generate_gradcam(image_tensor, predicted_class_idx) if gradcam is None: gradcam = np.zeros((IMG_SIZE, IMG_SIZE)) gradcam_image = overlay_gradcam_on_image(original_image, gradcam, alpha=0.4) # Step 3: Detect bounding boxes attention_np = gradcam if attention_np.max() > 0: attention_np = (attention_np - attention_np.min()) / (attention_np.max() - attention_np.min() + 1e-8) detector = BoundingBoxDetector(threshold=0.15, min_area=10, max_boxes=10) bboxes = detector.get_bboxes_from_heatmap(attention_np, orig_width, orig_height) # Only show bboxes for unhygienic classes filtered_bboxes = bboxes if predicted_class_idx in UNHYGIENIC_CLASSES else [] # Draw bboxes bbox_image = draw_bboxes_on_image(original_image, filtered_bboxes, predicted_class_idx, class_names) # Prepare response response = { "prediction": class_names[predicted_class_idx], "confidence": confidence, "probabilities": { class_names[i]: float(probabilities[i]) for i in range(len(class_names)) }, "bounding_boxes": filtered_bboxes, "num_problems_detected": len(filtered_bboxes), "gradcam_image": f"data:image/png;base64,{image_to_base64(gradcam_image)}", "bbox_image": f"data:image/png;base64,{image_to_base64(bbox_image)}" } return JSONResponse(content=response) except Exception as e: raise HTTPException(status_code=400, detail=f"Error processing image: {str(e)}") @app.post("/predict-simple") async def predict_simple(file: UploadFile = File(...)): """ Simplified prediction endpoint (returns only prediction and probabilities, no images) """ if model is None: raise HTTPException(status_code=500, detail="Model not loaded") try: # Read uploaded image contents = await file.read() image = Image.open(io.BytesIO(contents)).convert('RGB') # Preprocess image transform = transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) image_tensor = transform(image).unsqueeze(0).to(DEVICE) # Get prediction with torch.no_grad(): outputs, _ = model(image_tensor) probabilities = torch.softmax(outputs[0], dim=0).detach().cpu().numpy() predicted_class_idx = int(np.argmax(probabilities)) confidence = float(probabilities[predicted_class_idx]) response = { "prediction": class_names[predicted_class_idx], "confidence": confidence, "probabilities": { class_names[i]: float(probabilities[i]) for i in range(len(class_names)) } } return JSONResponse(content=response) except Exception as e: raise HTTPException(status_code=400, detail=f"Error processing image: {str(e)}") @app.post("/gradcam-only") async def gradcam_only(file: UploadFile = File(...)): """ GradCAM only endpoint (returns GradCAM heatmap and prediction) """ if model is None: raise HTTPException(status_code=500, detail="Model not loaded") try: contents = await file.read() image = Image.open(io.BytesIO(contents)).convert('RGB') original_image = image.copy() transform = transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) image_tensor = transform(image).unsqueeze(0).to(DEVICE) # Get prediction with torch.no_grad(): outputs, _ = model(image_tensor) probabilities = torch.softmax(outputs[0], dim=0).detach().cpu().numpy() predicted_class_idx = int(np.argmax(probabilities)) confidence = float(probabilities[predicted_class_idx]) # Generate GradCAM gradcam = model.generate_gradcam(image_tensor, predicted_class_idx) if gradcam is None: gradcam = np.zeros((IMG_SIZE, IMG_SIZE)) gradcam_image = overlay_gradcam_on_image(original_image, gradcam, alpha=0.4) response = { "prediction": class_names[predicted_class_idx], "confidence": confidence, "probabilities": { class_names[i]: float(probabilities[i]) for i in range(len(class_names)) }, "gradcam_image": f"data:image/png;base64,{image_to_base64(gradcam_image)}" } return JSONResponse(content=response) except Exception as e: raise HTTPException(status_code=400, detail=f"Error processing image: {str(e)}") @app.post("/bbox-detection") async def bbox_detection(file: UploadFile = File(...)): """ Bounding box detection only endpoint """ if model is None: raise HTTPException(status_code=500, detail="Model not loaded") try: contents = await file.read() image = Image.open(io.BytesIO(contents)).convert('RGB') original_image = image.copy() orig_width, orig_height = image.size transform = transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) image_tensor = transform(image).unsqueeze(0).to(DEVICE) # Get prediction and attention with torch.no_grad(): outputs, _ = model(image_tensor) probabilities = torch.softmax(outputs[0], dim=0).detach().cpu().numpy() predicted_class_idx = int(np.argmax(probabilities)) confidence = float(probabilities[predicted_class_idx]) # Generate GradCAM for attention gradcam = model.generate_gradcam(image_tensor, predicted_class_idx) if gradcam is None: gradcam = np.zeros((IMG_SIZE, IMG_SIZE)) attention_np = gradcam if attention_np.max() > 0: attention_np = (attention_np - attention_np.min()) / (attention_np.max() - attention_np.min() + 1e-8) # Detect boxes detector = BoundingBoxDetector(threshold=0.15, min_area=10, max_boxes=10) bboxes = detector.get_bboxes_from_heatmap(attention_np, orig_width, orig_height) filtered_bboxes = bboxes if predicted_class_idx in UNHYGIENIC_CLASSES else [] bbox_image = draw_bboxes_on_image(original_image, filtered_bboxes, predicted_class_idx, class_names) response = { "prediction": class_names[predicted_class_idx], "confidence": confidence, "bounding_boxes": filtered_bboxes, "num_problems_detected": len(filtered_bboxes), "bbox_image": f"data:image/png;base64,{image_to_base64(bbox_image)}" } return JSONResponse(content=response) except Exception as e: raise HTTPException(status_code=400, detail=f"Error processing image: {str(e)}") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)