Spaces:
Sleeping
Sleeping
| """ | |
| Waldo and Wilma Detector - Flask Web App for Hugging Face Spaces | |
| Detects Waldo and Wilma in uploaded images using Grounding DINO | |
| """ | |
| from flask import Flask, render_template, request, jsonify, send_from_directory | |
| import torch | |
| from PIL import Image, ImageDraw, ImageFont | |
| from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection | |
| import numpy as np | |
| import os | |
| import io | |
| import base64 | |
| from werkzeug.utils import secure_filename | |
| import uuid | |
| app = Flask(__name__) | |
| app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size | |
| app.config['UPLOAD_FOLDER'] = 'uploads' | |
| app.config['ALLOWED_EXTENSIONS'] = {'png', 'jpg', 'jpeg', 'gif', 'bmp'} | |
| # Create upload folder if it doesn't exist | |
| os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) | |
| class WaldoWilmaDetector: | |
| """Detector for Waldo and Wilma using Grounding DINO.""" | |
| def __init__(self): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Loading model on {self.device}...") | |
| model_id = "IDEA-Research/grounding-dino-base" | |
| self.processor = AutoProcessor.from_pretrained(model_id) | |
| self.model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(self.device) | |
| print("Model loaded successfully!") | |
| def detect(self, image, box_threshold=0.35, text_threshold=0.25): | |
| """ | |
| Detect Waldo and Wilma in an image. | |
| Args: | |
| image: PIL Image | |
| box_threshold: Confidence threshold for bounding boxes | |
| text_threshold: Confidence threshold for text matching | |
| Returns: | |
| tuple: (annotated_image, detection_info) | |
| """ | |
| # Ensure image is RGB | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| # Text prompt optimized for Waldo and Wilma detection | |
| text_prompt = ( | |
| "a person wearing red and white horizontal striped shirt with glasses and beanie. " | |
| "a person wearing red and white striped shirt. " | |
| "a character with red white stripes." | |
| ) | |
| # Process image | |
| inputs = self.processor(images=image, text=text_prompt, return_tensors="pt").to(self.device) | |
| # Run detection | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| # Post-process results | |
| results = self.processor.post_process_grounded_object_detection( | |
| outputs, | |
| inputs.input_ids, | |
| text_threshold=text_threshold, | |
| target_sizes=[image.size[::-1]] | |
| ) | |
| # Visualize results | |
| annotated_image = self._draw_boxes(image.copy(), results[0]) | |
| # Get detection info | |
| num_detections = len(results[0]["boxes"]) | |
| boxes = results[0]["boxes"].tolist() | |
| scores = results[0]["scores"].tolist() | |
| detection_info = { | |
| 'num_detections': num_detections, | |
| 'boxes': boxes, | |
| 'scores': scores, | |
| 'summary': f"Found {num_detections} character(s) wearing red and white stripes!" | |
| } | |
| return annotated_image, detection_info | |
| def _draw_boxes(self, image, result): | |
| """Draw bounding boxes on image.""" | |
| draw = ImageDraw.Draw(image) | |
| # Try to load a font, fallback to default if not available | |
| try: | |
| font = ImageFont.truetype("arial.ttf", 20) | |
| except: | |
| font = ImageFont.load_default() | |
| boxes = result["boxes"] | |
| scores = result["scores"] | |
| labels = result["labels"] | |
| # Define colors for different detections | |
| colors = ["red", "blue", "green", "yellow", "purple"] | |
| for idx, (box, score, label) in enumerate(zip(boxes, scores, labels)): | |
| # Convert box coordinates to integers | |
| box = [int(i) for i in box.tolist()] | |
| color = colors[idx % len(colors)] | |
| # Draw rectangle | |
| draw.rectangle(box, outline=color, width=4) | |
| # Draw label with background | |
| text_label = f"Waldo/Wilma: {score:.2%}" | |
| # Get text bounding box for background | |
| bbox = draw.textbbox((box[0], box[1] - 25), text_label, font=font) | |
| draw.rectangle(bbox, fill=color) | |
| draw.text((box[0], box[1] - 25), text_label, fill="white", font=font) | |
| return image | |
| # Initialize detector | |
| print("Initializing Waldo & Wilma Detector...") | |
| detector = WaldoWilmaDetector() | |
| print("Detector ready!") | |
| def allowed_file(filename): | |
| """Check if file extension is allowed.""" | |
| return '.' in filename and \ | |
| filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS'] | |
| def image_to_base64(image): | |
| """Convert PIL Image to base64 string.""" | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| return f"data:image/png;base64,{img_str}" | |
| def index(): | |
| """Render the main page.""" | |
| return render_template('index.html') | |
| def detect(): | |
| """Handle image upload and detection.""" | |
| try: | |
| # Check if image was uploaded | |
| if 'image' not in request.files: | |
| return jsonify({'error': 'No image uploaded'}), 400 | |
| file = request.files['image'] | |
| if file.filename == '': | |
| return jsonify({'error': 'No image selected'}), 400 | |
| if not allowed_file(file.filename): | |
| return jsonify({'error': 'Invalid file type. Please upload an image.'}), 400 | |
| # Get threshold parameters | |
| box_threshold = float(request.form.get('box_threshold', 0.35)) | |
| text_threshold = float(request.form.get('text_threshold', 0.25)) | |
| # Read and process image | |
| image = Image.open(file.stream) | |
| # Run detection | |
| annotated_image, detection_info = detector.detect( | |
| image, | |
| box_threshold=box_threshold, | |
| text_threshold=text_threshold | |
| ) | |
| # Convert annotated image to base64 | |
| result_image_base64 = image_to_base64(annotated_image) | |
| return jsonify({ | |
| 'success': True, | |
| 'image': result_image_base64, | |
| 'detections': detection_info | |
| }) | |
| except Exception as e: | |
| print(f"Error during detection: {str(e)}") | |
| return jsonify({'error': f'Detection failed: {str(e)}'}), 500 | |
| def health(): | |
| """Health check endpoint.""" | |
| return jsonify({'status': 'healthy', 'model_loaded': True}) | |
| if __name__ == '__main__': | |
| # For local development | |
| app.run(host='0.0.0.0', port=7860, debug=False) | |