Fred808 commited on
Commit
27762cf
·
verified ·
1 Parent(s): aef1217

Update download_api.py

Browse files
Files changed (1) hide show
  1. download_api.py +222 -90
download_api.py CHANGED
@@ -1,104 +1,236 @@
1
- from fastapi import FastAPI, HTTPException
2
- from fastapi.responses import FileResponse
3
- from pathlib import Path
4
  import os
5
  import json
 
6
  import threading
7
- import requests
8
- from io import BytesIO
9
- from PIL import Image
10
- import numpy as np
11
- from cursor_tracker import track_cursor_from_images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- app = FastAPI()
 
 
 
 
 
 
 
 
14
 
15
- ANNOTATIONS_DIR = Path("annotations").resolve()
16
- CURSOR_TEMPLATES_DIR = Path("cursors").resolve()
17
- ANNOTATION_OUTPUT = ANNOTATIONS_DIR / "blender1.json"
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- REMOTE_FRAMES_LIST_URL = "https://fred808-cu2.hf.space/frames"
20
- REMOTE_FRAME_BASE_URL = "https://fred808-cu2.hf.space/frames"
 
 
 
 
 
 
21
 
22
- ANNOTATIONS_DIR.mkdir(parents=True, exist_ok=True)
23
- cursor_tracking_started = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- def append_to_annotation_file(new_data: list, output_path: str):
 
 
 
 
 
 
 
 
 
 
26
  try:
27
- if os.path.exists(output_path):
28
- with open(output_path, "r", encoding="utf-8") as f:
29
- existing = json.load(f)
30
- else:
31
- existing = []
32
- except Exception:
33
- existing = []
34
-
35
- with open(output_path, "w", encoding="utf-8") as f:
36
- json.dump(existing + new_data, f, indent=2)
37
-
38
- def run_cursor_tracker_in_batches(batch_size=5000):
39
- global cursor_tracking_started
40
- if cursor_tracking_started:
41
- return
42
- cursor_tracking_started = True
43
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  try:
45
- res = requests.get(REMOTE_FRAMES_LIST_URL)
46
- res.raise_for_status()
47
- frame_list = res.json().get("frames", [])
48
- total = len(frame_list)
49
- print(f"[INFO] Found {total} frames.")
50
-
51
- for start in range(0, total, batch_size):
52
- end = min(start + batch_size, total)
53
- batch = frame_list[start:end]
54
- print(f"[BATCH] Fetching frames {start} to {end - 1}")
55
-
56
- images = []
57
- for fname in batch:
58
- try:
59
- url = f"{REMOTE_FRAME_BASE_URL}/{fname}"
60
- r = requests.get(url)
61
- if r.status_code == 429:
62
- import time; time.sleep(1.0)
63
- r = requests.get(url)
64
- r.raise_for_status()
65
- img = Image.open(BytesIO(r.content)).convert("RGB")
66
- images.append((fname, np.array(img)))
67
- import time; time.sleep(0.3)
68
- except Exception as e:
69
- print(f"[WARN] Could not fetch {fname}: {e}")
70
-
71
- if not images:
72
- continue
73
-
74
- results = track_cursor_from_images(
75
- images=images,
76
- cursor_templates_dir=str(CURSOR_TEMPLATES_DIR),
77
- output_json_path=None,
78
- threshold=0.8,
79
- return_results=True
80
- )
81
- append_to_annotation_file(results, str(ANNOTATION_OUTPUT))
82
- print(f"[✅] Appended {len(results)} annotations to JSON.")
83
-
 
 
 
 
 
 
 
84
  except Exception as e:
