Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -266,7 +266,35 @@ PRESET_RESOLUTIONS = {
|
|
| 266 |
"1440p (2560ร1440)": (2560, 1440),
|
| 267 |
"4K (3840ร2160)": (3840, 2160),
|
| 268 |
}
|
| 269 |
-
CHUNK_FRAMES = 121 # model hard limit per forward pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
|
| 271 |
# โโ Chunked video SR โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 272 |
@spaces.GPU(duration=100)
|
|
@@ -279,11 +307,12 @@ def generation_loop(video_path, seed=666, fps_out=24, model_size="3b",
|
|
| 279 |
def _extract_text_embeds(n_chunks):
|
| 280 |
embeds = []
|
| 281 |
for _ in range(n_chunks):
|
| 282 |
-
text_pos_embeds = torch.load('pos_emb.pt')
|
| 283 |
-
text_neg_embeds = torch.load('neg_emb.pt')
|
| 284 |
embeds.append({"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]})
|
| 285 |
gc.collect()
|
| 286 |
-
torch.cuda.
|
|
|
|
| 287 |
return embeds
|
| 288 |
|
| 289 |
def cut_video_to_model(video, sp_size):
|
|
@@ -338,6 +367,12 @@ def generation_loop(video_path, seed=666, fps_out=24, model_size="3b",
|
|
| 338 |
res_w = int(in_W * scale)
|
| 339 |
print(f"Target resolution: {res_w}ร{res_h} (mode={res_mode})")
|
| 340 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
target_resolution = (res_h * res_w) ** 0.5
|
| 342 |
|
| 343 |
def make_transform(target_res):
|
|
@@ -379,14 +414,20 @@ def generation_loop(video_path, seed=666, fps_out=24, model_size="3b",
|
|
| 379 |
return output_dir, None, output_dir
|
| 380 |
|
| 381 |
# โโ Chunked video processing โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 382 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
frame_chunks = []
|
| 384 |
-
for start in range(0, T_total,
|
| 385 |
-
end = min(start +
|
| 386 |
frame_chunks.append(full_video[start:end]) # each: (t_chunk, C, H, W)
|
| 387 |
|
| 388 |
n_chunks = len(frame_chunks)
|
| 389 |
-
print(f"Processing {n_chunks} chunk(s) of up to {
|
| 390 |
text_embeds_list = _extract_text_embeds(n_chunks)
|
| 391 |
|
| 392 |
all_output_frames = [] # will collect numpy uint8 frames
|
|
@@ -394,41 +435,64 @@ def generation_loop(video_path, seed=666, fps_out=24, model_size="3b",
|
|
| 394 |
for chunk_idx, (chunk_frames, text_embeds) in enumerate(zip(frame_chunks, text_embeds_list)):
|
| 395 |
print(f" Chunk {chunk_idx+1}/{n_chunks}: {chunk_frames.shape[0]} frames")
|
| 396 |
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
# Pad to model alignment
|
| 402 |
-
cond_padded = cut_video_to_model(cond, sp_size)
|
| 403 |
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
sample =
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 432 |
|
| 433 |
# โโ Concatenate chunks and write โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 434 |
import numpy as np
|
|
|
|
| 266 |
"1440p (2560ร1440)": (2560, 1440),
|
| 267 |
"4K (3840ร2160)": (3840, 2160),
|
| 268 |
}
|
| 269 |
+
CHUNK_FRAMES = 121 # absolute model hard limit per forward pass
|
| 270 |
+
|
| 271 |
+
def _choose_safe_chunk_frames(h: int, w: int, requested: int = CHUNK_FRAMES) -> int:
|
| 272 |
+
"""
|
| 273 |
+
Pick a safer temporal chunk size for high-resolution videos to avoid allocator/NVML crashes.
|
| 274 |
+
720p can usually use the full 121 frames; above that we shrink aggressively.
|
| 275 |
+
"""
|
| 276 |
+
pixels = int(h) * int(w)
|
| 277 |
+
if pixels >= 3840 * 2160: # 4K+
|
| 278 |
+
return min(requested, 8)
|
| 279 |
+
if pixels >= 2560 * 1440: # 1440p
|
| 280 |
+
return min(requested, 12)
|
| 281 |
+
if pixels >= 1920 * 1080: # 1080p
|
| 282 |
+
return min(requested, 16)
|
| 283 |
+
if pixels >= 1280 * 720: # 720p
|
| 284 |
+
return min(requested, 32)
|
| 285 |
+
return min(requested, 64)
|
| 286 |
+
|
| 287 |
+
def _is_cuda_memory_error(exc: BaseException) -> bool:
|
| 288 |
+
msg = str(exc)
|
| 289 |
+
keys = (
|
| 290 |
+
"out of memory",
|
| 291 |
+
"cuda out of memory",
|
| 292 |
+
"cudacachingallocator",
|
| 293 |
+
"nvml_success == r internal assert failed",
|
| 294 |
+
"allocator",
|
| 295 |
+
)
|
| 296 |
+
msg_low = msg.lower()
|
| 297 |
+
return any(k in msg_low for k in keys)
|
| 298 |
|
| 299 |
# โโ Chunked video SR โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 300 |
@spaces.GPU(duration=100)
|
|
|
|
| 307 |
def _extract_text_embeds(n_chunks):
|
| 308 |
embeds = []
|
| 309 |
for _ in range(n_chunks):
|
| 310 |
+
text_pos_embeds = torch.load('pos_emb.pt', map_location='cpu', weights_only=True)
|
| 311 |
+
text_neg_embeds = torch.load('neg_emb.pt', map_location='cpu', weights_only=True)
|
| 312 |
embeds.append({"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]})
|
| 313 |
gc.collect()
|
| 314 |
+
if torch.cuda.is_available():
|
| 315 |
+
torch.cuda.empty_cache()
|
| 316 |
return embeds
|
| 317 |
|
| 318 |
def cut_video_to_model(video, sp_size):
|
|
|
|
| 367 |
res_w = int(in_W * scale)
|
| 368 |
print(f"Target resolution: {res_w}ร{res_h} (mode={res_mode})")
|
| 369 |
|
| 370 |
+
if is_video and (res_h * res_w) > (1920 * 1080):
|
| 371 |
+
print(
|
| 372 |
+
"โ ๏ธ High-memory mode detected. 2K/4K video restoration is very likely to fail on limited GPU "
|
| 373 |
+
"memory; the code will use smaller temporal chunks automatically."
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
target_resolution = (res_h * res_w) ** 0.5
|
| 377 |
|
| 378 |
def make_transform(target_res):
|
|
|
|
| 414 |
return output_dir, None, output_dir
|
| 415 |
|
| 416 |
# โโ Chunked video processing โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 417 |
+
safe_chunk_frames = _choose_safe_chunk_frames(res_h, res_w, CHUNK_FRAMES)
|
| 418 |
+
if safe_chunk_frames != CHUNK_FRAMES:
|
| 419 |
+
print(
|
| 420 |
+
f"Reducing chunk size from {CHUNK_FRAMES} to {safe_chunk_frames} "
|
| 421 |
+
f"for safer memory usage at {res_w}ร{res_h}."
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
frame_chunks = []
|
| 425 |
+
for start in range(0, T_total, safe_chunk_frames):
|
| 426 |
+
end = min(start + safe_chunk_frames, T_total)
|
| 427 |
frame_chunks.append(full_video[start:end]) # each: (t_chunk, C, H, W)
|
| 428 |
|
| 429 |
n_chunks = len(frame_chunks)
|
| 430 |
+
print(f"Processing {n_chunks} chunk(s) of up to {safe_chunk_frames} frames each โฆ")
|
| 431 |
text_embeds_list = _extract_text_embeds(n_chunks)
|
| 432 |
|
| 433 |
all_output_frames = [] # will collect numpy uint8 frames
|
|
|
|
| 435 |
for chunk_idx, (chunk_frames, text_embeds) in enumerate(zip(frame_chunks, text_embeds_list)):
|
| 436 |
print(f" Chunk {chunk_idx+1}/{n_chunks}: {chunk_frames.shape[0]} frames")
|
| 437 |
|
| 438 |
+
cond = None
|
| 439 |
+
cond_padded = None
|
| 440 |
+
latent = None
|
| 441 |
+
sample = None
|
|
|
|
|
|
|
| 442 |
|
| 443 |
+
try:
|
| 444 |
+
# Transform to model input space
|
| 445 |
+
cond = video_transform(chunk_frames.to(torch.device("cuda"), non_blocking=True))
|
| 446 |
+
ori_length = cond.size(1)
|
| 447 |
+
|
| 448 |
+
# Pad to model alignment
|
| 449 |
+
cond_padded = cut_video_to_model(cond, sp_size)
|
| 450 |
+
|
| 451 |
+
# Move text embeds to GPU lazily right before use
|
| 452 |
+
for i, emb in enumerate(text_embeds["texts_pos"]):
|
| 453 |
+
text_embeds["texts_pos"][i] = emb.to("cuda", non_blocking=True)
|
| 454 |
+
for i, emb in enumerate(text_embeds["texts_neg"]):
|
| 455 |
+
text_embeds["texts_neg"][i] = emb.to("cuda", non_blocking=True)
|
| 456 |
+
|
| 457 |
+
# Encode โ diffuse โ decode
|
| 458 |
+
latent = runner.vae_encode([cond_padded])
|
| 459 |
+
sample = generation_step(runner, text_embeds, cond_latents=latent)[0]
|
| 460 |
+
|
| 461 |
+
# Trim padding
|
| 462 |
+
if ori_length < sample.shape[0]:
|
| 463 |
+
sample = sample[:ori_length]
|
| 464 |
+
|
| 465 |
+
# Color fix
|
| 466 |
+
input_pixel = rearrange(cond, "c t h w -> t c h w")
|
| 467 |
+
if use_colorfix:
|
| 468 |
+
sample = wavelet_reconstruction(sample.to("cpu"), input_pixel[:sample.size(0)].to("cpu"))
|
| 469 |
+
else:
|
| 470 |
+
sample = sample.to("cpu")
|
| 471 |
+
|
| 472 |
+
# Convert to uint8 numpy (T, H, W, C)
|
| 473 |
+
sample = rearrange(sample, "t c h w -> t h w c")
|
| 474 |
+
sample = sample.clip(-1,1).mul_(0.5).add_(0.5).mul_(255).round().to(torch.uint8).numpy()
|
| 475 |
+
all_output_frames.append(sample)
|
| 476 |
+
|
| 477 |
+
except RuntimeError as e:
|
| 478 |
+
if _is_cuda_memory_error(e):
|
| 479 |
+
raise RuntimeError(
|
| 480 |
+
f"GPU memoryไธ่ถณ๏ผๅฝๅๅ่พจ็ {res_w}ร{res_h}ใๅๅ {chunk_frames.shape[0]} ๅธงไป็ถ่ถ
ๅบๆพๅญใ"
|
| 481 |
+
|
| 482 |
+
f"่ฏทๆนไธบๆดไฝ่พๅบๅ่พจ็๏ผๅปบ่ฎฎ 720p/1080p๏ผใๆดๅฐ upscale_factor๏ผๆ็ปง็ปญ้ไฝ safe_chunk_framesใ"
|
| 483 |
+
|
| 484 |
+
f"ๅๅง้่ฏฏ: {e}"
|
| 485 |
+
) from e
|
| 486 |
+
raise
|
| 487 |
+
finally:
|
| 488 |
+
del latent, cond, cond_padded, sample
|
| 489 |
+
for k in ("texts_pos", "texts_neg"):
|
| 490 |
+
for i, emb in enumerate(text_embeds[k]):
|
| 491 |
+
if isinstance(emb, torch.Tensor):
|
| 492 |
+
text_embeds[k][i] = emb.to("cpu")
|
| 493 |
+
gc.collect()
|
| 494 |
+
if torch.cuda.is_available():
|
| 495 |
+
torch.cuda.empty_cache()
|
| 496 |
|
| 497 |
# โโ Concatenate chunks and write โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 498 |
import numpy as np
|