Fred808 commited on
Commit
200a2f9
·
verified ·
1 Parent(s): 8a5c31a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +507 -0
app.py ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from PIL import Image
4
+ import os
5
+ import threading
6
+ import time
7
+ import urllib.parse
8
+ from fastapi import FastAPI, UploadFile, File, HTTPException, Form, Request
9
+ from fastapi.responses import JSONResponse, HTMLResponse
10
+ from fastapi.staticfiles import StaticFiles
11
+ from fastapi.templating import Jinja2Templates
12
+ import json
13
+ import io
14
+ from pathlib import Path
15
+ from huggingface_hub import HfApi, hf_hub_download
16
+ import asyncio
17
+ import uvicorn
18
+ from typing import Optional, Dict, Tuple, List
19
+ import aiohttp
20
+ from urllib.parse import urlparse
21
+
22
+ app = FastAPI(
23
+ title="Cursor Detection and Tracking Server",
24
+ description="Processes images to detect cursors and uploads results to dataset"
25
+ )
26
+
27
+ # Setup static files and templates
28
+ app.mount("/static", StaticFiles(directory="static"), name="static")
29
+ templates = Jinja2Templates(directory="templates")
30
+
31
+ # --- Environment Configuration ---
32
+ HF_TOKEN = os.getenv("HF_TOKEN", "")
33
+ HF_DATASET_ID = os.getenv("HF_DATASET_ID", "Fred808/data") # Dataset to store results
34
+ HF_STATE_FILE = os.getenv("HF_STATE_FILE", "processing_state_cursors.json")
35
+ TEMP_DATASET_DIR = Path("temp_cursor_detection")
36
+ TEMP_DATASET_DIR.mkdir(exist_ok=True)
37
+
38
+ # Global variable to store loaded templates
39
+ CURSOR_TEMPLATES: Dict[str, np.ndarray] = {}
40
+ CURSOR_TEMPLATES_DIR = Path("cursors")
41
+
42
+ # --- Cursor Detection Functions ---
43
+
44
+ def to_rgb(img: np.ndarray) -> Optional[np.ndarray]:
45
+ """Converts image to BGR format (3 channels). Handles None input."""
46
+ if img is None:
47
+ return None
48
+ if len(img.shape) == 2:
49
+ return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
50
+ if img.shape[2] == 4:
51
+ return cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
52
+ return img
53
+
54
+ def get_mask_from_alpha(template_img: np.ndarray) -> Optional[np.ndarray]:
55
+ """Extracts a mask from the alpha channel of a 4-channel image."""
56
+ if template_img is not None and len(template_img.shape) == 3 and template_img.shape[2] == 4:
57
+ return (template_img[:, :, 3] > 0).astype(np.uint8) * 255
58
+ return None
59
+
60
+ def detect_cursor_in_frame_multi(
61
+ frame: np.ndarray,
62
+ cursor_templates: Dict[str, np.ndarray],
63
+ threshold: float = 0.8
64
+ ) -> Tuple[Optional[Tuple[int, int]], float, Optional[str]]:
65
+ """
66
+ Detects the best matching cursor template in a single frame.
67
+ Returns (position, confidence, template_name).
68
+ """
69
+ best_pos = None
70
+ best_conf = -1.0
71
+ best_template_name = None
72
+ frame_rgb = to_rgb(frame)
73
+
74
+ if frame_rgb is None:
75
+ return None, -1.0, None
76
+
77
+ for template_name, cursor_template in cursor_templates.items():
78
+ template_rgb = to_rgb(cursor_template)
79
+ mask = get_mask_from_alpha(cursor_template)
80
+
81
+ if template_rgb is None or template_rgb.shape[2] != frame_rgb.shape[2]:
82
+ continue
83
+
84
+ if template_rgb.shape[0] > frame_rgb.shape[0] or template_rgb.shape[1] > frame_rgb.shape[1]:
85
+ continue
86
+
87
+ try:
88
+ result = cv2.matchTemplate(frame_rgb, template_rgb, cv2.TM_CCOEFF_NORMED, mask=mask)
89
+ except Exception:
90
+ continue
91
+
92
+ _, max_val, _, max_loc = cv2.minMaxLoc(result)
93
+
94
+ if max_val > best_conf:
95
+ best_conf = max_val
96
+ if max_val >= threshold:
97
+ cursor_w, cursor_h = template_rgb.shape[1], template_rgb.shape[0]
98
+ cursor_x = max_loc[0] + cursor_w // 2
99
+ cursor_y = max_loc[1] + cursor_h // 2
100
+ best_pos = (cursor_x, cursor_y)
101
+ best_template_name = template_name
102
+
103
+ if best_conf >= threshold:
104
+ return best_pos, best_conf, best_template_name
105
+ return None, best_conf, None
106
+
107
+ async def download_image_from_url(url: str) -> bytes:
108
+ """Download image from URL and return as bytes."""
109
+ async with aiohttp.ClientSession() as session:
110
+ async with session.get(url) as response:
111
+ if response.status != 200:
112
+ raise HTTPException(
113
+ status_code=400,
114
+ detail=f"Failed to fetch image from URL. Status code: {response.status}"
115
+ )
116
+ return await response.read()
117
+
118
+ def load_cursor_templates():
119
+ """Loads all cursor templates from the specified directory."""
120
+ global CURSOR_TEMPLATES
121
+ if CURSOR_TEMPLATES:
122
+ print("Templates already loaded.")
123
+ return
124
+
125
+ print(f"Loading cursor templates from: {CURSOR_TEMPLATES_DIR}")
126
+
127
+ if not CURSOR_TEMPLATES_DIR.is_dir():
128
+ print(f"Error: Template directory not found at {CURSOR_TEMPLATES_DIR}")
129
+ return
130
+
131
+ for template_file in CURSOR_TEMPLATES_DIR.glob('*.png'):
132
+ template_img = cv2.imread(str(template_file), cv2.IMREAD_UNCHANGED)
133
+ if template_img is not None:
134
+ CURSOR_TEMPLATES[template_file.name] = template_img
135
+ else:
136
+ print(f"[WARN] Could not load template: {template_file.name}")
137
+
138
+ if not CURSOR_TEMPLATES:
139
+ print(f"FATAL: No cursor templates found in: {CURSOR_TEMPLATES_DIR}")
140
+ else:
141
+ print(f"Successfully loaded {len(CURSOR_TEMPLATES)} templates.")
142
+
143
+ # --- Dataset Management Functions ---
144
+
145
+ def _load_hf_state() -> dict:
146
+ """Download the HF state file from the dataset and return parsed JSON."""
147
+ default = {"next_download_index": 0, "file_states": {}}
148
+ try:
149
+ api = HfApi(token=HF_TOKEN)
150
+ files = api.list_repo_files(repo_id=HF_DATASET_ID, repo_type="dataset")
151
+ if HF_STATE_FILE not in files:
152
+ print(f"[DATASET] State file not found in {HF_DATASET_ID}. Using default state.")
153
+ return default
154
+
155
+ hf_hub_download(repo_id=HF_DATASET_ID, filename=HF_STATE_FILE, repo_type="dataset", token=HF_TOKEN, local_dir=TEMP_DATASET_DIR)
156
+ p = TEMP_DATASET_DIR / HF_STATE_FILE
157
+ with p.open('r', encoding='utf-8') as f:
158
+ data = json.load(f)
159
+
160
+ if "file_states" not in data or not isinstance(data["file_states"], dict):
161
+ data["file_states"] = {}
162
+ if "next_download_index" not in data:
163
+ data["next_download_index"] = 0
164
+ return data
165
+ except Exception as e:
166
+ print(f"[DATASET] Failed to load HF state: {e}")
167
+ return default
168
+
169
+ def _upload_hf_state(state: dict) -> bool:
170
+ """Upload the HF state file to the dataset."""
171
+ try:
172
+ p = TEMP_DATASET_DIR / HF_STATE_FILE
173
+ with p.open('w', encoding='utf-8') as f:
174
+ json.dump(state, f, indent=2)
175
+
176
+ api = HfApi(token=HF_TOKEN)
177
+ api.upload_file(
178
+ path_or_fileobj=str(p),
179
+ path_in_repo=HF_STATE_FILE,
180
+ repo_id=HF_DATASET_ID,
181
+ repo_type="dataset",
182
+ commit_message=f"Update processing state: next_index={state.get('next_download_index')}"
183
+ )
184
+ print(f"[DATASET] Uploaded state to {HF_DATASET_ID}.")
185
+ return True
186
+ except Exception as e:
187
+ print(f"[DATASET] Failed to upload HF state: {e}")
188
+ return False
189
+
190
+ def _lock_file_for_processing(image_name: str, state: dict) -> bool:
191
+ """Attempt to mark image as 'processing' and upload state to establish lock."""
192
+ print(f"[DATASET] Attempting to lock {image_name}...")
193
+ state.setdefault('file_states', {})
194
+ state['file_states'][image_name] = 'processing'
195
+ if _upload_hf_state(state):
196
+ print(f"[DATASET] Locked {image_name}.")
197
+ return True
198
+ else:
199
+ state['file_states'].pop(image_name, None)
200
+ return False
201
+
202
+ def _unlock_file_as_processed(image_name: str, state: dict, next_index: int) -> bool:
203
+ """Mark as processed and update next index, upload state."""
204
+ print(f"[DATASET] Marking {image_name} as processed...")
205
+ state.setdefault('file_states', {})
206
+ state['file_states'][image_name] = 'processed'
207
+ state['next_download_index'] = next_index
208
+ return _upload_hf_state(state)
209
+
210
+ def _get_image_list_from_hf() -> list:
211
+ """Return sorted list of image file paths from HF_DATASET_ID."""
212
+ try:
213
+ api = HfApi(token=HF_TOKEN)
214
+ files = api.list_repo_files(repo_id=HF_DATASET_ID, repo_type="dataset")
215
+ image_files = sorted([f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
216
+ print(f"[DATASET] Found {len(image_files)} image files in {HF_DATASET_ID}.")
217
+ return image_files
218
+ except Exception as e:
219
+ print(f"[DATASET] Error listing HF dataset files: {e}")
220
+ return []
221
+
222
+ def _upload_cursor_results(image_name: str, results: dict) -> bool:
223
+ """Upload cursor detection results JSON to dataset."""
224
+ try:
225
+ filename = Path(image_name).with_suffix('.json').name
226
+ content = json.dumps(results, indent=2, ensure_ascii=False).encode('utf-8')
227
+ api = HfApi(token=HF_TOKEN)
228
+ api.upload_file(
229
+ path_or_fileobj=io.BytesIO(content),
230
+ path_in_repo=f"cursor_results/{filename}",
231
+ repo_id=HF_DATASET_ID,
232
+ repo_type="dataset",
233
+ commit_message=f"Cursor detection results for {image_name}"
234
+ )
235
+ print(f"[DATASET] Uploaded results for {image_name} to {HF_DATASET_ID}.")
236
+ return True
237
+ except Exception as e:
238
+ print(f"[DATASET] Failed to upload results for {image_name}: {e}")
239
+ return False
240
+
241
+ class DatasetProgress:
242
+ """Track dataset processing progress"""
243
+ def __init__(self):
244
+ self.current_image = None
245
+ self.total_images = 0
246
+ self.processed_images = 0
247
+ self.status = "idle"
248
+ self.error = None
249
+ self.start_time = None
250
+
251
+ def to_dict(self):
252
+ return {
253
+ "status": self.status,
254
+ "current_image": self.current_image,
255
+ "progress": f"{self.processed_images}/{self.total_images}" if self.total_images else "0/0",
256
+ "elapsed": time.time() - self.start_time if self.start_time else 0,
257
+ "error": self.error
258
+ }
259
+
260
+ # Global progress tracker
261
+ dataset_progress = DatasetProgress()
262
+
263
+ async def process_image(image_bytes: bytes, threshold: float = 0.8) -> dict:
264
+ """Process a single image and return cursor detection results."""
265
+ try:
266
+ np_array = np.frombuffer(image_bytes, np.uint8)
267
+ frame = cv2.imdecode(np_array, cv2.IMREAD_UNCHANGED)
268
+
269
+ if frame is None:
270
+ raise ValueError("Could not decode image")
271
+
272
+ pos, conf, template_name = detect_cursor_in_frame_multi(frame, CURSOR_TEMPLATES, threshold)
273
+
274
+ confidence = float(conf)
275
+ if confidence == float('inf') or confidence == float('-inf'):
276
+ confidence = 1.0 if confidence > 0 else 0.0
277
+
278
+ return {
279
+ 'cursor_active': pos is not None,
280
+ 'x': pos[0] if pos else None,
281
+ 'y': pos[1] if pos else None,
282
+ 'confidence': confidence,
283
+ 'template': template_name,
284
+ 'image_shape': list(frame.shape)
285
+ }
286
+ except Exception as e:
287
+ raise ValueError(f"Error processing image: {str(e)}")
288
+
289
+ async def dataset_task(start_index: int = 1):
290
+ """Main dataset processing loop."""
291
+ global dataset_progress
292
+
293
+ dataset_progress = DatasetProgress()
294
+ dataset_progress.status = "starting"
295
+ dataset_progress.start_time = time.time()
296
+
297
+ print(f"[DATASET] Starting dataset task from index {start_index}...")
298
+
299
+ if not CURSOR_TEMPLATES:
300
+ err = "No cursor templates loaded"
301
+ dataset_progress.status = "error"
302
+ dataset_progress.error = err
303
+ print(f"[DATASET] {err}")
304
+ return False
305
+
306
+ try:
307
+ state = await asyncio.to_thread(_load_hf_state)
308
+ image_list = await asyncio.to_thread(_get_image_list_from_hf)
309
+
310
+ if not image_list:
311
+ err = "No images found in dataset"
312
+ dataset_progress.status = "error"
313
+ dataset_progress.error = err
314
+ print(f"[DATASET] {err}")
315
+ return False
316
+
317
+ dataset_progress.total_images = len(image_list)
318
+ dataset_progress.status = "processing"
319
+
320
+ if start_index < 1:
321
+ start_index = 1
322
+
323
+ for idx in range(start_index-1, len(image_list)):
324
+ try:
325
+ image_path = image_list[idx]
326
+ image_name = Path(image_path).name
327
+ print(f"[DATASET] Processing image {idx + 1}/{len(image_list)}: {image_name}")
328
+
329
+ file_state = state.get('file_states', {}).get(image_name)
330
+ if file_state == 'processed':
331
+ print(f"[DATASET] Skipping {image_name}: already processed.")
332
+ dataset_progress.processed_images += 1
333
+ continue
334
+ if file_state == 'processing':
335
+ print(f"[DATASET] Skipping {image_name}: currently processing by another worker.")
336
+ continue
337
+
338
+ # Try to lock
339
+ locked = await asyncio.to_thread(_lock_file_for_processing, image_name, state)
340
+ if not locked:
341
+ print(f"[DATASET] Could not lock {image_name}; skipping.")
342
+ continue
343
+
344
+ try:
345
+ # Download and process image
346
+ print(f"[DATASET] Downloading {image_name}...")
347
+ image_bytes = await asyncio.to_thread(
348
+ lambda: hf_hub_download(
349
+ repo_id=HF_DATASET_ID,
350
+ filename=image_path,
351
+ repo_type="dataset",
352
+ token=HF_TOKEN
353
+ )
354
+ )
355
+
356
+ with open(image_bytes, 'rb') as f:
357
+ content = f.read()
358
+
359
+ # Process image
360
+ results = await process_image(content)
361
+ results['image_name'] = image_name
362
+ results['image_path'] = image_path
363
+
364
+ # Upload results
365
+ uploaded = await asyncio.to_thread(_upload_cursor_results, image_name, results)
366
+
367
+ if uploaded:
368
+ next_index = idx + 2 # next 1-based index
369
+ ok = await asyncio.to_thread(_unlock_file_as_processed, image_name, state, next_index)
370
+ if not ok:
371
+ print(f"[DATASET] Warning: processed but failed to update state for {image_name}.")
372
+ dataset_progress.processed_images += 1
373
+ print(f"[DATASET] Successfully processed {image_name}")
374
+ else:
375
+ print(f"[DATASET] Failed to upload results for {image_name}")
376
+ state['file_states'][image_name] = 'failed'
377
+ await asyncio.to_thread(_upload_hf_state, state)
378
+
379
+ except Exception as e:
380
+ print(f"[DATASET] Error processing {image_name}: {e}")
381
+ state['file_states'][image_name] = 'failed'
382
+ await asyncio.to_thread(_upload_hf_state, state)
383
+ continue
384
+
385
+ except Exception as e:
386
+ print(f"[DATASET] Error in image loop: {e}")
387
+ continue
388
+
389
+ print(f"[DATASET] Task completed. Processed {dataset_progress.processed_images}/{len(image_list)} images.")
390
+ dataset_progress.status = "completed"
391
+ return True
392
+
393
+ except Exception as e:
394
+ err = f"Error in main processing loop: {str(e)}"
395
+ dataset_progress.status = "error"
396
+ dataset_progress.error = err
397
+ print(f"[DATASET] {err}")
398
+ return False
399
+
400
+ @app.on_event("startup")
401
+ async def startup_event():
402
+ """Load templates when the application starts."""
403
+ load_cursor_templates()
404
+
405
+ @app.post('/start_dataset')
406
+ async def start_dataset(start_index: int = Form(1)):
407
+ """Trigger dataset processing in background."""
408
+ try:
409
+ if dataset_progress and dataset_progress.status in ("starting", "processing"):
410
+ return JSONResponse(
411
+ status_code=400,
412
+ content={
413
+ "status": "error",
414
+ "error": "Dataset processing already running",
415
+ "progress": dataset_progress.to_dict()
416
+ }
417
+ )
418
+
419
+ if not CURSOR_TEMPLATES:
420
+ return JSONResponse(
421
+ status_code=503,
422
+ content={
423
+ "status": "error",
424
+ "error": "Cursor templates not loaded. Please ensure templates are available."
425
+ }
426
+ )
427
+
428
+ import asyncio as _asyncio
429
+ _asyncio.create_task(dataset_task(start_index))
430
+ return JSONResponse(content={
431
+ "status": "started",
432
+ "start_index": start_index,
433
+ "message": "Dataset processing started. Check /status endpoint for progress."
434
+ })
435
+ except Exception as e:
436
+ return JSONResponse(status_code=500, content={"status": "error", "error": str(e)})
437
+
438
+ @app.get('/dataset_status')
439
+ async def get_dataset_status():
440
+ """Get current dataset processing status and progress."""
441
+ if not dataset_progress:
442
+ return {"status": "idle"}
443
+ return dataset_progress.to_dict()
444
+
445
+ @app.post("/track_cursor")
446
+ async def track_cursor_endpoint(
447
+ file: UploadFile = File(...),
448
+ threshold: float = Form(0.8)
449
+ ):
450
+ """Process a single uploaded image and return cursor detection results."""
451
+ if not CURSOR_TEMPLATES:
452
+ raise HTTPException(
453
+ status_code=503,
454
+ detail="Cursor templates are not loaded."
455
+ )
456
+
457
+ content = await file.read()
458
+ results = await process_image(content, threshold)
459
+ return JSONResponse(content=results)
460
+
461
+ @app.post("/track_cursor_url")
462
+ async def track_cursor_url_endpoint(
463
+ image_url: str = Form(...),
464
+ threshold: float = Form(0.8)
465
+ ):
466
+ """Process an image from URL and return cursor detection results."""
467
+ if not CURSOR_TEMPLATES:
468
+ raise HTTPException(
469
+ status_code=503,
470
+ detail="Cursor templates are not loaded."
471
+ )
472
+
473
+ try:
474
+ parsed_url = urlparse(image_url)
475
+ if not all([parsed_url.scheme, parsed_url.netloc]):
476
+ raise HTTPException(
477
+ status_code=400,
478
+ detail="Invalid URL provided"
479
+ )
480
+
481
+ content = await download_image_from_url(image_url)
482
+ results = await process_image(content, threshold)
483
+ results['source_url'] = image_url
484
+ return JSONResponse(content=results)
485
+
486
+ except Exception as e:
487
+ raise HTTPException(
488
+ status_code=500,
489
+ detail=f"An error occurred while processing the image: {str(e)}"
490
+ )
491
+
492
+ @app.get("/templates")
493
+ async def list_templates():
494
+ """Returns a list of all loaded cursor template names."""
495
+ return {"templates": list(CURSOR_TEMPLATES.keys()), "count": len(CURSOR_TEMPLATES)}
496
+
497
+ @app.get("/", response_class=HTMLResponse)
498
+ async def home(request: Request):
499
+ return templates.TemplateResponse("home.html", {"request": request})
500
+
501
+ # Get the port from environment variable
502
+ port = int(os.environ.get("PORT", 7860))
503
+
504
+ # Launch FastAPI with uvicorn when run directly
505
+ if __name__ == "__main__":
506
+ import uvicorn
507
+ uvicorn.run(app, host="0.0.0.0", port=port, timeout_keep_alive=75)