baka999 commited on
Commit
15726ee
ยท
verified ยท
1 Parent(s): d0005e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -42
app.py CHANGED
@@ -266,7 +266,35 @@ PRESET_RESOLUTIONS = {
266
  "1440p (2560ร—1440)": (2560, 1440),
267
  "4K (3840ร—2160)": (3840, 2160),
268
  }
269
- CHUNK_FRAMES = 121 # model hard limit per forward pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
  # โ”€โ”€ Chunked video SR โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
272
  @spaces.GPU(duration=100)
@@ -279,11 +307,12 @@ def generation_loop(video_path, seed=666, fps_out=24, model_size="3b",
279
  def _extract_text_embeds(n_chunks):
280
  embeds = []
281
  for _ in range(n_chunks):
282
- text_pos_embeds = torch.load('pos_emb.pt')
283
- text_neg_embeds = torch.load('neg_emb.pt')
284
  embeds.append({"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]})
285
  gc.collect()
286
- torch.cuda.empty_cache()
 
287
  return embeds
288
 
289
  def cut_video_to_model(video, sp_size):
@@ -338,6 +367,12 @@ def generation_loop(video_path, seed=666, fps_out=24, model_size="3b",
338
  res_w = int(in_W * scale)
339
  print(f"Target resolution: {res_w}ร—{res_h} (mode={res_mode})")
340
 
 
 
 
 
 
 
341
  target_resolution = (res_h * res_w) ** 0.5
342
 
343
  def make_transform(target_res):
@@ -379,14 +414,20 @@ def generation_loop(video_path, seed=666, fps_out=24, model_size="3b",
379
  return output_dir, None, output_dir
380
 
381
  # โ”€โ”€ Chunked video processing โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
382
- # Split full_video (T, C, H, W) into chunks of CHUNK_FRAMES
 
 
 
 
 
 
383
  frame_chunks = []
384
- for start in range(0, T_total, CHUNK_FRAMES):
385
- end = min(start + CHUNK_FRAMES, T_total)
386
  frame_chunks.append(full_video[start:end]) # each: (t_chunk, C, H, W)
387
 
388
  n_chunks = len(frame_chunks)
389
- print(f"Processing {n_chunks} chunk(s) of up to {CHUNK_FRAMES} frames each โ€ฆ")
390
  text_embeds_list = _extract_text_embeds(n_chunks)
391
 
392
  all_output_frames = [] # will collect numpy uint8 frames
@@ -394,41 +435,64 @@ def generation_loop(video_path, seed=666, fps_out=24, model_size="3b",
394
  for chunk_idx, (chunk_frames, text_embeds) in enumerate(zip(frame_chunks, text_embeds_list)):
395
  print(f" Chunk {chunk_idx+1}/{n_chunks}: {chunk_frames.shape[0]} frames")
396
 
397
- # Transform to model input space
398
- cond = video_transform(chunk_frames.to(torch.device("cuda"))) # (C, t, H_out, W_out)
399
- ori_length = cond.size(1)
400
-
401
- # Pad to model alignment
402
- cond_padded = cut_video_to_model(cond, sp_size)
403
 
404
- # Move text embeds to GPU
405
- for i, emb in enumerate(text_embeds["texts_pos"]):
406
- text_embeds["texts_pos"][i] = emb.to("cuda")
407
- for i, emb in enumerate(text_embeds["texts_neg"]):
408
- text_embeds["texts_neg"][i] = emb.to("cuda")
409
-
410
- # Encode โ†’ diffuse โ†’ decode
411
- latent = runner.vae_encode([cond_padded])
412
- sample = generation_step(runner, text_embeds, cond_latents=latent)[0]
413
- # Trim padding
414
- if ori_length < sample.shape[0]:
415
- sample = sample[:ori_length]
416
-
417
- # Color fix
418
- input_pixel = rearrange(cond, "c t h w -> t c h w")
419
- if use_colorfix:
420
- sample = wavelet_reconstruction(sample.to("cpu"), input_pixel[:sample.size(0)].to("cpu"))
421
- else:
422
- sample = sample.to("cpu")
423
-
424
- # Convert to uint8 numpy (T, H, W, C)
425
- sample = rearrange(sample, "t c h w -> t h w c")
426
- sample = sample.clip(-1,1).mul_(0.5).add_(0.5).mul_(255).round().to(torch.uint8).numpy()
427
- all_output_frames.append(sample)
428
-
429
- del latent, cond, cond_padded
430
- gc.collect()
431
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
 
433
  # โ”€โ”€ Concatenate chunks and write โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
434
  import numpy as np
 
266
  "1440p (2560ร—1440)": (2560, 1440),
267
  "4K (3840ร—2160)": (3840, 2160),
268
  }
269
+ CHUNK_FRAMES = 121 # absolute model hard limit per forward pass
270
+
271
+ def _choose_safe_chunk_frames(h: int, w: int, requested: int = CHUNK_FRAMES) -> int:
272
+ """
273
+ Pick a safer temporal chunk size for high-resolution videos to avoid allocator/NVML crashes.
274
+ 720p can usually use the full 121 frames; above that we shrink aggressively.
275
+ """
276
+ pixels = int(h) * int(w)
277
+ if pixels >= 3840 * 2160: # 4K+
278
+ return min(requested, 8)
279
+ if pixels >= 2560 * 1440: # 1440p
280
+ return min(requested, 12)
281
+ if pixels >= 1920 * 1080: # 1080p
282
+ return min(requested, 16)
283
+ if pixels >= 1280 * 720: # 720p
284
+ return min(requested, 32)
285
+ return min(requested, 64)
286
+
287
+ def _is_cuda_memory_error(exc: BaseException) -> bool:
288
+ msg = str(exc)
289
+ keys = (
290
+ "out of memory",
291
+ "cuda out of memory",
292
+ "cudacachingallocator",
293
+ "nvml_success == r internal assert failed",
294
+ "allocator",
295
+ )
296
+ msg_low = msg.lower()
297
+ return any(k in msg_low for k in keys)
298
 
299
  # โ”€โ”€ Chunked video SR โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
300
  @spaces.GPU(duration=100)
 
307
  def _extract_text_embeds(n_chunks):
308
  embeds = []
309
  for _ in range(n_chunks):
310
+ text_pos_embeds = torch.load('pos_emb.pt', map_location='cpu', weights_only=True)
311
+ text_neg_embeds = torch.load('neg_emb.pt', map_location='cpu', weights_only=True)
312
  embeds.append({"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]})
313
  gc.collect()
314
+ if torch.cuda.is_available():
315
+ torch.cuda.empty_cache()
316
  return embeds
317
 
318
  def cut_video_to_model(video, sp_size):
 
367
  res_w = int(in_W * scale)
368
  print(f"Target resolution: {res_w}ร—{res_h} (mode={res_mode})")
369
 
370
+ if is_video and (res_h * res_w) > (1920 * 1080):
371
+ print(
372
+ "โš ๏ธ High-memory mode detected. 2K/4K video restoration is very likely to fail on limited GPU "
373
+ "memory; the code will use smaller temporal chunks automatically."
374
+ )
375
+
376
  target_resolution = (res_h * res_w) ** 0.5
377
 
378
  def make_transform(target_res):
 
414
  return output_dir, None, output_dir
415
 
416
  # โ”€โ”€ Chunked video processing โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
417
+ safe_chunk_frames = _choose_safe_chunk_frames(res_h, res_w, CHUNK_FRAMES)
418
+ if safe_chunk_frames != CHUNK_FRAMES:
419
+ print(
420
+ f"Reducing chunk size from {CHUNK_FRAMES} to {safe_chunk_frames} "
421
+ f"for safer memory usage at {res_w}ร—{res_h}."
422
+ )
423
+
424
  frame_chunks = []
425
+ for start in range(0, T_total, safe_chunk_frames):
426
+ end = min(start + safe_chunk_frames, T_total)
427
  frame_chunks.append(full_video[start:end]) # each: (t_chunk, C, H, W)
428
 
429
  n_chunks = len(frame_chunks)
430
+ print(f"Processing {n_chunks} chunk(s) of up to {safe_chunk_frames} frames each โ€ฆ")
431
  text_embeds_list = _extract_text_embeds(n_chunks)
432
 
433
  all_output_frames = [] # will collect numpy uint8 frames
 
435
  for chunk_idx, (chunk_frames, text_embeds) in enumerate(zip(frame_chunks, text_embeds_list)):
436
  print(f" Chunk {chunk_idx+1}/{n_chunks}: {chunk_frames.shape[0]} frames")
437
 
438
+ cond = None
439
+ cond_padded = None
440
+ latent = None
441
+ sample = None
 
 
442
 
443
+ try:
444
+ # Transform to model input space
445
+ cond = video_transform(chunk_frames.to(torch.device("cuda"), non_blocking=True))
446
+ ori_length = cond.size(1)
447
+
448
+ # Pad to model alignment
449
+ cond_padded = cut_video_to_model(cond, sp_size)
450
+
451
+ # Move text embeds to GPU lazily right before use
452
+ for i, emb in enumerate(text_embeds["texts_pos"]):
453
+ text_embeds["texts_pos"][i] = emb.to("cuda", non_blocking=True)
454
+ for i, emb in enumerate(text_embeds["texts_neg"]):
455
+ text_embeds["texts_neg"][i] = emb.to("cuda", non_blocking=True)
456
+
457
+ # Encode โ†’ diffuse โ†’ decode
458
+ latent = runner.vae_encode([cond_padded])
459
+ sample = generation_step(runner, text_embeds, cond_latents=latent)[0]
460
+
461
+ # Trim padding
462
+ if ori_length < sample.shape[0]:
463
+ sample = sample[:ori_length]
464
+
465
+ # Color fix
466
+ input_pixel = rearrange(cond, "c t h w -> t c h w")
467
+ if use_colorfix:
468
+ sample = wavelet_reconstruction(sample.to("cpu"), input_pixel[:sample.size(0)].to("cpu"))
469
+ else:
470
+ sample = sample.to("cpu")
471
+
472
+ # Convert to uint8 numpy (T, H, W, C)
473
+ sample = rearrange(sample, "t c h w -> t h w c")
474
+ sample = sample.clip(-1,1).mul_(0.5).add_(0.5).mul_(255).round().to(torch.uint8).numpy()
475
+ all_output_frames.append(sample)
476
+
477
+ except RuntimeError as e:
478
+ if _is_cuda_memory_error(e):
479
+ raise RuntimeError(
480
+ f"GPU memoryไธ่ถณ๏ผšๅฝ“ๅ‰ๅˆ†่พจ็އ {res_w}ร—{res_h}ใ€ๅˆ†ๅ— {chunk_frames.shape[0]} ๅธงไป็„ถ่ถ…ๅ‡บๆ˜พๅญ˜ใ€‚"
481
+
482
+ f"่ฏทๆ”นไธบๆ›ดไฝŽ่พ“ๅ‡บๅˆ†่พจ็އ๏ผˆๅปบ่ฎฎ 720p/1080p๏ผ‰ใ€ๆ›ดๅฐ upscale_factor๏ผŒๆˆ–็ปง็ปญ้™ไฝŽ safe_chunk_framesใ€‚"
483
+
484
+ f"ๅŽŸๅง‹้”™่ฏฏ: {e}"
485
+ ) from e
486
+ raise
487
+ finally:
488
+ del latent, cond, cond_padded, sample
489
+ for k in ("texts_pos", "texts_neg"):
490
+ for i, emb in enumerate(text_embeds[k]):
491
+ if isinstance(emb, torch.Tensor):
492
+ text_embeds[k][i] = emb.to("cpu")
493
+ gc.collect()
494
+ if torch.cuda.is_available():
495
+ torch.cuda.empty_cache()
496
 
497
  # โ”€โ”€ Concatenate chunks and write โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
498
  import numpy as np