File size: 10,609 Bytes
a653106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
import cv2
import numpy as np
import json
import os
from pathlib import Path
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import JSONResponse
from typing import Dict, Any, Tuple, Optional, Union
import io
import aiohttp
import uvicorn
from urllib.parse import urlparse

# --- Original Cursor Detection Functions (Adapted for Server) ---

def to_rgb(img: np.ndarray) -> Optional[np.ndarray]:
    """Converts image to BGR format (3 channels). Handles None input."""
    if img is None:
        return None
    if len(img.shape) == 2:
        # Grayscale to BGR
        return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
    if img.shape[2] == 4:
        # BGRA to BGR (removes alpha channel)
        return cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
    # Already BGR or RGB (assuming OpenCV reads as BGR)
    return img

def get_mask_from_alpha(template_img: np.ndarray) -> Optional[np.ndarray]:
    """Extracts a mask from the alpha channel of a 4-channel image."""
    if template_img is not None and len(template_img.shape) == 3 and template_img.shape[2] == 4:
        # Create a mask where alpha is greater than 0
        return (template_img[:, :, 3] > 0).astype(np.uint8) * 255
    return None

def detect_cursor_in_frame_multi(

    frame: np.ndarray, 

    cursor_templates: Dict[str, np.ndarray], 

    threshold: float = 0.8

) -> Tuple[Optional[Tuple[int, int]], float, Optional[str]]:
    """

    Detects the best matching cursor template in a single frame.

    Returns (position, confidence, template_name).

    """
    best_pos = None
    best_conf = -1.0
    best_template_name = None
    frame_rgb = to_rgb(frame)
    
    if frame_rgb is None:
        return None, -1.0, None

    for template_name, cursor_template in cursor_templates.items():
        template_rgb = to_rgb(cursor_template)
        mask = get_mask_from_alpha(cursor_template)
        
        if template_rgb is None or template_rgb.shape[2] != frame_rgb.shape[2]:
            # print(f"[WARN] Skipping template {template_name} due to channel mismatch or load error.")
            continue
        
        # Ensure template is smaller than or equal to the frame
        if template_rgb.shape[0] > frame_rgb.shape[0] or template_rgb.shape[1] > frame_rgb.shape[1]:
            # print(f"[WARN] Skipping template {template_name}: template larger than frame.")
            continue

        try:
            # Match template. Use mask for non-rectangular templates.
            result = cv2.matchTemplate(frame_rgb, template_rgb, cv2.TM_CCOEFF_NORMED, mask=mask)
        except Exception as e:
            # print(f"[WARN] matchTemplate failed for {template_name}: {e}")
            continue
        
        _, max_val, _, max_loc = cv2.minMaxLoc(result)
        
        if max_val > best_conf:
            best_conf = max_val
            if max_val >= threshold:
                cursor_w, cursor_h = template_rgb.shape[1], template_rgb.shape[0]
                # Calculate center position of the detected area
                cursor_x = max_loc[0] + cursor_w // 2
                cursor_y = max_loc[1] + cursor_h // 2
                best_pos = (cursor_x, cursor_y)
                best_template_name = template_name

    if best_conf >= threshold:
        return best_pos, best_conf, best_template_name
    return None, best_conf, None

async def download_image_from_url(url: str) -> bytes:
    """Download image from URL and return as bytes."""
    async with aiohttp.ClientSession() as session:
        async with session.get(url) as response:
            if response.status != 200:
                raise HTTPException(
                    status_code=400,
                    detail=f"Failed to fetch image from URL. Status code: {response.status}"
                )
            return await response.read()

# --- Server Setup ---

app = FastAPI(
    title="Cursor Tracker API",
    description="API to detect and track mouse cursors in uploaded images using template matching."
)

# Global variable to store loaded templates
CURSOR_TEMPLATES: Dict[str, np.ndarray] = {}
CURSOR_TEMPLATES_DIR = Path("cursors")

def load_cursor_templates():
    """Loads all cursor templates from the specified directory."""
    global CURSOR_TEMPLATES
    if CURSOR_TEMPLATES:
        print("Templates already loaded.")
        return

    print(f"Loading cursor templates from: {CURSOR_TEMPLATES_DIR}")
    
    if not CURSOR_TEMPLATES_DIR.is_dir():
        print(f"Error: Template directory not found at {CURSOR_TEMPLATES_DIR}")
        return

    for template_file in CURSOR_TEMPLATES_DIR.glob('*.png'):
        # Load image with alpha channel (IMREAD_UNCHANGED)
        template_img = cv2.imread(str(template_file), cv2.IMREAD_UNCHANGED)
        if template_img is not None:
            CURSOR_TEMPLATES[template_file.name] = template_img
        else:
            print(f"[WARN] Could not load template: {template_file.name}")

    if not CURSOR_TEMPLATES:
        print(f"FATAL: No cursor templates found in: {CURSOR_TEMPLATES_DIR}")
    else:
        print(f"Successfully loaded {len(CURSOR_TEMPLATES)} templates.")

