File size: 10,609 Bytes
a653106 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 |
import cv2
import numpy as np
import json
import os
from pathlib import Path
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import JSONResponse
from typing import Dict, Any, Tuple, Optional, Union
import io
import aiohttp
import uvicorn
from urllib.parse import urlparse
# --- Original Cursor Detection Functions (Adapted for Server) ---
def to_rgb(img: np.ndarray) -> Optional[np.ndarray]:
"""Converts image to BGR format (3 channels). Handles None input."""
if img is None:
return None
if len(img.shape) == 2:
# Grayscale to BGR
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
if img.shape[2] == 4:
# BGRA to BGR (removes alpha channel)
return cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
# Already BGR or RGB (assuming OpenCV reads as BGR)
return img
def get_mask_from_alpha(template_img: np.ndarray) -> Optional[np.ndarray]:
"""Extracts a mask from the alpha channel of a 4-channel image."""
if template_img is not None and len(template_img.shape) == 3 and template_img.shape[2] == 4:
# Create a mask where alpha is greater than 0
return (template_img[:, :, 3] > 0).astype(np.uint8) * 255
return None
def detect_cursor_in_frame_multi(
frame: np.ndarray,
cursor_templates: Dict[str, np.ndarray],
threshold: float = 0.8
) -> Tuple[Optional[Tuple[int, int]], float, Optional[str]]:
"""
Detects the best matching cursor template in a single frame.
Returns (position, confidence, template_name).
"""
best_pos = None
best_conf = -1.0
best_template_name = None
frame_rgb = to_rgb(frame)
if frame_rgb is None:
return None, -1.0, None
for template_name, cursor_template in cursor_templates.items():
template_rgb = to_rgb(cursor_template)
mask = get_mask_from_alpha(cursor_template)
if template_rgb is None or template_rgb.shape[2] != frame_rgb.shape[2]:
# print(f"[WARN] Skipping template {template_name} due to channel mismatch or load error.")
continue
# Ensure template is smaller than or equal to the frame
if template_rgb.shape[0] > frame_rgb.shape[0] or template_rgb.shape[1] > frame_rgb.shape[1]:
# print(f"[WARN] Skipping template {template_name}: template larger than frame.")
continue
try:
# Match template. Use mask for non-rectangular templates.
result = cv2.matchTemplate(frame_rgb, template_rgb, cv2.TM_CCOEFF_NORMED, mask=mask)
except Exception as e:
# print(f"[WARN] matchTemplate failed for {template_name}: {e}")
continue
_, max_val, _, max_loc = cv2.minMaxLoc(result)
if max_val > best_conf:
best_conf = max_val
if max_val >= threshold:
cursor_w, cursor_h = template_rgb.shape[1], template_rgb.shape[0]
# Calculate center position of the detected area
cursor_x = max_loc[0] + cursor_w // 2
cursor_y = max_loc[1] + cursor_h // 2
best_pos = (cursor_x, cursor_y)
best_template_name = template_name
if best_conf >= threshold:
return best_pos, best_conf, best_template_name
return None, best_conf, None
async def download_image_from_url(url: str) -> bytes:
"""Download image from URL and return as bytes."""
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
if response.status != 200:
raise HTTPException(
status_code=400,
detail=f"Failed to fetch image from URL. Status code: {response.status}"
)
return await response.read()
# --- Server Setup ---
app = FastAPI(
title="Cursor Tracker API",
description="API to detect and track mouse cursors in uploaded images using template matching."
)
# Global variable to store loaded templates
CURSOR_TEMPLATES: Dict[str, np.ndarray] = {}
CURSOR_TEMPLATES_DIR = Path("cursors")
def load_cursor_templates():
"""Loads all cursor templates from the specified directory."""
global CURSOR_TEMPLATES
if CURSOR_TEMPLATES:
print("Templates already loaded.")
return
print(f"Loading cursor templates from: {CURSOR_TEMPLATES_DIR}")
if not CURSOR_TEMPLATES_DIR.is_dir():
print(f"Error: Template directory not found at {CURSOR_TEMPLATES_DIR}")
return
for template_file in CURSOR_TEMPLATES_DIR.glob('*.png'):
# Load image with alpha channel (IMREAD_UNCHANGED)
template_img = cv2.imread(str(template_file), cv2.IMREAD_UNCHANGED)
if template_img is not None:
CURSOR_TEMPLATES[template_file.name] = template_img
else:
print(f"[WARN] Could not load template: {template_file.name}")
if not CURSOR_TEMPLATES:
print(f"FATAL: No cursor templates found in: {CURSOR_TEMPLATES_DIR}")
else:
print(f"Successfully loaded {len(CURSOR_TEMPLATES)} templates.")
@app.on_event("startup")
async def startup_event():
"""Load templates when the application starts."""
load_cursor_templates()
@app.get("/")
async def root():
"""Simple root endpoint for health check."""
return {"message": "Cursor Tracker API is running. Use /track_cursor to upload an image."}
@app.post("/track_cursor")
async def track_cursor_endpoint(
file: UploadFile = File(...),
threshold: float = Form(0.8)
):
"""
Accepts an image file and returns the detected cursor position and details.
"""
if not CURSOR_TEMPLATES:
raise HTTPException(
status_code=503,
detail="Cursor templates are not loaded. Server initialization failed."
)
# 1. Read image file content
content = await file.read()
# 2. Convert file content to OpenCV image format
np_array = np.frombuffer(content, np.uint8)
frame = cv2.imdecode(np_array, cv2.IMREAD_UNCHANGED)
if frame is None:
raise HTTPException(
status_code=400,
detail="Could not decode image file. Ensure it is a valid image format (e.g., PNG, JPEG)."
)
# 3. Detect cursor
pos, conf, template_name = detect_cursor_in_frame_multi(frame, CURSOR_TEMPLATES, threshold)
# 4. Log values for debugging
print(f"pos: {pos}, type: {type(pos)}")
print(f"conf: {conf}, type: {type(conf)}")
print(f"template_name: {template_name}, type: {type(template_name)}")
print(f"frame.shape: {frame.shape}, type: {type(frame.shape)}")
# 5. Prepare response
# Handle infinite confidence values
confidence = float(conf)
if not (confidence == float('inf') or confidence == float('-inf')):
confidence_val = confidence
else:
confidence_val = 1.0 if confidence > 0 else 0.0
if pos is not None:
response_data = {
'cursor_active': True,
'x': pos[0],
'y': pos[1],
'confidence': confidence_val,
'template': template_name,
'image_shape': list(frame.shape)
}
else:
response_data = {
'cursor_active': False,
'x': None,
'y': None,
'confidence': confidence_val,
'template': None,
'image_shape': list(frame.shape)
}
return JSONResponse(content=response_data)
# Optional: Endpoint to get a list of loaded templates
@app.post("/track_cursor_url")
async def track_cursor_url_endpoint(
image_url: str = Form(...),
threshold: float = Form(0.8)
):
"""
Accepts an image URL and returns the detected cursor position and details.
"""
if not CURSOR_TEMPLATES:
raise HTTPException(
status_code=503,
detail="Cursor templates are not loaded. Server initialization failed."
)
try:
# Validate URL
parsed_url = urlparse(image_url)
if not all([parsed_url.scheme, parsed_url.netloc]):
raise HTTPException(
status_code=400,
detail="Invalid URL provided"
)
# Download image
content = await download_image_from_url(image_url)
# Convert to OpenCV format
np_array = np.frombuffer(content, np.uint8)
frame = cv2.imdecode(np_array, cv2.IMREAD_UNCHANGED)
if frame is None:
raise HTTPException(
status_code=400,
detail="Could not decode image from URL. Ensure it is a valid image format (e.g., PNG, JPEG)."
)
# Detect cursor
pos, conf, template_name = detect_cursor_in_frame_multi(frame, CURSOR_TEMPLATES, threshold)
# Prepare response
if pos is not None:
response_data = {
'cursor_active': True,
'x': pos[0],
'y': pos[1],
'confidence': float(conf),
'template': template_name,
'image_shape': list(frame.shape),
'source_url': image_url
}
else:
response_data = {
'cursor_active': False,
'x': None,
'y': None,
'confidence': float(conf),
'template': None,
'image_shape': list(frame.shape),
'source_url': image_url
}
return JSONResponse(content=response_data)
except aiohttp.ClientError as e:
raise HTTPException(
status_code=400,
detail=f"Failed to fetch image from URL: {str(e)}"
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An error occurred while processing the image: {str(e)}"
)
@app.get("/templates")
async def list_templates():
"""Returns a list of all loaded cursor template names."""
return {"templates": list(CURSOR_TEMPLATES.keys()), "count": len(CURSOR_TEMPLATES)}
port = int(os.environ.get("PORT", 7860))
# Launch FastAPI with uvicorn when run directly
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=port, timeout_keep_alive=75) |