import cv2 import numpy as np import json import os from pathlib import Path from fastapi import FastAPI, File, UploadFile, Form, HTTPException from fastapi.responses import JSONResponse from typing import Dict, Any, Tuple, Optional, Union import io import aiohttp import uvicorn from urllib.parse import urlparse # --- Original Cursor Detection Functions (Adapted for Server) --- def to_rgb(img: np.ndarray) -> Optional[np.ndarray]: """Converts image to BGR format (3 channels). Handles None input.""" if img is None: return None if len(img.shape) == 2: # Grayscale to BGR return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) if img.shape[2] == 4: # BGRA to BGR (removes alpha channel) return cv2.cvtColor(img, cv2.COLOR_BGRA2BGR) # Already BGR or RGB (assuming OpenCV reads as BGR) return img def get_mask_from_alpha(template_img: np.ndarray) -> Optional[np.ndarray]: """Extracts a mask from the alpha channel of a 4-channel image.""" if template_img is not None and len(template_img.shape) == 3 and template_img.shape[2] == 4: # Create a mask where alpha is greater than 0 return (template_img[:, :, 3] > 0).astype(np.uint8) * 255 return None def detect_cursor_in_frame_multi( frame: np.ndarray, cursor_templates: Dict[str, np.ndarray], threshold: float = 0.8 ) -> Tuple[Optional[Tuple[int, int]], float, Optional[str]]: """ Detects the best matching cursor template in a single frame. Returns (position, confidence, template_name). """ best_pos = None best_conf = -1.0 best_template_name = None frame_rgb = to_rgb(frame) if frame_rgb is None: return None, -1.0, None for template_name, cursor_template in cursor_templates.items(): template_rgb = to_rgb(cursor_template) mask = get_mask_from_alpha(cursor_template) if template_rgb is None or template_rgb.shape[2] != frame_rgb.shape[2]: # print(f"[WARN] Skipping template {template_name} due to channel mismatch or load error.") continue # Ensure template is smaller than or equal to the frame if template_rgb.shape[0] > frame_rgb.shape[0] or template_rgb.shape[1] > frame_rgb.shape[1]: # print(f"[WARN] Skipping template {template_name}: template larger than frame.") continue try: # Match template. Use mask for non-rectangular templates. result = cv2.matchTemplate(frame_rgb, template_rgb, cv2.TM_CCOEFF_NORMED, mask=mask) except Exception as e: # print(f"[WARN] matchTemplate failed for {template_name}: {e}") continue _, max_val, _, max_loc = cv2.minMaxLoc(result) if max_val > best_conf: best_conf = max_val if max_val >= threshold: cursor_w, cursor_h = template_rgb.shape[1], template_rgb.shape[0] # Calculate center position of the detected area cursor_x = max_loc[0] + cursor_w // 2 cursor_y = max_loc[1] + cursor_h // 2 best_pos = (cursor_x, cursor_y) best_template_name = template_name if best_conf >= threshold: return best_pos, best_conf, best_template_name return None, best_conf, None async def download_image_from_url(url: str) -> bytes: """Download image from URL and return as bytes.""" async with aiohttp.ClientSession() as session: async with session.get(url) as response: if response.status != 200: raise HTTPException( status_code=400, detail=f"Failed to fetch image from URL. Status code: {response.status}" ) return await response.read() # --- Server Setup --- app = FastAPI( title="Cursor Tracker API", description="API to detect and track mouse cursors in uploaded images using template matching." ) # Global variable to store loaded templates CURSOR_TEMPLATES: Dict[str, np.ndarray] = {} CURSOR_TEMPLATES_DIR = Path("cursors") def load_cursor_templates(): """Loads all cursor templates from the specified directory.""" global CURSOR_TEMPLATES if CURSOR_TEMPLATES: print("Templates already loaded.") return print(f"Loading cursor templates from: {CURSOR_TEMPLATES_DIR}") if not CURSOR_TEMPLATES_DIR.is_dir(): print(f"Error: Template directory not found at {CURSOR_TEMPLATES_DIR}") return for template_file in CURSOR_TEMPLATES_DIR.glob('*.png'): # Load image with alpha channel (IMREAD_UNCHANGED) template_img = cv2.imread(str(template_file), cv2.IMREAD_UNCHANGED) if template_img is not None: CURSOR_TEMPLATES[template_file.name] = template_img else: print(f"[WARN] Could not load template: {template_file.name}") if not CURSOR_TEMPLATES: print(f"FATAL: No cursor templates found in: {CURSOR_TEMPLATES_DIR}") else: print(f"Successfully loaded {len(CURSOR_TEMPLATES)} templates.") @app.on_event("startup") async def startup_event(): """Load templates when the application starts.""" load_cursor_templates() @app.get("/") async def root(): """Simple root endpoint for health check.""" return {"message": "Cursor Tracker API is running. Use /track_cursor to upload an image."} @app.post("/track_cursor") async def track_cursor_endpoint( file: UploadFile = File(...), threshold: float = Form(0.8) ): """ Accepts an image file and returns the detected cursor position and details. """ if not CURSOR_TEMPLATES: raise HTTPException( status_code=503, detail="Cursor templates are not loaded. Server initialization failed." ) # 1. Read image file content content = await file.read() # 2. Convert file content to OpenCV image format np_array = np.frombuffer(content, np.uint8) frame = cv2.imdecode(np_array, cv2.IMREAD_UNCHANGED) if frame is None: raise HTTPException( status_code=400, detail="Could not decode image file. Ensure it is a valid image format (e.g., PNG, JPEG)." ) # 3. Detect cursor pos, conf, template_name = detect_cursor_in_frame_multi(frame, CURSOR_TEMPLATES, threshold) # 4. Log values for debugging print(f"pos: {pos}, type: {type(pos)}") print(f"conf: {conf}, type: {type(conf)}") print(f"template_name: {template_name}, type: {type(template_name)}") print(f"frame.shape: {frame.shape}, type: {type(frame.shape)}") # 5. Prepare response # Handle infinite confidence values confidence = float(conf) if not (confidence == float('inf') or confidence == float('-inf')): confidence_val = confidence else: confidence_val = 1.0 if confidence > 0 else 0.0 if pos is not None: response_data = { 'cursor_active': True, 'x': pos[0], 'y': pos[1], 'confidence': confidence_val, 'template': template_name, 'image_shape': list(frame.shape) } else: response_data = { 'cursor_active': False, 'x': None, 'y': None, 'confidence': confidence_val, 'template': None, 'image_shape': list(frame.shape) } return JSONResponse(content=response_data) # Optional: Endpoint to get a list of loaded templates @app.post("/track_cursor_url") async def track_cursor_url_endpoint( image_url: str = Form(...), threshold: float = Form(0.8) ): """ Accepts an image URL and returns the detected cursor position and details. """ if not CURSOR_TEMPLATES: raise HTTPException( status_code=503, detail="Cursor templates are not loaded. Server initialization failed." ) try: # Validate URL parsed_url = urlparse(image_url) if not all([parsed_url.scheme, parsed_url.netloc]): raise HTTPException( status_code=400, detail="Invalid URL provided" ) # Download image content = await download_image_from_url(image_url) # Convert to OpenCV format np_array = np.frombuffer(content, np.uint8) frame = cv2.imdecode(np_array, cv2.IMREAD_UNCHANGED) if frame is None: raise HTTPException( status_code=400, detail="Could not decode image from URL. Ensure it is a valid image format (e.g., PNG, JPEG)." ) # Detect cursor pos, conf, template_name = detect_cursor_in_frame_multi(frame, CURSOR_TEMPLATES, threshold) # Prepare response if pos is not None: response_data = { 'cursor_active': True, 'x': pos[0], 'y': pos[1], 'confidence': float(conf), 'template': template_name, 'image_shape': list(frame.shape), 'source_url': image_url } else: response_data = { 'cursor_active': False, 'x': None, 'y': None, 'confidence': float(conf), 'template': None, 'image_shape': list(frame.shape), 'source_url': image_url } return JSONResponse(content=response_data) except aiohttp.ClientError as e: raise HTTPException( status_code=400, detail=f"Failed to fetch image from URL: {str(e)}" ) except Exception as e: raise HTTPException( status_code=500, detail=f"An error occurred while processing the image: {str(e)}" ) @app.get("/templates") async def list_templates(): """Returns a list of all loaded cursor template names.""" return {"templates": list(CURSOR_TEMPLATES.keys()), "count": len(CURSOR_TEMPLATES)} port = int(os.environ.get("PORT", 7860)) # Launch FastAPI with uvicorn when run directly if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=port, timeout_keep_alive=75)