pia-13 / app.py
eliason1's picture
Upload app.py with huggingface_hub
a653106 verified
import cv2
import numpy as np
import json
import os
from pathlib import Path
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import JSONResponse
from typing import Dict, Any, Tuple, Optional, Union
import io
import aiohttp
import uvicorn
from urllib.parse import urlparse
# --- Original Cursor Detection Functions (Adapted for Server) ---
def to_rgb(img: np.ndarray) -> Optional[np.ndarray]:
"""Converts image to BGR format (3 channels). Handles None input."""
if img is None:
return None
if len(img.shape) == 2:
# Grayscale to BGR
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
if img.shape[2] == 4:
# BGRA to BGR (removes alpha channel)
return cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
# Already BGR or RGB (assuming OpenCV reads as BGR)
return img
def get_mask_from_alpha(template_img: np.ndarray) -> Optional[np.ndarray]:
"""Extracts a mask from the alpha channel of a 4-channel image."""
if template_img is not None and len(template_img.shape) == 3 and template_img.shape[2] == 4:
# Create a mask where alpha is greater than 0
return (template_img[:, :, 3] > 0).astype(np.uint8) * 255
return None
def detect_cursor_in_frame_multi(
frame: np.ndarray,
cursor_templates: Dict[str, np.ndarray],
threshold: float = 0.8
) -> Tuple[Optional[Tuple[int, int]], float, Optional[str]]:
"""
Detects the best matching cursor template in a single frame.
Returns (position, confidence, template_name).
"""
best_pos = None
best_conf = -1.0
best_template_name = None
frame_rgb = to_rgb(frame)
if frame_rgb is None:
return None, -1.0, None
for template_name, cursor_template in cursor_templates.items():
template_rgb = to_rgb(cursor_template)
mask = get_mask_from_alpha(cursor_template)
if template_rgb is None or template_rgb.shape[2] != frame_rgb.shape[2]:
# print(f"[WARN] Skipping template {template_name} due to channel mismatch or load error.")
continue
# Ensure template is smaller than or equal to the frame
if template_rgb.shape[0] > frame_rgb.shape[0] or template_rgb.shape[1] > frame_rgb.shape[1]:
# print(f"[WARN] Skipping template {template_name}: template larger than frame.")
continue
try:
# Match template. Use mask for non-rectangular templates.
result = cv2.matchTemplate(frame_rgb, template_rgb, cv2.TM_CCOEFF_NORMED, mask=mask)
except Exception as e:
# print(f"[WARN] matchTemplate failed for {template_name}: {e}")
continue
_, max_val, _, max_loc = cv2.minMaxLoc(result)
if max_val > best_conf:
best_conf = max_val
if max_val >= threshold:
cursor_w, cursor_h = template_rgb.shape[1], template_rgb.shape[0]
# Calculate center position of the detected area
cursor_x = max_loc[0] + cursor_w // 2
cursor_y = max_loc[1] + cursor_h // 2
best_pos = (cursor_x, cursor_y)
best_template_name = template_name
if best_conf >= threshold:
return best_pos, best_conf, best_template_name
return None, best_conf, None
async def download_image_from_url(url: str) -> bytes:
"""Download image from URL and return as bytes."""
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
if response.status != 200:
raise HTTPException(
status_code=400,
detail=f"Failed to fetch image from URL. Status code: {response.status}"
)
return await response.read()
# --- Server Setup ---
app = FastAPI(
title="Cursor Tracker API",
description="API to detect and track mouse cursors in uploaded images using template matching."
)
# Global variable to store loaded templates
CURSOR_TEMPLATES: Dict[str, np.ndarray] = {}
CURSOR_TEMPLATES_DIR = Path("cursors")
def load_cursor_templates():
"""Loads all cursor templates from the specified directory."""
global CURSOR_TEMPLATES
if CURSOR_TEMPLATES:
print("Templates already loaded.")
return
print(f"Loading cursor templates from: {CURSOR_TEMPLATES_DIR}")
if not CURSOR_TEMPLATES_DIR.is_dir():
print(f"Error: Template directory not found at {CURSOR_TEMPLATES_DIR}")
return
for template_file in CURSOR_TEMPLATES_DIR.glob('*.png'):
# Load image with alpha channel (IMREAD_UNCHANGED)
template_img = cv2.imread(str(template_file), cv2.IMREAD_UNCHANGED)
if template_img is not None:
CURSOR_TEMPLATES[template_file.name] = template_img
else:
print(f"[WARN] Could not load template: {template_file.name}")
if not CURSOR_TEMPLATES:
print(f"FATAL: No cursor templates found in: {CURSOR_TEMPLATES_DIR}")
else:
print(f"Successfully loaded {len(CURSOR_TEMPLATES)} templates.")
@app.on_event("startup")
async def startup_event():
"""Load templates when the application starts."""
load_cursor_templates()
@app.get("/")
async def root():
"""Simple root endpoint for health check."""
return {"message": "Cursor Tracker API is running. Use /track_cursor to upload an image."}
@app.post("/track_cursor")
async def track_cursor_endpoint(
file: UploadFile = File(...),
threshold: float = Form(0.8)
):
"""
Accepts an image file and returns the detected cursor position and details.
"""
if not CURSOR_TEMPLATES:
raise HTTPException(
status_code=503,
detail="Cursor templates are not loaded. Server initialization failed."
)
# 1. Read image file content
content = await file.read()
# 2. Convert file content to OpenCV image format
np_array = np.frombuffer(content, np.uint8)
frame = cv2.imdecode(np_array, cv2.IMREAD_UNCHANGED)
if frame is None:
raise HTTPException(
status_code=400,
detail="Could not decode image file. Ensure it is a valid image format (e.g., PNG, JPEG)."
)
# 3. Detect cursor
pos, conf, template_name = detect_cursor_in_frame_multi(frame, CURSOR_TEMPLATES, threshold)
# 4. Log values for debugging
print(f"pos: {pos}, type: {type(pos)}")
print(f"conf: {conf}, type: {type(conf)}")
print(f"template_name: {template_name}, type: {type(template_name)}")
print(f"frame.shape: {frame.shape}, type: {type(frame.shape)}")
# 5. Prepare response
# Handle infinite confidence values
confidence = float(conf)
if not (confidence == float('inf') or confidence == float('-inf')):
confidence_val = confidence
else:
confidence_val = 1.0 if confidence > 0 else 0.0
if pos is not None:
response_data = {
'cursor_active': True,
'x': pos[0],
'y': pos[1],
'confidence': confidence_val,
'template': template_name,
'image_shape': list(frame.shape)
}
else:
response_data = {
'cursor_active': False,
'x': None,
'y': None,
'confidence': confidence_val,
'template': None,
'image_shape': list(frame.shape)
}
return JSONResponse(content=response_data)
# Optional: Endpoint to get a list of loaded templates
@app.post("/track_cursor_url")
async def track_cursor_url_endpoint(
image_url: str = Form(...),
threshold: float = Form(0.8)
):
"""
Accepts an image URL and returns the detected cursor position and details.
"""
if not CURSOR_TEMPLATES:
raise HTTPException(
status_code=503,
detail="Cursor templates are not loaded. Server initialization failed."
)
try:
# Validate URL
parsed_url = urlparse(image_url)
if not all([parsed_url.scheme, parsed_url.netloc]):
raise HTTPException(
status_code=400,
detail="Invalid URL provided"
)
# Download image
content = await download_image_from_url(image_url)
# Convert to OpenCV format
np_array = np.frombuffer(content, np.uint8)
frame = cv2.imdecode(np_array, cv2.IMREAD_UNCHANGED)
if frame is None:
raise HTTPException(
status_code=400,
detail="Could not decode image from URL. Ensure it is a valid image format (e.g., PNG, JPEG)."
)
# Detect cursor
pos, conf, template_name = detect_cursor_in_frame_multi(frame, CURSOR_TEMPLATES, threshold)
# Prepare response
if pos is not None:
response_data = {
'cursor_active': True,
'x': pos[0],
'y': pos[1],
'confidence': float(conf),
'template': template_name,
'image_shape': list(frame.shape),
'source_url': image_url
}
else:
response_data = {
'cursor_active': False,
'x': None,
'y': None,
'confidence': float(conf),
'template': None,
'image_shape': list(frame.shape),
'source_url': image_url
}
return JSONResponse(content=response_data)
except aiohttp.ClientError as e:
raise HTTPException(
status_code=400,
detail=f"Failed to fetch image from URL: {str(e)}"
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An error occurred while processing the image: {str(e)}"
)
@app.get("/templates")
async def list_templates():
"""Returns a list of all loaded cursor template names."""
return {"templates": list(CURSOR_TEMPLATES.keys()), "count": len(CURSOR_TEMPLATES)}
port = int(os.environ.get("PORT", 7860))
# Launch FastAPI with uvicorn when run directly
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=port, timeout_keep_alive=75)