|
|
import cv2
|
|
|
import numpy as np
|
|
|
import json
|
|
|
import os
|
|
|
from pathlib import Path
|
|
|
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
|
|
|
from fastapi.responses import JSONResponse
|
|
|
from typing import Dict, Any, Tuple, Optional, Union
|
|
|
import io
|
|
|
import aiohttp
|
|
|
import uvicorn
|
|
|
from urllib.parse import urlparse
|
|
|
|
|
|
|
|
|
|
|
|
def to_rgb(img: np.ndarray) -> Optional[np.ndarray]:
|
|
|
"""Converts image to BGR format (3 channels). Handles None input."""
|
|
|
if img is None:
|
|
|
return None
|
|
|
if len(img.shape) == 2:
|
|
|
|
|
|
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
|
|
if img.shape[2] == 4:
|
|
|
|
|
|
return cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
|
|
|
|
|
|
return img
|
|
|
|
|
|
def get_mask_from_alpha(template_img: np.ndarray) -> Optional[np.ndarray]:
|
|
|
"""Extracts a mask from the alpha channel of a 4-channel image."""
|
|
|
if template_img is not None and len(template_img.shape) == 3 and template_img.shape[2] == 4:
|
|
|
|
|
|
return (template_img[:, :, 3] > 0).astype(np.uint8) * 255
|
|
|
return None
|
|
|
|
|
|
def detect_cursor_in_frame_multi(
|
|
|
frame: np.ndarray,
|
|
|
cursor_templates: Dict[str, np.ndarray],
|
|
|
threshold: float = 0.8
|
|
|
) -> Tuple[Optional[Tuple[int, int]], float, Optional[str]]:
|
|
|
"""
|
|
|
Detects the best matching cursor template in a single frame.
|
|
|
Returns (position, confidence, template_name).
|
|
|
"""
|
|
|
best_pos = None
|
|
|
best_conf = -1.0
|
|
|
best_template_name = None
|
|
|
frame_rgb = to_rgb(frame)
|
|
|
|
|
|
if frame_rgb is None:
|
|
|
return None, -1.0, None
|
|
|
|
|
|
for template_name, cursor_template in cursor_templates.items():
|
|
|
template_rgb = to_rgb(cursor_template)
|
|
|
mask = get_mask_from_alpha(cursor_template)
|
|
|
|
|
|
if template_rgb is None or template_rgb.shape[2] != frame_rgb.shape[2]:
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
if template_rgb.shape[0] > frame_rgb.shape[0] or template_rgb.shape[1] > frame_rgb.shape[1]:
|
|
|
|
|
|
continue
|
|
|
|
|
|
try:
|
|
|
|
|
|
result = cv2.matchTemplate(frame_rgb, template_rgb, cv2.TM_CCOEFF_NORMED, mask=mask)
|
|
|
except Exception as e:
|
|
|
|
|
|
continue
|
|
|
|
|
|
_, max_val, _, max_loc = cv2.minMaxLoc(result)
|
|
|
|
|
|
if max_val > best_conf:
|
|
|
best_conf = max_val
|
|
|
if max_val >= threshold:
|
|
|
cursor_w, cursor_h = template_rgb.shape[1], template_rgb.shape[0]
|
|
|
|
|
|
cursor_x = max_loc[0] + cursor_w // 2
|
|
|
cursor_y = max_loc[1] + cursor_h // 2
|
|
|
best_pos = (cursor_x, cursor_y)
|
|
|
best_template_name = template_name
|
|
|
|
|
|
if best_conf >= threshold:
|
|
|
return best_pos, best_conf, best_template_name
|
|
|
return None, best_conf, None
|
|
|
|
|
|
async def download_image_from_url(url: str) -> bytes:
|
|
|
"""Download image from URL and return as bytes."""
|
|
|
async with aiohttp.ClientSession() as session:
|
|
|
async with session.get(url) as response:
|
|
|
if response.status != 200:
|
|
|
raise HTTPException(
|
|
|
status_code=400,
|
|
|
detail=f"Failed to fetch image from URL. Status code: {response.status}"
|
|
|
)
|
|
|
return await response.read()
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI(
|
|
|
title="Cursor Tracker API",
|
|
|
description="API to detect and track mouse cursors in uploaded images using template matching."
|
|
|
)
|
|
|
|
|
|
|
|
|
CURSOR_TEMPLATES: Dict[str, np.ndarray] = {}
|
|
|
CURSOR_TEMPLATES_DIR = Path("cursors")
|
|
|
|
|
|
def load_cursor_templates():
|
|
|
"""Loads all cursor templates from the specified directory."""
|
|
|
global CURSOR_TEMPLATES
|
|
|
if CURSOR_TEMPLATES:
|
|
|
print("Templates already loaded.")
|
|
|
return
|
|
|
|
|
|
print(f"Loading cursor templates from: {CURSOR_TEMPLATES_DIR}")
|
|
|
|
|
|
if not CURSOR_TEMPLATES_DIR.is_dir():
|
|
|
print(f"Error: Template directory not found at {CURSOR_TEMPLATES_DIR}")
|
|
|
return
|
|
|
|
|
|
for template_file in CURSOR_TEMPLATES_DIR.glob('*.png'):
|
|
|
|
|
|
template_img = cv2.imread(str(template_file), cv2.IMREAD_UNCHANGED)
|
|
|
if template_img is not None:
|
|
|
CURSOR_TEMPLATES[template_file.name] = template_img
|
|
|
else:
|
|
|
print(f"[WARN] Could not load template: {template_file.name}")
|
|
|
|
|
|
if not CURSOR_TEMPLATES:
|
|
|
print(f"FATAL: No cursor templates found in: {CURSOR_TEMPLATES_DIR}")
|
|
|
else:
|
|
|
print(f"Successfully loaded {len(CURSOR_TEMPLATES)} templates.")
|
|
|
|
|
|
@app.on_event("startup")
|
|
|
async def startup_event():
|
|
|
"""Load templates when the application starts."""
|
|
|
load_cursor_templates()
|
|
|
|
|
|
@app.get("/")
|
|
|
async def root():
|
|
|
"""Simple root endpoint for health check."""
|
|
|
return {"message": "Cursor Tracker API is running. Use /track_cursor to upload an image."}
|
|
|
|
|
|
@app.post("/track_cursor")
|
|
|
async def track_cursor_endpoint(
|
|
|
file: UploadFile = File(...),
|
|
|
threshold: float = Form(0.8)
|
|
|
):
|
|
|
"""
|
|
|
Accepts an image file and returns the detected cursor position and details.
|
|
|
"""
|
|
|
if not CURSOR_TEMPLATES:
|
|
|
raise HTTPException(
|
|
|
status_code=503,
|
|
|
detail="Cursor templates are not loaded. Server initialization failed."
|
|
|
)
|
|
|
|
|
|
|
|
|
content = await file.read()
|
|
|
|
|
|
|
|
|
np_array = np.frombuffer(content, np.uint8)
|
|
|
frame = cv2.imdecode(np_array, cv2.IMREAD_UNCHANGED)
|
|
|
|
|
|
if frame is None:
|
|
|
raise HTTPException(
|
|
|
status_code=400,
|
|
|
detail="Could not decode image file. Ensure it is a valid image format (e.g., PNG, JPEG)."
|
|
|
)
|
|
|
|
|
|
|
|
|
pos, conf, template_name = detect_cursor_in_frame_multi(frame, CURSOR_TEMPLATES, threshold)
|
|
|
|
|
|
|
|
|
print(f"pos: {pos}, type: {type(pos)}")
|
|
|
print(f"conf: {conf}, type: {type(conf)}")
|
|
|
print(f"template_name: {template_name}, type: {type(template_name)}")
|
|
|
print(f"frame.shape: {frame.shape}, type: {type(frame.shape)}")
|
|
|
|
|
|
|
|
|
|
|
|
confidence = float(conf)
|
|
|
if not (confidence == float('inf') or confidence == float('-inf')):
|
|
|
confidence_val = confidence
|
|
|
else:
|
|
|
confidence_val = 1.0 if confidence > 0 else 0.0
|
|
|
|
|
|
if pos is not None:
|
|
|
response_data = {
|
|
|
'cursor_active': True,
|
|
|
'x': pos[0],
|
|
|
'y': pos[1],
|
|
|
'confidence': confidence_val,
|
|
|
'template': template_name,
|
|
|
'image_shape': list(frame.shape)
|
|
|
}
|
|
|
else:
|
|
|
response_data = {
|
|
|
'cursor_active': False,
|
|
|
'x': None,
|
|
|
'y': None,
|
|
|
'confidence': confidence_val,
|
|
|
'template': None,
|
|
|
'image_shape': list(frame.shape)
|
|
|
}
|
|
|
|
|
|
return JSONResponse(content=response_data)
|
|
|
|
|
|
|
|
|
@app.post("/track_cursor_url")
|
|
|
async def track_cursor_url_endpoint(
|
|
|
image_url: str = Form(...),
|
|
|
threshold: float = Form(0.8)
|
|
|
):
|
|
|
"""
|
|
|
Accepts an image URL and returns the detected cursor position and details.
|
|
|
"""
|
|
|
if not CURSOR_TEMPLATES:
|
|
|
raise HTTPException(
|
|
|
status_code=503,
|
|
|
detail="Cursor templates are not loaded. Server initialization failed."
|
|
|
)
|
|
|
|
|
|
try:
|
|
|
|
|
|
parsed_url = urlparse(image_url)
|
|
|
if not all([parsed_url.scheme, parsed_url.netloc]):
|
|
|
raise HTTPException(
|
|
|
status_code=400,
|
|
|
detail="Invalid URL provided"
|
|
|
)
|
|
|
|
|
|
|
|
|
content = await download_image_from_url(image_url)
|
|
|
|
|
|
|
|
|
np_array = np.frombuffer(content, np.uint8)
|
|
|
frame = cv2.imdecode(np_array, cv2.IMREAD_UNCHANGED)
|
|
|
|
|
|
if frame is None:
|
|
|
raise HTTPException(
|
|
|
status_code=400,
|
|
|
detail="Could not decode image from URL. Ensure it is a valid image format (e.g., PNG, JPEG)."
|
|
|
)
|
|
|
|
|
|
|
|
|
pos, conf, template_name = detect_cursor_in_frame_multi(frame, CURSOR_TEMPLATES, threshold)
|
|
|
|
|
|
|
|
|
if pos is not None:
|
|
|
response_data = {
|
|
|
'cursor_active': True,
|
|
|
'x': pos[0],
|
|
|
'y': pos[1],
|
|
|
'confidence': float(conf),
|
|
|
'template': template_name,
|
|
|
'image_shape': list(frame.shape),
|
|
|
'source_url': image_url
|
|
|
}
|
|
|
else:
|
|
|
response_data = {
|
|
|
'cursor_active': False,
|
|
|
'x': None,
|
|
|
'y': None,
|
|
|
'confidence': float(conf),
|
|
|
'template': None,
|
|
|
'image_shape': list(frame.shape),
|
|
|
'source_url': image_url
|
|
|
}
|
|
|
|
|
|
return JSONResponse(content=response_data)
|
|
|
|
|
|
except aiohttp.ClientError as e:
|
|
|
raise HTTPException(
|
|
|
status_code=400,
|
|
|
detail=f"Failed to fetch image from URL: {str(e)}"
|
|
|
)
|
|
|
except Exception as e:
|
|
|
raise HTTPException(
|
|
|
status_code=500,
|
|
|
detail=f"An error occurred while processing the image: {str(e)}"
|
|
|
)
|
|
|
|
|
|
@app.get("/templates")
|
|
|
async def list_templates():
|
|
|
"""Returns a list of all loaded cursor template names."""
|
|
|
return {"templates": list(CURSOR_TEMPLATES.keys()), "count": len(CURSOR_TEMPLATES)}
|
|
|
|
|
|
port = int(os.environ.get("PORT", 7860))
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
import uvicorn
|
|
|
uvicorn.run(app, host="0.0.0.0", port=port, timeout_keep_alive=75) |