eliason1 commited on
Commit
dc7c040
·
verified ·
1 Parent(s): cfb87f0

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +297 -0
app.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import json
4
+ import os
5
+ from pathlib import Path
6
+ from fastapi import FastAPI, File, UploadFile, Form, HTTPException
7
+ from fastapi.responses import JSONResponse
8
+ from typing import Dict, Any, Tuple, Optional, Union
9
+ import io
10
+ import aiohttp
11
+ import uvicorn
12
+ from urllib.parse import urlparse
13
+
14
+ # --- Original Cursor Detection Functions (Adapted for Server) ---
15
+
16
+ def to_rgb(img: np.ndarray) -> Optional[np.ndarray]:
17
+ """Converts image to BGR format (3 channels). Handles None input."""
18
+ if img is None:
19
+ return None
20
+ if len(img.shape) == 2:
21
+ # Grayscale to BGR
22
+ return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
23
+ if img.shape[2] == 4:
24
+ # BGRA to BGR (removes alpha channel)
25
+ return cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
26
+ # Already BGR or RGB (assuming OpenCV reads as BGR)
27
+ return img
28
+
29
+ def get_mask_from_alpha(template_img: np.ndarray) -> Optional[np.ndarray]:
30
+ """Extracts a mask from the alpha channel of a 4-channel image."""
31
+ if template_img is not None and len(template_img.shape) == 3 and template_img.shape[2] == 4:
32
+ # Create a mask where alpha is greater than 0
33
+ return (template_img[:, :, 3] > 0).astype(np.uint8) * 255
34
+ return None
35
+
36
+ def detect_cursor_in_frame_multi(
37
+ frame: np.ndarray,
38
+ cursor_templates: Dict[str, np.ndarray],
39
+ threshold: float = 0.8
40
+ ) -> Tuple[Optional[Tuple[int, int]], float, Optional[str]]:
41
+ """
42
+ Detects the best matching cursor template in a single frame.
43
+ Returns (position, confidence, template_name).
44
+ """
45
+ best_pos = None
46
+ best_conf = -1.0
47
+ best_template_name = None
48
+ frame_rgb = to_rgb(frame)
49
+
50
+ if frame_rgb is None:
51
+ return None, -1.0, None
52
+
53
+ for template_name, cursor_template in cursor_templates.items():
54
+ template_rgb = to_rgb(cursor_template)
55
+ mask = get_mask_from_alpha(cursor_template)
56
+
57
+ if template_rgb is None or template_rgb.shape[2] != frame_rgb.shape[2]:
58
+ # print(f"[WARN] Skipping template {template_name} due to channel mismatch or load error.")
59
+ continue
60
+
61
+ # Ensure template is smaller than or equal to the frame
62
+ if template_rgb.shape[0] > frame_rgb.shape[0] or template_rgb.shape[1] > frame_rgb.shape[1]:
63
+ # print(f"[WARN] Skipping template {template_name}: template larger than frame.")
64
+ continue
65
+
66
+ try:
67
+ # Match template. Use mask for non-rectangular templates.
68
+ result = cv2.matchTemplate(frame_rgb, template_rgb, cv2.TM_CCOEFF_NORMED, mask=mask)
69
+ except Exception as e:
70
+ # print(f"[WARN] matchTemplate failed for {template_name}: {e}")
71
+ continue
72
+
73
+ _, max_val, _, max_loc = cv2.minMaxLoc(result)
74
+
75
+ if max_val > best_conf:
76
+ best_conf = max_val
77
+ if max_val >= threshold:
78
+ cursor_w, cursor_h = template_rgb.shape[1], template_rgb.shape[0]
79
+ # Calculate center position of the detected area
80
+ cursor_x = max_loc[0] + cursor_w // 2
81
+ cursor_y = max_loc[1] + cursor_h // 2
82
+ best_pos = (cursor_x, cursor_y)
83
+ best_template_name = template_name
84
+
85
+ if best_conf >= threshold:
86
+ return best_pos, best_conf, best_template_name
87
+ return None, best_conf, None
88
+
89
+ async def download_image_from_url(url: str) -> bytes:
90
+ """Download image from URL and return as bytes."""
91
+ async with aiohttp.ClientSession() as session:
92
+ async with session.get(url) as response:
93
+ if response.status != 200:
94
+ raise HTTPException(
95
+ status_code=400,
96
+ detail=f"Failed to fetch image from URL. Status code: {response.status}"
97
+ )
98
+ return await response.read()
99
+
100
+ # --- Server Setup ---
101
+
102
+ app = FastAPI(
103
+ title="Cursor Tracker API",
104
+ description="API to detect and track mouse cursors in uploaded images using template matching."
105
+ )
106
+
107
+ # Global variable to store loaded templates
108
+ CURSOR_TEMPLATES: Dict[str, np.ndarray] = {}
109
+ CURSOR_TEMPLATES_DIR = Path("cursors")
110
+
111
+ def load_cursor_templates():
112
+ """Loads all cursor templates from the specified directory."""
113
+ global CURSOR_TEMPLATES
114
+ if CURSOR_TEMPLATES:
115
+ print("Templates already loaded.")
116
+ return
117
+
118
+ print(f"Loading cursor templates from: {CURSOR_TEMPLATES_DIR}")
119
+
120
+ if not CURSOR_TEMPLATES_DIR.is_dir():
121
+ print(f"Error: Template directory not found at {CURSOR_TEMPLATES_DIR}")
122
+ return
123
+
124
+ for template_file in CURSOR_TEMPLATES_DIR.glob('*.png'):
125
+ # Load image with alpha channel (IMREAD_UNCHANGED)
126
+ template_img = cv2.imread(str(template_file), cv2.IMREAD_UNCHANGED)
127
+ if template_img is not None:
128
+ CURSOR_TEMPLATES[template_file.name] = template_img
129
+ else:
130
+ print(f"[WARN] Could not load template: {template_file.name}")
131
+
132
+ if not CURSOR_TEMPLATES:
133
+ print(f"FATAL: No cursor templates found in: {CURSOR_TEMPLATES_DIR}")
134
+ else:
135
+ print(f"Successfully loaded {len(CURSOR_TEMPLATES)} templates.")
136
+
137
+ @app.on_event("startup")
138
+ async def startup_event():
139
+ """Load templates when the application starts."""
140
+ load_cursor_templates()
141
+
142
+ @app.get("/")
143
+ async def root():
144
+ """Simple root endpoint for health check."""
145
+ return {"message": "Cursor Tracker API is running. Use /track_cursor to upload an image."}
146
+
147
+ @app.post("/track_cursor")
148
+ async def track_cursor_endpoint(
149
+ file: UploadFile = File(...),
150
+ threshold: float = Form(0.8)
151
+ ):
152
+ """
153
+ Accepts an image file and returns the detected cursor position and details.
154
+ """
155
+ if not CURSOR_TEMPLATES:
156
+ raise HTTPException(
157
+ status_code=503,
158
+ detail="Cursor templates are not loaded. Server initialization failed."
159
+ )
160
+
161
+ # 1. Read image file content
162
+ content = await file.read()
163
+
164
+ # 2. Convert file content to OpenCV image format
165
+ np_array = np.frombuffer(content, np.uint8)
166
+ frame = cv2.imdecode(np_array, cv2.IMREAD_UNCHANGED)
167
+
168
+ if frame is None:
169
+ raise HTTPException(
170
+ status_code=400,
171
+ detail="Could not decode image file. Ensure it is a valid image format (e.g., PNG, JPEG)."
172
+ )
173
+
174
+ # 3. Detect cursor
175
+ pos, conf, template_name = detect_cursor_in_frame_multi(frame, CURSOR_TEMPLATES, threshold)
176
+
177
+ # 4. Log values for debugging
178
+ print(f"pos: {pos}, type: {type(pos)}")
179
+ print(f"conf: {conf}, type: {type(conf)}")
180
+ print(f"template_name: {template_name}, type: {type(template_name)}")
181
+ print(f"frame.shape: {frame.shape}, type: {type(frame.shape)}")
182
+
183
+ # 5. Prepare response
184
+ # Handle infinite confidence values
185
+ confidence = float(conf)
186
+ if not (confidence == float('inf') or confidence == float('-inf')):
187
+ confidence_val = confidence
188
+ else:
189
+ confidence_val = 1.0 if confidence > 0 else 0.0
190
+
191
+ if pos is not None:
192
+ response_data = {
193
+ 'cursor_active': True,
194
+ 'x': pos[0],
195
+ 'y': pos[1],
196
+ 'confidence': confidence_val,
197
+ 'template': template_name,
198
+ 'image_shape': list(frame.shape)
199
+ }
200
+ else:
201
+ response_data = {
202
+ 'cursor_active': False,
203
+ 'x': None,
204
+ 'y': None,
205
+ 'confidence': confidence_val,
206
+ 'template': None,
207
+ 'image_shape': list(frame.shape)
208
+ }
209
+
210
+ return JSONResponse(content=response_data)
211
+
212
+ # Optional: Endpoint to get a list of loaded templates
213
+ @app.post("/track_cursor_url")
214
+ async def track_cursor_url_endpoint(
215
+ image_url: str = Form(...),
216
+ threshold: float = Form(0.8)
217
+ ):
218
+ """
219
+ Accepts an image URL and returns the detected cursor position and details.
220
+ """
221
+ if not CURSOR_TEMPLATES:
222
+ raise HTTPException(
223
+ status_code=503,
224
+ detail="Cursor templates are not loaded. Server initialization failed."
225
+ )
226
+
227
+ try:
228
+ # Validate URL
229
+ parsed_url = urlparse(image_url)
230
+ if not all([parsed_url.scheme, parsed_url.netloc]):
231
+ raise HTTPException(
232
+ status_code=400,
233
+ detail="Invalid URL provided"
234
+ )
235
+
236
+ # Download image
237
+ content = await download_image_from_url(image_url)
238
+
239
+ # Convert to OpenCV format
240
+ np_array = np.frombuffer(content, np.uint8)
241
+ frame = cv2.imdecode(np_array, cv2.IMREAD_UNCHANGED)
242
+
243
+ if frame is None:
244
+ raise HTTPException(
245
+ status_code=400,
246
+ detail="Could not decode image from URL. Ensure it is a valid image format (e.g., PNG, JPEG)."
247
+ )
248
+
249
+ # Detect cursor
250
+ pos, conf, template_name = detect_cursor_in_frame_multi(frame, CURSOR_TEMPLATES, threshold)
251
+
252
+ # Prepare response
253
+ if pos is not None:
254
+ response_data = {
255
+ 'cursor_active': True,
256
+ 'x': pos[0],
257
+ 'y': pos[1],
258
+ 'confidence': float(conf),
259
+ 'template': template_name,
260
+ 'image_shape': list(frame.shape),
261
+ 'source_url': image_url
262
+ }
263
+ else:
264
+ response_data = {
265
+ 'cursor_active': False,
266
+ 'x': None,
267
+ 'y': None,
268
+ 'confidence': float(conf),
269
+ 'template': None,
270
+ 'image_shape': list(frame.shape),
271
+ 'source_url': image_url
272
+ }
273
+
274
+ return JSONResponse(content=response_data)
275
+
276
+ except aiohttp.ClientError as e:
277
+ raise HTTPException(
278
+ status_code=400,
279
+ detail=f"Failed to fetch image from URL: {str(e)}"
280
+ )
281
+ except Exception as e:
282
+ raise HTTPException(
283
+ status_code=500,
284
+ detail=f"An error occurred while processing the image: {str(e)}"
285
+ )
286
+
287
+ @app.get("/templates")
288
+ async def list_templates():
289
+ """Returns a list of all loaded cursor template names."""
290
+ return {"templates": list(CURSOR_TEMPLATES.keys()), "count": len(CURSOR_TEMPLATES)}
291
+
292
+ port = int(os.environ.get("PORT", 7860))
293
+
294
+ # Launch FastAPI with uvicorn when run directly
295
+ if __name__ == "__main__":
296
+ import uvicorn
297
+ uvicorn.run(app, host="0.0.0.0", port=port, timeout_keep_alive=75)