@app.on_event("startup")
async def startup_event():
    """Load templates when the application starts."""
    load_cursor_templates()

@app.get("/")
async def root():
    """Simple root endpoint for health check."""
    return {"message": "Cursor Tracker API is running. Use /track_cursor to upload an image."}

@app.post("/track_cursor")
async def track_cursor_endpoint(

    file: UploadFile = File(...),

    threshold: float = Form(0.8)

):
    """

    Accepts an image file and returns the detected cursor position and details.

    """
    if not CURSOR_TEMPLATES:
        raise HTTPException(
            status_code=503, 
            detail="Cursor templates are not loaded. Server initialization failed."
        )

    # 1. Read image file content
    content = await file.read()
    
    # 2. Convert file content to OpenCV image format
    np_array = np.frombuffer(content, np.uint8)
    frame = cv2.imdecode(np_array, cv2.IMREAD_UNCHANGED)
    
    if frame is None:
        raise HTTPException(
            status_code=400, 
            detail="Could not decode image file. Ensure it is a valid image format (e.g., PNG, JPEG)."
        )

    # 3. Detect cursor
    pos, conf, template_name = detect_cursor_in_frame_multi(frame, CURSOR_TEMPLATES, threshold)

    # 4. Log values for debugging
    print(f"pos: {pos}, type: {type(pos)}")
    print(f"conf: {conf}, type: {type(conf)}")
    print(f"template_name: {template_name}, type: {type(template_name)}")
    print(f"frame.shape: {frame.shape}, type: {type(frame.shape)}")

    # 5. Prepare response
    # Handle infinite confidence values
    confidence = float(conf)
    if not (confidence == float('inf') or confidence == float('-inf')):
        confidence_val = confidence
    else:
        confidence_val = 1.0 if confidence > 0 else 0.0

    if pos is not None:
        response_data = {
            'cursor_active': True,
            'x': pos[0],
            'y': pos[1],
            'confidence': confidence_val,
            'template': template_name,
            'image_shape': list(frame.shape)
        }
    else:
        response_data = {
            'cursor_active': False,
            'x': None,
            'y': None,
            'confidence': confidence_val,
            'template': None,
            'image_shape': list(frame.shape)
        }

    return JSONResponse(content=response_data)

# Optional: Endpoint to get a list of loaded templates
@app.post("/track_cursor_url")
async def track_cursor_url_endpoint(

    image_url: str = Form(...),

    threshold: float = Form(0.8)

):
    """

    Accepts an image URL and returns the detected cursor position and details.

    """
    if not CURSOR_TEMPLATES:
        raise HTTPException(
            status_code=503,
            detail="Cursor templates are not loaded. Server initialization failed."
        )

    try:
        # Validate URL
        parsed_url = urlparse(image_url)
        if not all([parsed_url.scheme, parsed_url.netloc]):
            raise HTTPException(
                status_code=400,
                detail="Invalid URL provided"
            )

        # Download image
        content = await download_image_from_url(image_url)
        
        # Convert to OpenCV format
        np_array = np.frombuffer(content, np.uint8)
        frame = cv2.imdecode(np_array, cv2.IMREAD_UNCHANGED)
        
        if frame is None:
            raise HTTPException(
                status_code=400,
                detail="Could not decode image from URL. Ensure it is a valid image format (e.g., PNG, JPEG)."
            )

        # Detect cursor
        pos, conf, template_name = detect_cursor_in_frame_multi(frame, CURSOR_TEMPLATES, threshold)

        # Prepare response
        if pos is not None:
            response_data = {
                'cursor_active': True,
                'x': pos[0],
                'y': pos[1],
                'confidence': float(conf),
                'template': template_name,
                'image_shape': list(frame.shape),
                'source_url': image_url
            }
        else:
            response_data = {
                'cursor_active': False,
                'x': None,
                'y': None,
                'confidence': float(conf),
                'template': None,
                'image_shape': list(frame.shape),
                'source_url': image_url
            }

        return JSONResponse(content=response_data)

    except aiohttp.ClientError as e:
        raise HTTPException(
            status_code=400,
            detail=f"Failed to fetch image from URL: {str(e)}"
        )
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"An error occurred while processing the image: {str(e)}"
        )

@app.get("/templates")
async def list_templates():
    """Returns a list of all loaded cursor template names."""
    return {"templates": list(CURSOR_TEMPLATES.keys()), "count": len(CURSOR_TEMPLATES)}

port = int(os.environ.get("PORT", 7860))

# Launch FastAPI with uvicorn when run directly
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=port, timeout_keep_alive=75)