Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
c0820d5
1
Parent(s):
14aa927
main.py
CHANGED
|
@@ -96,6 +96,7 @@ app = FastAPI()
|
|
| 96 |
app.mount("/static", StaticFiles(directory="static"), name="static")
|
| 97 |
|
| 98 |
# Add this at the top with other global variables
|
|
|
|
| 99 |
|
| 100 |
# Create a thread pool executor
|
| 101 |
thread_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
|
@@ -226,7 +227,9 @@ async def get():
|
|
| 226 |
# WebSocket endpoint for continuous user interaction
|
| 227 |
@app.websocket("/ws")
|
| 228 |
async def websocket_endpoint(websocket: WebSocket):
|
| 229 |
-
|
|
|
|
|
|
|
| 230 |
print(f"New WebSocket connection: {client_id}")
|
| 231 |
await websocket.accept()
|
| 232 |
|
|
@@ -350,6 +353,9 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 350 |
print(f"[{time.perf_counter():.3f}] Sending image to client...")
|
| 351 |
await websocket.send_json({"image": img_str})
|
| 352 |
print(f"[{time.perf_counter():.3f}] Image sent. Queue size before next_input: {input_queue.qsize()}")
|
|
|
|
|
|
|
|
|
|
| 353 |
finally:
|
| 354 |
is_processing = False
|
| 355 |
print(f"[{time.perf_counter():.3f}] Processing complete. Queue size before checking next input: {input_queue.qsize()}")
|
|
@@ -457,7 +463,9 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 457 |
print("WebSocket connection timed out")
|
| 458 |
|
| 459 |
except WebSocketDisconnect:
|
| 460 |
-
|
|
|
|
|
|
|
| 461 |
break
|
| 462 |
|
| 463 |
except Exception as e:
|
|
@@ -475,3 +483,44 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 475 |
print(f" Average FPS: {frame_count/total_time:.2f}")
|
| 476 |
|
| 477 |
print(f"WebSocket connection closed: {client_id}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
app.mount("/static", StaticFiles(directory="static"), name="static")
|
| 97 |
|
| 98 |
# Add this at the top with other global variables
|
| 99 |
+
connection_counter = 0
|
| 100 |
|
| 101 |
# Create a thread pool executor
|
| 102 |
thread_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
|
|
|
| 227 |
# WebSocket endpoint for continuous user interaction
|
| 228 |
@app.websocket("/ws")
|
| 229 |
async def websocket_endpoint(websocket: WebSocket):
|
| 230 |
+
global connection_counter
|
| 231 |
+
connection_counter += 1
|
| 232 |
+
client_id = f"{int(time.time())}_{connection_counter}"
|
| 233 |
print(f"New WebSocket connection: {client_id}")
|
| 234 |
await websocket.accept()
|
| 235 |
|
|
|
|
| 353 |
print(f"[{time.perf_counter():.3f}] Sending image to client...")
|
| 354 |
await websocket.send_json({"image": img_str})
|
| 355 |
print(f"[{time.perf_counter():.3f}] Image sent. Queue size before next_input: {input_queue.qsize()}")
|
| 356 |
+
|
| 357 |
+
# Log the input
|
| 358 |
+
log_interaction(client_id, data, generated_frame=sample_img)
|
| 359 |
finally:
|
| 360 |
is_processing = False
|
| 361 |
print(f"[{time.perf_counter():.3f}] Processing complete. Queue size before checking next input: {input_queue.qsize()}")
|
|
|
|
| 463 |
print("WebSocket connection timed out")
|
| 464 |
|
| 465 |
except WebSocketDisconnect:
|
| 466 |
+
# Log final EOS entry
|
| 467 |
+
log_interaction(client_id, {}, is_end_of_session=True)
|
| 468 |
+
print(f"WebSocket disconnected: {client_id}")
|
| 469 |
break
|
| 470 |
|
| 471 |
except Exception as e:
|
|
|
|
| 483 |
print(f" Average FPS: {frame_count/total_time:.2f}")
|
| 484 |
|
| 485 |
print(f"WebSocket connection closed: {client_id}")
|
| 486 |
+
|
| 487 |
+
def log_interaction(client_id, data, generated_frame=None, is_end_of_session=False):
|
| 488 |
+
"""Log user interaction and optionally the generated frame."""
|
| 489 |
+
timestamp = time.time()
|
| 490 |
+
|
| 491 |
+
# Create directory structure if it doesn't exist
|
| 492 |
+
os.makedirs("interaction_logs", exist_ok=True)
|
| 493 |
+
|
| 494 |
+
# Structure the log entry
|
| 495 |
+
log_entry = {
|
| 496 |
+
"timestamp": timestamp,
|
| 497 |
+
"client_id": client_id,
|
| 498 |
+
"is_eos": is_end_of_session
|
| 499 |
+
}
|
| 500 |
+
|
| 501 |
+
# Only include input data if this isn't an EOS token or if data is provided
|
| 502 |
+
if not is_end_of_session or data:
|
| 503 |
+
log_entry["inputs"] = {
|
| 504 |
+
"x": data.get("x"),
|
| 505 |
+
"y": data.get("y"),
|
| 506 |
+
"is_left_click": data.get("is_left_click"),
|
| 507 |
+
"is_right_click": data.get("is_right_click"),
|
| 508 |
+
"keys_down": data.get("keys_down", []),
|
| 509 |
+
"keys_up": data.get("keys_up", [])
|
| 510 |
+
}
|
| 511 |
+
else:
|
| 512 |
+
# For EOS records with empty data, just include minimal info
|
| 513 |
+
log_entry["inputs"] = None
|
| 514 |
+
|
| 515 |
+
# Save to a file (one file per session)
|
| 516 |
+
session_file = f"interaction_logs/session_{client_id}.jsonl"
|
| 517 |
+
with open(session_file, "a") as f:
|
| 518 |
+
f.write(json.dumps(log_entry) + "\n")
|
| 519 |
+
|
| 520 |
+
# Optionally save the frame if provided
|
| 521 |
+
if generated_frame is not None and not is_end_of_session:
|
| 522 |
+
frame_dir = f"interaction_logs/frames_{client_id}"
|
| 523 |
+
os.makedirs(frame_dir, exist_ok=True)
|
| 524 |
+
frame_file = f"{frame_dir}/{timestamp:.6f}.png"
|
| 525 |
+
# Save the frame as PNG
|
| 526 |
+
Image.fromarray(generated_frame).save(frame_file)
|