Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
e858976
1
Parent(s):
042c554
main.py
CHANGED
|
@@ -14,8 +14,8 @@ import time
|
|
| 14 |
from typing import Any, Dict
|
| 15 |
from ldm.models.diffusion.ddpm import LatentDiffusion, DDIMSampler
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
SCREEN_WIDTH = 512
|
| 20 |
SCREEN_HEIGHT = 384
|
| 21 |
NUM_SAMPLING_STEPS = 8
|
|
@@ -167,6 +167,11 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 167 |
hidden_states = None
|
| 168 |
keys_down = set() # Initialize as an empty set
|
| 169 |
frame_num = -1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
while True:
|
| 171 |
try:
|
| 172 |
# Receive user input with a timeout
|
|
@@ -176,7 +181,13 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 176 |
await websocket.send_json({"type": "heartbeat_response"})
|
| 177 |
continue
|
| 178 |
frame_num += 1
|
|
|
|
| 179 |
start_frame = time.perf_counter()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
x = data.get("x")
|
| 181 |
y = data.get("y")
|
| 182 |
is_left_click = data.get("is_left_click")
|
|
@@ -197,6 +208,9 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 197 |
# Use the provided function to print timing statistics
|
| 198 |
print_timing_stats(timing_info, frame_num)
|
| 199 |
|
|
|
|
|
|
|
|
|
|
| 200 |
img = Image.fromarray(sample_img)
|
| 201 |
buffered = io.BytesIO()
|
| 202 |
img.save(buffered, format="PNG")
|
|
@@ -217,5 +231,13 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 217 |
print(f"Error in WebSocket connection {client_id}: {e}")
|
| 218 |
|
| 219 |
finally:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
print(f"WebSocket connection closed: {client_id}")
|
| 221 |
#await websocket.close() # Ensure the WebSocket is closed
|
|
|
|
| 14 |
from typing import Any, Dict
|
| 15 |
from ldm.models.diffusion.ddpm import LatentDiffusion, DDIMSampler
|
| 16 |
|
| 17 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 18 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 19 |
SCREEN_WIDTH = 512
|
| 20 |
SCREEN_HEIGHT = 384
|
| 21 |
NUM_SAMPLING_STEPS = 8
|
|
|
|
| 167 |
hidden_states = None
|
| 168 |
keys_down = set() # Initialize as an empty set
|
| 169 |
frame_num = -1
|
| 170 |
+
|
| 171 |
+
# Start timing for global FPS calculation
|
| 172 |
+
connection_start_time = time.perf_counter()
|
| 173 |
+
frame_count = 0
|
| 174 |
+
|
| 175 |
while True:
|
| 176 |
try:
|
| 177 |
# Receive user input with a timeout
|
|
|
|
| 181 |
await websocket.send_json({"type": "heartbeat_response"})
|
| 182 |
continue
|
| 183 |
frame_num += 1
|
| 184 |
+
frame_count += 1 # Increment total frame counter
|
| 185 |
start_frame = time.perf_counter()
|
| 186 |
+
|
| 187 |
+
# Calculate global FPS
|
| 188 |
+
total_elapsed = start_frame - connection_start_time
|
| 189 |
+
global_fps = frame_count / total_elapsed if total_elapsed > 0 else 0
|
| 190 |
+
|
| 191 |
x = data.get("x")
|
| 192 |
y = data.get("y")
|
| 193 |
is_left_click = data.get("is_left_click")
|
|
|
|
| 208 |
# Use the provided function to print timing statistics
|
| 209 |
print_timing_stats(timing_info, frame_num)
|
| 210 |
|
| 211 |
+
# Print global FPS measurement
|
| 212 |
+
print(f" Global FPS: {global_fps:.2f} (total: {frame_count} frames in {total_elapsed:.2f}s)")
|
| 213 |
+
|
| 214 |
img = Image.fromarray(sample_img)
|
| 215 |
buffered = io.BytesIO()
|
| 216 |
img.save(buffered, format="PNG")
|
|
|
|
| 231 |
print(f"Error in WebSocket connection {client_id}: {e}")
|
| 232 |
|
| 233 |
finally:
|
| 234 |
+
# Print final FPS statistics when connection ends
|
| 235 |
+
if frame_num >= 0: # Only if we processed at least one frame
|
| 236 |
+
total_time = time.perf_counter() - connection_start_time
|
| 237 |
+
print(f"\nConnection {client_id} summary:")
|
| 238 |
+
print(f" Total frames processed: {frame_count}")
|
| 239 |
+
print(f" Total elapsed time: {total_time:.2f} seconds")
|
| 240 |
+
print(f" Average FPS: {frame_count/total_time:.2f}")
|
| 241 |
+
|
| 242 |
print(f"WebSocket connection closed: {client_id}")
|
| 243 |
#await websocket.close() # Ensure the WebSocket is closed
|