""" Waldo and Wilma Detector - Flask Web App for Hugging Face Spaces Detects Waldo and Wilma in uploaded images using Grounding DINO """ from flask import Flask, render_template, request, jsonify, send_from_directory import torch from PIL import Image, ImageDraw, ImageFont from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection import numpy as np import os import io import base64 from werkzeug.utils import secure_filename import uuid app = Flask(__name__) app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size app.config['UPLOAD_FOLDER'] = 'uploads' app.config['ALLOWED_EXTENSIONS'] = {'png', 'jpg', 'jpeg', 'gif', 'bmp'} # Create upload folder if it doesn't exist os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) class WaldoWilmaDetector: """Detector for Waldo and Wilma using Grounding DINO.""" def __init__(self): self.device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Loading model on {self.device}...") model_id = "IDEA-Research/grounding-dino-base" self.processor = AutoProcessor.from_pretrained(model_id) self.model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(self.device) print("Model loaded successfully!") def detect(self, image, box_threshold=0.35, text_threshold=0.25): """ Detect Waldo and Wilma in an image. Args: image: PIL Image box_threshold: Confidence threshold for bounding boxes text_threshold: Confidence threshold for text matching Returns: tuple: (annotated_image, detection_info) """ # Ensure image is RGB if image.mode != "RGB": image = image.convert("RGB") # Text prompt optimized for Waldo and Wilma detection text_prompt = ( "a person wearing red and white horizontal striped shirt with glasses and beanie. " "a person wearing red and white striped shirt. " "a character with red white stripes." ) # Process image inputs = self.processor(images=image, text=text_prompt, return_tensors="pt").to(self.device) # Run detection with torch.no_grad(): outputs = self.model(**inputs) # Post-process results results = self.processor.post_process_grounded_object_detection( outputs, inputs.input_ids, text_threshold=text_threshold, target_sizes=[image.size[::-1]] ) # Visualize results annotated_image = self._draw_boxes(image.copy(), results[0]) # Get detection info num_detections = len(results[0]["boxes"]) boxes = results[0]["boxes"].tolist() scores = results[0]["scores"].tolist() detection_info = { 'num_detections': num_detections, 'boxes': boxes, 'scores': scores, 'summary': f"Found {num_detections} character(s) wearing red and white stripes!" } return annotated_image, detection_info def _draw_boxes(self, image, result): """Draw bounding boxes on image.""" draw = ImageDraw.Draw(image) # Try to load a font, fallback to default if not available try: font = ImageFont.truetype("arial.ttf", 20) except: font = ImageFont.load_default() boxes = result["boxes"] scores = result["scores"] labels = result["labels"] # Define colors for different detections colors = ["red", "blue", "green", "yellow", "purple"] for idx, (box, score, label) in enumerate(zip(boxes, scores, labels)): # Convert box coordinates to integers box = [int(i) for i in box.tolist()] color = colors[idx % len(colors)] # Draw rectangle draw.rectangle(box, outline=color, width=4) # Draw label with background text_label = f"Waldo/Wilma: {score:.2%}" # Get text bounding box for background bbox = draw.textbbox((box[0], box[1] - 25), text_label, font=font) draw.rectangle(bbox, fill=color) draw.text((box[0], box[1] - 25), text_label, fill="white", font=font) return image # Initialize detector print("Initializing Waldo & Wilma Detector...") detector = WaldoWilmaDetector() print("Detector ready!") def allowed_file(filename): """Check if file extension is allowed.""" return '.' in filename and \ filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS'] def image_to_base64(image): """Convert PIL Image to base64 string.""" buffered = io.BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode() return f"data:image/png;base64,{img_str}" @app.route('/') def index(): """Render the main page.""" return render_template('index.html') @app.route('/detect', methods=['POST']) def detect(): """Handle image upload and detection.""" try: # Check if image was uploaded if 'image' not in request.files: return jsonify({'error': 'No image uploaded'}), 400 file = request.files['image'] if file.filename == '': return jsonify({'error': 'No image selected'}), 400 if not allowed_file(file.filename): return jsonify({'error': 'Invalid file type. Please upload an image.'}), 400 # Get threshold parameters box_threshold = float(request.form.get('box_threshold', 0.35)) text_threshold = float(request.form.get('text_threshold', 0.25)) # Read and process image image = Image.open(file.stream) # Run detection annotated_image, detection_info = detector.detect( image, box_threshold=box_threshold, text_threshold=text_threshold ) # Convert annotated image to base64 result_image_base64 = image_to_base64(annotated_image) return jsonify({ 'success': True, 'image': result_image_base64, 'detections': detection_info }) except Exception as e: print(f"Error during detection: {str(e)}") return jsonify({'error': f'Detection failed: {str(e)}'}), 500 @app.route('/health') def health(): """Health check endpoint.""" return jsonify({'status': 'healthy', 'model_loaded': True}) if __name__ == '__main__': # For local development app.run(host='0.0.0.0', port=7860, debug=False)