Kory
changing to flask
3709a55
"""
Waldo and Wilma Detector - Flask Web App for Hugging Face Spaces
Detects Waldo and Wilma in uploaded images using Grounding DINO
"""
from flask import Flask, render_template, request, jsonify, send_from_directory
import torch
from PIL import Image, ImageDraw, ImageFont
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
import numpy as np
import os
import io
import base64
from werkzeug.utils import secure_filename
import uuid
app = Flask(__name__)
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size
app.config['UPLOAD_FOLDER'] = 'uploads'
app.config['ALLOWED_EXTENSIONS'] = {'png', 'jpg', 'jpeg', 'gif', 'bmp'}
# Create upload folder if it doesn't exist
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
class WaldoWilmaDetector:
"""Detector for Waldo and Wilma using Grounding DINO."""
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model on {self.device}...")
model_id = "IDEA-Research/grounding-dino-base"
self.processor = AutoProcessor.from_pretrained(model_id)
self.model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(self.device)
print("Model loaded successfully!")
def detect(self, image, box_threshold=0.35, text_threshold=0.25):
"""
Detect Waldo and Wilma in an image.
Args:
image: PIL Image
box_threshold: Confidence threshold for bounding boxes
text_threshold: Confidence threshold for text matching
Returns:
tuple: (annotated_image, detection_info)
"""
# Ensure image is RGB
if image.mode != "RGB":
image = image.convert("RGB")
# Text prompt optimized for Waldo and Wilma detection
text_prompt = (
"a person wearing red and white horizontal striped shirt with glasses and beanie. "
"a person wearing red and white striped shirt. "
"a character with red white stripes."
)
# Process image
inputs = self.processor(images=image, text=text_prompt, return_tensors="pt").to(self.device)
# Run detection
with torch.no_grad():
outputs = self.model(**inputs)
# Post-process results
results = self.processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
text_threshold=text_threshold,
target_sizes=[image.size[::-1]]
)
# Visualize results
annotated_image = self._draw_boxes(image.copy(), results[0])
# Get detection info
num_detections = len(results[0]["boxes"])
boxes = results[0]["boxes"].tolist()
scores = results[0]["scores"].tolist()
detection_info = {
'num_detections': num_detections,
'boxes': boxes,
'scores': scores,
'summary': f"Found {num_detections} character(s) wearing red and white stripes!"
}
return annotated_image, detection_info
def _draw_boxes(self, image, result):
"""Draw bounding boxes on image."""
draw = ImageDraw.Draw(image)
# Try to load a font, fallback to default if not available
try:
font = ImageFont.truetype("arial.ttf", 20)
except:
font = ImageFont.load_default()
boxes = result["boxes"]
scores = result["scores"]
labels = result["labels"]
# Define colors for different detections
colors = ["red", "blue", "green", "yellow", "purple"]
for idx, (box, score, label) in enumerate(zip(boxes, scores, labels)):
# Convert box coordinates to integers
box = [int(i) for i in box.tolist()]
color = colors[idx % len(colors)]
# Draw rectangle
draw.rectangle(box, outline=color, width=4)
# Draw label with background
text_label = f"Waldo/Wilma: {score:.2%}"
# Get text bounding box for background
bbox = draw.textbbox((box[0], box[1] - 25), text_label, font=font)
draw.rectangle(bbox, fill=color)
draw.text((box[0], box[1] - 25), text_label, fill="white", font=font)
return image
# Initialize detector
print("Initializing Waldo & Wilma Detector...")
detector = WaldoWilmaDetector()
print("Detector ready!")
def allowed_file(filename):
"""Check if file extension is allowed."""
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS']
def image_to_base64(image):
"""Convert PIL Image to base64 string."""
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/png;base64,{img_str}"
@app.route('/')
def index():
"""Render the main page."""
return render_template('index.html')
@app.route('/detect', methods=['POST'])
def detect():
"""Handle image upload and detection."""
try:
# Check if image was uploaded
if 'image' not in request.files:
return jsonify({'error': 'No image uploaded'}), 400
file = request.files['image']
if file.filename == '':
return jsonify({'error': 'No image selected'}), 400
if not allowed_file(file.filename):
return jsonify({'error': 'Invalid file type. Please upload an image.'}), 400
# Get threshold parameters
box_threshold = float(request.form.get('box_threshold', 0.35))
text_threshold = float(request.form.get('text_threshold', 0.25))
# Read and process image
image = Image.open(file.stream)
# Run detection
annotated_image, detection_info = detector.detect(
image,
box_threshold=box_threshold,
text_threshold=text_threshold
)
# Convert annotated image to base64
result_image_base64 = image_to_base64(annotated_image)
return jsonify({
'success': True,
'image': result_image_base64,
'detections': detection_info
})
except Exception as e:
print(f"Error during detection: {str(e)}")
return jsonify({'error': f'Detection failed: {str(e)}'}), 500
@app.route('/health')
def health():
"""Health check endpoint."""
return jsonify({'status': 'healthy', 'model_loaded': True})
if __name__ == '__main__':
# For local development
app.run(host='0.0.0.0', port=7860, debug=False)