85
- print(f"[ERROR] Cursor tracking failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- @app.on_event("startup")
88
- def startup_event():
89
- threading.Thread(target=run_cursor_tracker_in_batches, daemon=True).start()
90
 
91
- @app.get("/download/{filename}")
92
- def download_file(filename: str):
93
- file_path = ANNOTATIONS_DIR / filename
94
- if not file_path.exists() or not file_path.is_file():
95
- raise HTTPException(status_code=404, detail="File not found")
96
- return FileResponse(str(file_path), filename=filename)
97
 
98
- @app.get("/")
99
- def root():
100
- files = [f.name for f in ANNOTATIONS_DIR.glob("*.json")]
101
- return {
102
- "message": "Use /download/{filename} to get cursor tracking results.",
103
- "available_files": files
104
- }
 
 
 
 
1
  import os
2
  import json
3
+ import time
4
  import threading
5
+ from fastapi import FastAPI, HTTPException, BackgroundTasks
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from fastapi.responses import JSONResponse
8
+ import uvicorn
9
+
10
+ # Import from cursor_tracker
11
+ from cursor_tracker import (
12
+ main_processing_loop,
13
+ processing_status,
14
+ CURSOR_TRACKING_OUTPUT_FOLDER,
15
+ CURSOR_TEMPLATES_DIR,
16
+ log_message
17
+ )
18
+
19
+ # FastAPI App Definition
20
+ app = FastAPI(title="Cursor Tracking API", description="API to access cursor tracking results", version="1.0.0")
21
+
22
+ # Add CORS middleware to allow cross-origin requests
23
+ app.add_middleware(
24
+ CORSMiddleware,
25
+ allow_origins=["*"], # Allows all origins
26
+ allow_credentials=True,
27
+ allow_methods=["*"], # Allows all methods
28
+ allow_headers=["*"],
29
+ )
30
+
31
+ # Global variable to track if processing is running
32
+ processing_thread = None
33
 
34
+ @app.on_event("startup")
35
+ async def startup_event():
36
+ """Run the processing loop in the background when the API starts"""
37
+ global processing_thread
38
+ if not (processing_thread and processing_thread.is_alive()):
39
+ log_message("🚀 Starting RAR extraction, frame extraction, and cursor tracking pipeline in background...")
40
+ processing_thread = threading.Thread(target=main_processing_loop)
41
+ processing_thread.daemon = True
42
+ processing_thread.start()
43
 
44
+ @app.get("/")
45
+ async def root():
46
+ """Root endpoint with API information"""
47
+ return {
48
+ "message": "Cursor Tracking API",
49
+ "version": "1.0.0",
50
+ "endpoints": {
51
+ "/status": "Get processing status",
52
+ "/cursor-data": "List all cursor tracking JSON files",
53
+ "/cursor-data/{filename}": "Get specific cursor tracking data",
54
+ "/start-processing": "Start the RAR processing pipeline",
55
+ "/stop-processing": "Stop the RAR processing pipeline"
56
+ }
57
+ }
58
 
59
+ @app.get("/status")
60
+ async def get_status():
61
+ """Get current processing status"""
62
+ return {
63
+ "processing_status": processing_status,
64
+ "cursor_tracking_folder": CURSOR_TRACKING_OUTPUT_FOLDER,
65
+ "folder_exists": os.path.exists(CURSOR_TRACKING_OUTPUT_FOLDER)
66
+ }
67
 
68
+ @app.get("/cursor-data")
69
+ async def list_cursor_data():
70
+ """List all available cursor tracking JSON files"""
71
+ if not os.path.exists(CURSOR_TRACKING_OUTPUT_FOLDER):
72
+ return {"files": [], "message": "Cursor tracking output folder does not exist yet"}
73
+
74
+ json_files = []
75
+ for file in os.listdir(CURSOR_TRACKING_OUTPUT_FOLDER):
76
+ if file.endswith(".json"):
77
+ file_path = os.path.join(CURSOR_TRACKING_OUTPUT_FOLDER, file)
78
+ file_stats = os.stat(file_path)
79
+ json_files.append({
80
+ "filename": file,
81
+ "size_bytes": file_stats.st_size,
82
+ "modified_time": time.ctime(file_stats.st_mtime),
83
+ "download_url": f"/cursor-data/{file}"
84
+ })
85
+
86
+ return {
87
+ "files": json_files,
88
+ "total_files": len(json_files),
89
+ "folder_path": CURSOR_TRACKING_OUTPUT_FOLDER
90
+ }
91
 
92
+ @app.get("/cursor-data/{filename}")
93
+ async def get_cursor_data(filename: str):
94
+ """Get specific cursor tracking data by filename"""
95
+ if not filename.endswith(".json"):
96
+ raise HTTPException(status_code=400, detail="File must be a JSON file")
97
+
98
+ file_path = os.path.join(CURSOR_TRACKING_OUTPUT_FOLDER, filename)
99
+
100
+ if not os.path.exists(file_path):
101
+ raise HTTPException(status_code=404, detail=f"File {filename} not found")
102
+
103
  try:
104
+ with open(file_path, "r") as f:
105
+ data = json.load(f)
106
+
107
+ # Add metadata
108
+ file_stats = os.stat(file_path)
109
+ response_data = {
110
+ "filename": filename,
111
+ "file_size_bytes": file_stats.st_size,
112
+ "modified_time": time.ctime(file_stats.st_mtime),
113
+ "total_frames": len(data),
114
+ "cursor_active_frames": len([frame for frame in data if frame.get("cursor_active", False)]),
115
+ "data": data
116
+ }
117
+
118
+ return response_data
119
+
120
+ except json.JSONDecodeError:
121
+ raise HTTPException(status_code=500, detail=f"Invalid JSON in file {filename}")
122
+ except Exception as e:
123
+ raise HTTPException(status_code=500, detail=f"Error reading file {filename}: {str(e)}")
124
+
125
+ @app.post("/start-processing")
126
+ async def start_processing(background_tasks: BackgroundTasks):
127
+ """Start the RAR processing pipeline in the background"""
128
+ global processing_thread
129
+
130
+ if processing_thread and processing_thread.is_alive():
131
+ return {"message": "Processing is already running", "status": "already_running"}
132
+
133
+ if processing_status["is_running"]:
134
+ return {"message": "Processing is already running", "status": "already_running"}
135
+
136
+ # Start processing in a background thread
137
+ processing_thread = threading.Thread(target=main_processing_loop)
138
+ processing_thread.daemon = True
139
+ processing_thread.start()
140
+
141
+ return {"message": "Processing started in background", "status": "started"}
142
+
143
+ @app.post("/stop-processing")
144
+ async def stop_processing():
145
+ """Stop the RAR processing pipeline"""
146
+ global processing_thread
147
+
148
+ if not processing_status["is_running"] and (not processing_thread or not processing_thread.is_alive()):
149
+ return {"message": "No processing is currently running", "status": "not_running"}
150
+
151
+ # Note: This is a graceful stop request. The actual stopping depends on the processing loop
152
+ # checking the processing_status["is_running"] flag
153
+ processing_status["is_running"] = False
154
+
155
+ return {"message": "Stop signal sent to processing pipeline", "status": "stop_requested"}
156
+
157
+ @app.get("/cursor-data/{filename}/summary")
158
+ async def get_cursor_data_summary(filename: str):
159
+ """Get a summary of cursor tracking data without the full frame data"""
160
+ if not filename.endswith(".json"):
161
+ raise HTTPException(status_code=400, detail="File must be a JSON file")
162
+
163
+ file_path = os.path.join(CURSOR_TRACKING_OUTPUT_FOLDER, filename)
164
+
165
+ if not os.path.exists(file_path):
166
+ raise HTTPException(status_code=404, detail=f"File {filename} not found")
167
+
168
  try:
169
+ with open(file_path, "r") as f:
170
+ data = json.load(f)
171
+
172
+ # Calculate summary statistics
173
+ total_frames = len(data)
174
+ cursor_active_frames = len([frame for frame in data if frame.get("cursor_active", False)])
175
+ cursor_inactive_frames = total_frames - cursor_active_frames
176
+
177
+ # Get unique templates used
178
+ templates_used = set()
179
+ confidence_scores = []
180
+
181
+ for frame in data:
182
+ if frame.get("cursor_active", False) and frame.get("template"):
183
+ templates_used.add(frame["template"])
184
+ if frame.get("confidence") is not None:
185
+ confidence_scores.append(frame["confidence"])
186
+
187
+ # Calculate confidence statistics
188
+ avg_confidence = sum(confidence_scores) / len(confidence_scores) if confidence_scores else 0
189
+ max_confidence = max(confidence_scores) if confidence_scores else 0
190
+ min_confidence = min(confidence_scores) if confidence_scores else 0
191
+
192
+ file_stats = os.stat(file_path)
193
+
194
+ summary = {
195
+ "filename": filename,
196
+ "file_size_bytes": file_stats.st_size,
197
+ "modified_time": time.ctime(file_stats.st_mtime),
198
+ "total_frames": total_frames,
199
+ "cursor_active_frames": cursor_active_frames,
200
+ "cursor_inactive_frames": cursor_inactive_frames,
201
+ "cursor_detection_rate": cursor_active_frames / total_frames if total_frames > 0 else 0,
202
+ "templates_used": list(templates_used),
203
+ "confidence_stats": {
204
+ "average": avg_confidence,
205
+ "maximum": max_confidence,
206
+ "minimum": min_confidence,
207
+ "total_measurements": len(confidence_scores)
208
+ }
209
+ }
210
+
211
+ return summary
212
+
213
+ except json.JSONDecodeError:
214
+ raise HTTPException(status_code=500, detail=f"Invalid JSON in file {filename}")
215
  except Exception as e:
216
+ raise HTTPException(status_code=500, detail=f"Error reading file {filename}: {str(e)}")
217
+
218
+ if __name__ == "__main__":
219
+ # Start the FastAPI server
220
+ print("Starting Cursor Tracking FastAPI Server...")
221
+ print("API Documentation will be available at: http://localhost:8000/docs")
222
+ print("API Root endpoint: http://localhost:8000/")
223
+
224
+ # Ensure the cursor tracking output folder exists
225
+ os.makedirs(CURSOR_TRACKING_OUTPUT_FOLDER, exist_ok=True)
226
+
227
+ uvicorn.run(
228
+ app,
229
+ host="0.0.0.0",
230
+ port=8000,
231
+ log_level="info",
232
+ reload=False # Set to False for production
233
+ )
234
 
 
 
 
235
 
 
 
 
 
 
 
236