Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
fc0bb07
1
Parent(s):
9ced953
main.py
CHANGED
|
@@ -13,6 +13,7 @@ import os
|
|
| 13 |
import time
|
| 14 |
from typing import Any, Dict
|
| 15 |
from ldm.models.diffusion.ddpm import LatentDiffusion, DDIMSampler
|
|
|
|
| 16 |
|
| 17 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 18 |
torch.backends.cudnn.allow_tf32 = True
|
|
@@ -74,6 +75,9 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
| 74 |
|
| 75 |
# Add this at the top with other global variables
|
| 76 |
|
|
|
|
|
|
|
|
|
|
| 77 |
def prepare_model_inputs(
|
| 78 |
previous_frame: torch.Tensor,
|
| 79 |
hidden_states: Any,
|
|
@@ -110,11 +114,20 @@ def prepare_model_inputs(
|
|
| 110 |
return inputs
|
| 111 |
|
| 112 |
@torch.no_grad()
|
| 113 |
-
def process_frame(
|
| 114 |
model: LatentDiffusion,
|
| 115 |
inputs: Dict[str, torch.Tensor]
|
| 116 |
) -> Tuple[torch.Tensor, np.ndarray, Any, Dict[str, float]]:
|
| 117 |
"""Process a single frame through the model."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
timing = {}
|
| 119 |
# Temporal encoding
|
| 120 |
start = time.perf_counter()
|
|
@@ -136,7 +149,10 @@ def process_frame(
|
|
| 136 |
# Decoding
|
| 137 |
start = time.perf_counter()
|
| 138 |
sample = sample_latent * DATA_NORMALIZATION['std'] + DATA_NORMALIZATION['mean']
|
|
|
|
|
|
|
| 139 |
time.sleep(10)
|
|
|
|
| 140 |
sample = model.decode_first_stage(sample)
|
| 141 |
sample = sample.squeeze(0).clamp(-1, 1)
|
| 142 |
timing['decode'] = time.perf_counter() - start
|
|
@@ -212,7 +228,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 212 |
|
| 213 |
inputs = prepare_model_inputs(previous_frame, hidden_states, x, y, is_right_click, is_left_click, list(keys_down), stoi, itos, frame_num)
|
| 214 |
print(f"[{time.perf_counter():.3f}] Starting model inference...")
|
| 215 |
-
previous_frame, sample_img, hidden_states, timing_info = process_frame(model, inputs)
|
| 216 |
timing_info['full_frame'] = time.perf_counter() - process_start_time
|
| 217 |
|
| 218 |
print(f"[{time.perf_counter():.3f}] Model inference complete. Queue size now: {input_queue.qsize()}")
|
|
|
|
| 13 |
import time
|
| 14 |
from typing import Any, Dict
|
| 15 |
from ldm.models.diffusion.ddpm import LatentDiffusion, DDIMSampler
|
| 16 |
+
import concurrent.futures
|
| 17 |
|
| 18 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 19 |
torch.backends.cudnn.allow_tf32 = True
|
|
|
|
| 75 |
|
| 76 |
# Add this at the top with other global variables
|
| 77 |
|
| 78 |
+
# Create a thread pool executor
|
| 79 |
+
thread_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
| 80 |
+
|
| 81 |
def prepare_model_inputs(
|
| 82 |
previous_frame: torch.Tensor,
|
| 83 |
hidden_states: Any,
|
|
|
|
| 114 |
return inputs
|
| 115 |
|
| 116 |
@torch.no_grad()
|
| 117 |
+
async def process_frame(
|
| 118 |
model: LatentDiffusion,
|
| 119 |
inputs: Dict[str, torch.Tensor]
|
| 120 |
) -> Tuple[torch.Tensor, np.ndarray, Any, Dict[str, float]]:
|
| 121 |
"""Process a single frame through the model."""
|
| 122 |
+
# Run the heavy computation in a separate thread
|
| 123 |
+
loop = asyncio.get_running_loop()
|
| 124 |
+
return await loop.run_in_executor(
|
| 125 |
+
thread_executor,
|
| 126 |
+
lambda: _process_frame_sync(model, inputs)
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
def _process_frame_sync(model, inputs):
|
| 130 |
+
"""Synchronous version of process_frame that runs in a thread"""
|
| 131 |
timing = {}
|
| 132 |
# Temporal encoding
|
| 133 |
start = time.perf_counter()
|
|
|
|
| 149 |
# Decoding
|
| 150 |
start = time.perf_counter()
|
| 151 |
sample = sample_latent * DATA_NORMALIZATION['std'] + DATA_NORMALIZATION['mean']
|
| 152 |
+
|
| 153 |
+
# Use time.sleep(10) here since it's in a separate thread
|
| 154 |
time.sleep(10)
|
| 155 |
+
|
| 156 |
sample = model.decode_first_stage(sample)
|
| 157 |
sample = sample.squeeze(0).clamp(-1, 1)
|
| 158 |
timing['decode'] = time.perf_counter() - start
|
|
|
|
| 228 |
|
| 229 |
inputs = prepare_model_inputs(previous_frame, hidden_states, x, y, is_right_click, is_left_click, list(keys_down), stoi, itos, frame_num)
|
| 230 |
print(f"[{time.perf_counter():.3f}] Starting model inference...")
|
| 231 |
+
previous_frame, sample_img, hidden_states, timing_info = await process_frame(model, inputs)
|
| 232 |
timing_info['full_frame'] = time.perf_counter() - process_start_time
|
| 233 |
|
| 234 |
print(f"[{time.perf_counter():.3f}] Model inference complete. Queue size now: {input_queue.qsize()}")
|