doniramdani820's picture
Upload 6 files
d91a6ec verified
"""
reCAPTCHA 4x4 Segmentation API - Hugging Face Space
Lightweight API for 4x4 grid challenge solving using YOLO segmentation
"""
from flask import Flask, request, jsonify
from flask_cors import CORS
from ultralytics import YOLO
import cv2
import numpy as np
import yaml
import base64
import io
from PIL import Image
import time
import os
from functools import lru_cache, wraps
app = Flask(__name__)
CORS(app) # Enable CORS for browser extension
# API Security - Load secret key from environment
API_SECRET_KEY = os.environ.get('API_SECRET_KEY', None)
print(f"\nπŸ” API Security: {'ENABLED' if API_SECRET_KEY else 'DISABLED (PUBLIC ACCESS)'}")
if API_SECRET_KEY:
print(f" Secret Key: {API_SECRET_KEY[:8]}{'*' * (len(API_SECRET_KEY) - 8)}")
# Global variables
model = None
class_names = None
model_load_time = 0
request_count = 0
successful_count = 0
failed_count = 0
total_latency = 0.0
unauthorized_count = 0
# Configuration
MODEL_FOLDER = "." # Models in root folder (no subfolder)
CONFIDENCE_THRESHOLD = 0.25
IOU_THRESHOLD = 0.45
MASK_THRESHOLD_PERCENTAGE = 1.0 # 1% minimum overlap
def require_api_key(f):
"""
Decorator to require API key authentication
If API_SECRET_KEY is set in environment:
- Check X-API-Key header matches secret
- Return 401 if missing or invalid
If API_SECRET_KEY is NOT set:
- Allow all requests (public access)
"""
@wraps(f)
def decorated_function(*args, **kwargs):
global unauthorized_count
# If no secret key configured, allow all requests
if API_SECRET_KEY is None:
return f(*args, **kwargs)
# Check for API key in header
provided_key = request.headers.get('X-API-Key')
if not provided_key:
unauthorized_count += 1
return jsonify({
'success': False,
'error': 'Missing API key',
'message': 'Please provide X-API-Key header'
}), 401
if provided_key != API_SECRET_KEY:
unauthorized_count += 1
return jsonify({
'success': False,
'error': 'Invalid API key',
'message': 'The provided API key is incorrect'
}), 401
# Valid key, proceed with request
return f(*args, **kwargs)
return decorated_function
print("="*60)
print("πŸš€ reCAPTCHA 4x4 Segmentation API")
print("="*60)
@lru_cache(maxsize=1)
def load_model():
"""Load YOLO segmentation model and class names (cached) with optimizations"""
global model, class_names, model_load_time
start_time = time.time()
print(f"πŸ“¦ Loading model from {MODEL_FOLDER}/...")
try:
# Load YOLO segmentation model with optimizations
model_path = os.path.join(MODEL_FOLDER, "best.onnx")
# Set environment variables for ONNX optimization
os.environ['OMP_NUM_THREADS'] = '2' # 2 vCPU
os.environ['OMP_WAIT_POLICY'] = 'ACTIVE' # Active waiting for lower latency
model = YOLO(model_path, task='segment')
print(f" βœ“ Model loaded: {model_path}")
print(f" βœ“ CPU threads: 2 (OMP_NUM_THREADS)")
print(f" βœ“ Wait policy: ACTIVE (low latency)")
# Load class names
data_yaml_path = os.path.join(MODEL_FOLDER, "data.yaml")
with open(data_yaml_path, 'r') as f:
data = yaml.safe_load(f)
class_names = data['names']
print(f" βœ“ Classes loaded: {len(class_names)} classes")
model_load_time = time.time() - start_time
print(f" ⏱️ Load time: {model_load_time:.2f}s")
return True
except Exception as e:
print(f" βœ— Error loading model: {e}")
return False
def base64_to_image(base64_string):
"""Convert base64 string to OpenCV image"""
try:
# Remove data URL prefix if present
if ',' in base64_string:
base64_string = base64_string.split(',')[1]
# Decode base64
image_data = base64.b64decode(base64_string)
# Convert to PIL Image
pil_image = Image.open(io.BytesIO(image_data))
# Convert to OpenCV format
opencv_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
return opencv_image
except Exception as e:
print(f"Error converting base64 to image: {e}")
return None
def run_segmentation(img):
"""Run YOLO segmentation prediction (optimized)"""
global model
# Run prediction with optimizations
results = model.predict(
img,
conf=CONFIDENCE_THRESHOLD,
iou=IOU_THRESHOLD,
verbose=False,
half=False, # No FP16 on CPU
device='cpu', # Explicit CPU
max_det=300, # Limit detections for speed
agnostic_nms=False, # Class-specific NMS (faster)
retina_masks=False # Faster mask generation
)
if not results or len(results) == 0:
return [], [], [], [], img.shape[1], img.shape[0]
result = results[0]
# Extract results
if result.boxes is not None and len(result.boxes) > 0:
boxes = result.boxes.xyxy.cpu().numpy() # [x1, y1, x2, y2]
scores = result.boxes.conf.cpu().numpy()
class_ids = result.boxes.cls.cpu().numpy().astype(int)
else:
boxes = np.array([])
scores = np.array([])
class_ids = np.array([])
# Extract masks
if result.masks is not None and len(result.masks) > 0:
masks = result.masks.data.cpu().numpy() # [N, H, W]
else:
masks = np.array([])
orig_h, orig_w = img.shape[:2]
return boxes, scores, class_ids, masks, orig_w, orig_h
def normalize_text(text):
"""
Normalize challenge text with comprehensive synonym mapping
Maps: reCAPTCHA challenge text β†’ Model class names
Handles: singular/plural, synonyms, articles
"""
text = text.lower().strip()
# Remove articles "a " and "the " for better matching
text = text.replace('a ', '').replace('the ', '')
# Comprehensive mapping: challenge_text β†’ model_class (4x4)
# Model classes: 'a fire hydrant', 'bicycles', 'buses', 'cars', 'chimneys',
# 'crosswalks', 'motorcycles', 'parking meters', 'stairs',
# 'taxis', 'tractors', 'traffic lights'
mappings = {
# Fire hydrant (model: "fire hydrant" - already removed "a ")
'fire hydrant': 'fire hydrant',
'fire hydrants': 'fire hydrant',
'hydrant': 'fire hydrant',
'hydrants': 'fire hydrant',
# Bicycle β†’ bicycles (model: "bicycles")
'bicycle': 'bicycles',
'bike': 'bicycles',
'bikes': 'bicycles',
# Bus β†’ buses (model: "buses")
'bus': 'buses',
# Car β†’ cars (model: "cars")
'car': 'cars',
'vehicle': 'cars',
'vehicles': 'cars',
'automobile': 'cars',
'automobiles': 'cars',
'taxi': 'taxis', # Model has taxis
'cab': 'taxis',
'cabs': 'taxis',
# Chimney β†’ chimneys (model: "chimneys")
'chimney': 'chimneys',
# Crosswalk β†’ crosswalks (model: "crosswalks")
'crosswalk': 'crosswalks',
'pedestrian crossing': 'crosswalks',
'zebra crossing': 'crosswalks',
# Motorcycle β†’ motorcycles (model: "motorcycles")
'motorcycle': 'motorcycles',
'motorbike': 'motorcycles',
'motorbikes': 'motorcycles',
# Parking meters (model: "parking meters")
'parking meter': 'parking meters',
'parking metre': 'parking meters',
'parking metres': 'parking meters',
# IMPORTANT: Ladder β†’ Stairs (model: "stairs")
'ladder': 'stairs',
'ladders': 'stairs',
'stair': 'stairs',
'staircase': 'stairs',
'staircases': 'stairs',
'step': 'stairs',
'steps': 'stairs',
# IMPORTANT: Tractor β†’ Tractors (model: "tractors" plural!)
'tractor': 'tractors',
'farm tractor': 'tractors',
'farm tractors': 'tractors',
# Traffic light β†’ traffic lights (model: "traffic lights")
'traffic light': 'traffic lights',
'traffic signal': 'traffic lights',
'traffic signals': 'traffic lights',
# Boat variations (not in 4x4, but keep for fallback)
'boat': 'boats',
'boats': 'boats',
'ship': 'boats',
'ships': 'boats',
# Bridge variations (not in 4x4, but keep for fallback)
'bridge': 'bridges',
'bridges': 'bridges',
# Tree/Palm variations (not in 4x4 model!)
'tree': 'trees',
'trees': 'trees',
'palm': 'trees',
'palms': 'trees',
'palm tree': 'trees',
'palm trees': 'trees'
}
# Check for exact matches first
if text in mappings:
return mappings[text]
# Check for partial matches (e.g., "palm trees" contains "palm")
for challenge_variant, model_class in mappings.items():
if challenge_variant in text or text in challenge_variant:
return model_class
# No mapping found, return as-is
return text
def get_tiles_to_click(boxes, scores, class_ids, masks, challenge_title, img_width, img_height):
"""Map segmentation masks to 4x4 tiles (optimized)"""
if len(boxes) == 0 or not challenge_title:
return []
# Normalize challenge title
normalized_title = normalize_text(challenge_title)
# Calculate tile dimensions
tile_width = img_width / 4
tile_height = img_height / 4
# Pre-calculate tile boundaries (vectorized)
tile_boundaries = []
for row in range(4):
for col in range(4):
x1 = int(col * tile_width)
y1 = int(row * tile_height)
x2 = int((col + 1) * tile_width)
y2 = int((row + 1) * tile_height)
tile_boundaries.append((row * 4 + col, y1, y2, x1, x2))
# Store tiles to click
tiles_to_click = set()
# Process each detection
for i, (box, score, class_id, mask) in enumerate(zip(boxes, scores, class_ids, masks)):
# Get class name
det_class = class_names[class_id].lower()
# Also remove articles from detection class for consistent matching
det_class = det_class.replace('a ', '').replace('the ', '')
# Check if detection matches challenge
if normalized_title not in det_class and det_class not in normalized_title:
continue
# Resize mask to original image size (use INTER_NEAREST for speed)
mask_resized = cv2.resize(mask, (img_width, img_height), interpolation=cv2.INTER_NEAREST)
# Binarize mask once
mask_binary = (mask_resized > 0.5).astype(np.uint8)
# Check each tile
for tile_id, y1, y2, x1, x2 in tile_boundaries:
# Extract mask for this tile
tile_mask = mask_binary[y1:y2, x1:x2]
# Calculate percentage of tile covered by mask (vectorized)
tile_area = tile_mask.size
mask_area = np.sum(tile_mask)
coverage_percentage = (mask_area / tile_area) * 100
# If coverage above threshold, mark tile
if coverage_percentage >= MASK_THRESHOLD_PERCENTAGE:
tiles_to_click.add(tile_id)
return sorted(list(tiles_to_click))
@app.route('/health', methods=['GET'])
def health():
"""Health check endpoint (public - no API key required)"""
return jsonify({
'status': 'healthy',
'model_loaded': model is not None,
'security_enabled': API_SECRET_KEY is not None,
'model_load_time_s': model_load_time,
'requests_total': request_count,
'requests_successful': successful_count,
'requests_failed': failed_count,
'requests_unauthorized': unauthorized_count,
'avg_latency_s': total_latency / max(request_count, 1)
})
@app.route('/predict', methods=['POST'])
@require_api_key # πŸ” Require API key for prediction
def predict():
"""Main prediction endpoint (protected by API key)"""
global request_count, successful_count, failed_count, total_latency
start_time = time.time()
request_count += 1
try:
# Parse request
data = request.json
if not data or 'image' not in data:
failed_count += 1
return jsonify({'error': 'Missing image data'}), 400
challenge_title = data.get('challenge_title', '')
# Convert base64 to image
img = base64_to_image(data['image'])
if img is None:
failed_count += 1
return jsonify({'error': 'Invalid image data'}), 400
img_height, img_width = img.shape[:2]
# Run segmentation
boxes, scores, class_ids, masks, orig_w, orig_h = run_segmentation(img)
# Get tiles to click
tiles = get_tiles_to_click(boxes, scores, class_ids, masks, challenge_title, orig_w, orig_h)
# Calculate latency
latency = time.time() - start_time
total_latency += latency
successful_count += 1
return jsonify({
'success': True,
'tiles_to_click': tiles,
'num_detections': len(boxes),
'latency_s': round(latency, 3),
'challenge_title': challenge_title
})
except Exception as e:
failed_count += 1
latency = time.time() - start_time
total_latency += latency
print(f"Error in predict: {e}")
import traceback
traceback.print_exc()
return jsonify({
'success': False,
'error': str(e),
'latency_s': round(latency, 3)
}), 500
@app.route('/', methods=['GET'])
def index():
"""Root endpoint"""
return jsonify({
'name': 'reCAPTCHA 4x4 Segmentation API',
'version': '1.0.0',
'model': '4X4 YOLO Segmentation',
'endpoints': {
'POST /predict': 'Predict tiles to click',
'GET /health': 'Health check',
'GET /': 'This page'
}
})
if __name__ == '__main__':
print("\nπŸš€ Starting 4x4 Segmentation API...")
# Load model on startup
if load_model():
print("βœ… Model loaded successfully!\n")
else:
print("❌ Failed to load model!\n")
exit(1)
# Run Flask app
app.run(host='0.0.0.0', port=7860, debug=False)