File size: 27,437 Bytes
42f2c22
 
 
 
 
 
 
 
 
 
 
 
 
6cbcc74
a6ea067
6cbcc74
42f2c22
 
 
 
 
 
 
 
 
 
c5715dc
42f2c22
 
 
 
 
 
 
 
 
 
 
657048f
42f2c22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6237836
e3b4db8
 
367bee6
67c0e7b
657048f
d0b8777
367bee6
e276ced
 
 
 
e3b4db8
de5b2a1
 
 
 
 
 
e3b4db8
657048f
e3b4db8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42f2c22
 
 
 
 
 
 
 
c3a2d6e
 
 
 
 
 
42f2c22
367bee6
c3a2d6e
42f2c22
c3a2d6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42f2c22
 
 
 
 
 
d62797d
367bee6
c3a2d6e
 
367bee6
d62797d
367bee6
42f2c22
a3e4180
 
 
 
 
 
 
 
 
 
 
 
 
42f2c22
1fd3071
 
 
42f2c22
273be34
c3a2d6e
 
 
 
 
42f2c22
 
 
c3a2d6e
1fd3071
 
c3a2d6e
42f2c22
 
 
 
 
 
273be34
42f2c22
 
512f3c8
42f2c22
 
 
 
 
 
 
 
 
 
 
 
512f3c8
42f2c22
 
512f3c8
42f2c22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273be34
42f2c22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0005e8
 
 
 
 
 
 
 
15726ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0005e8
 
273be34
d0005e8
 
 
 
c3a2d6e
42f2c22
d0005e8
 
 
15726ee
 
d0005e8
42f2c22
15726ee
 
d0005e8
42f2c22
d0005e8
 
 
42f2c22
d0005e8
 
 
42f2c22
d0005e8
 
 
 
 
42f2c22
d0005e8
42f2c22
 
 
 
d0005e8
42f2c22
c9102eb
42f2c22
d0005e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15726ee
 
 
 
 
 
d0005e8
 
 
 
 
42f2c22
 
 
 
d0005e8
42f2c22
d0005e8
42f2c22
d0005e8
 
 
 
 
 
 
 
42f2c22
d0005e8
42f2c22
d0005e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15726ee
 
 
 
 
 
 
d0005e8
15726ee
 
d0005e8
 
 
15726ee
d0005e8
 
 
 
 
 
 
15726ee
 
 
 
7fe120b
15726ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0005e8
 
 
 
 
 
 
 
 
 
42f2c22
 
0d8f0d4
 
 
 
 
 
d0005e8
0d8f0d4
 
 
 
 
 
d0005e8
0d8f0d4
4e38b01
d0005e8
c3a2d6e
 
d0005e8
 
c3a2d6e
 
0d8f0d4
d0005e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d8f0d4
d0005e8
 
 
0d8f0d4
d0005e8
c3a2d6e
d0005e8
 
 
 
c3a2d6e
 
0d8f0d4
 
 
 
d0005e8
 
 
0d8f0d4
d0005e8
0d8f0d4
 
 
 
 
 
 
 
 
 
 
 
 
6571c22
 
0d8f0d4
 
 
 
 
 
 
ad4fb65
 
 
 
0d8f0d4
ad4fb65
0d8f0d4
ad4fb65
 
 
 
0d8f0d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42f2c22
ab75ac1
c3a2d6e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# //
# // Licensed under the Apache License, Version 2.0 (the "License");
# // you may not use this file except in compliance with the License.
# // You may obtain a copy of the License at
# //
# //     http://www.apache.org/licenses/LICENSE-2.0
# //
# // Unless required by applicable law or agreed to in writing, software
# // distributed under the License is distributed on an "AS IS" BASIS,
# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# // See the License for the specific language governing permissions and
# // limitations under the License.
import spaces
import subprocess

import os
import torch
import mediapy
from einops import rearrange
from omegaconf import OmegaConf
print(os.getcwd())
import datetime
from tqdm import tqdm
import gc

from data.image.transforms.divisible_crop import DivisibleCrop
from data.image.transforms.na_resize import NaResize
from data.video.transforms.rearrange import Rearrange
if os.path.exists("./projects/video_diffusion_sr/color_fix.py"):
    from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
    use_colorfix=True
else:
    use_colorfix = False
    print('Note!!!!!! Color fix is not avaliable!')
from torchvision.transforms import Compose, Lambda, Normalize
from torchvision.io.video import read_video
import argparse
from PIL import Image

from common.distributed import (
    get_device,
    init_torch,
)

from common.distributed.advanced import (
    get_data_parallel_rank,
    get_data_parallel_world_size,
    get_sequence_parallel_rank,
    get_sequence_parallel_world_size,
    init_sequence_parallel,
)

from projects.video_diffusion_sr.infer import VideoDiffusionInfer
from common.config import load_config
from common.distributed.ops import sync_data
from common.seed import set_seed
from common.partition import partition_by_groups, partition_by_size

import gradio as gr
from pathlib import Path
from urllib.parse import urlparse
from torch.hub import download_url_to_file, get_dir
import shlex
import uuid
import mimetypes
import torchvision.transforms as T

os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12355"
os.environ["RANK"] = str(0)
os.environ["WORLD_SIZE"] = str(1)

subprocess.run(
    "pip install flash-attn --no-build-isolation",
    env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
    shell=True,
)

def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
    """Load file from http url, will download models if necessary.

    Reference: https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py

    Args:
        url (str): URL to be downloaded.
        model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
            Default: None.
        progress (bool): Whether to show the download progress. Default: True.
        file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.

    Returns:
        str: The path to the downloaded file.
    """
    if model_dir is None:  # use the pytorch hub_dir
        hub_dir = get_dir()
        model_dir = os.path.join(hub_dir, 'checkpoints')

    os.makedirs(model_dir, exist_ok=True)

    parts = urlparse(url)
    filename = os.path.basename(parts.path)
    if file_name is not None:
        filename = file_name
    cached_file = os.path.abspath(os.path.join(model_dir, filename))
    if not os.path.exists(cached_file):
        print(f'Downloading: "{url}" to {cached_file}\n')
        download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
    return cached_file


# os.system("pip freeze")
ckpt_dir = Path('./ckpts')
if not ckpt_dir.exists():
	ckpt_dir.mkdir()

pretrain_model_url = {
    # --- 3B ---
    'vae':    'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
    'dit_3b': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
    # --- 7B ---
    'dit_7b': 'https://huggingface.co/ByteDance-Seed/SeedVR2-7B/resolve/main/seedvr2_ema_7b.pth',
    # --- shared ---
    'pos_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt',
    'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt',
    'apex':    'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/apex-0.1-cp310-cp310-linux_x86_64.whl'
}

MODEL_CONFIGS = {
    "3b": {
        "config_dir": "./configs_3b",
        "checkpoint":  "./ckpts/seedvr2_ema_3b.pth",
        "ckpt_url_key": "dit_3b",
    },
    "7b": {
        "config_dir": "./configs_7b",
        "checkpoint":  "./ckpts/seedvr2_ema_7b.pth",
        "ckpt_url_key": "dit_7b",
    },
}

def ensure_model_weights(model_size: str):
    """Download model weights on demand if not already present."""
    cfg = MODEL_CONFIGS[model_size]
    if not os.path.exists(cfg["checkpoint"]):
        print(f"Downloading {model_size.upper()} checkpoint โ€ฆ")
        load_file_from_url(
            url=pretrain_model_url[cfg["ckpt_url_key"]],
            model_dir='./ckpts/', progress=True, file_name=None
        )

# Always-needed weights (VAE + embeddings)
if not os.path.exists('./ckpts/ema_vae.pth'):
	load_file_from_url(url=pretrain_model_url['vae'], model_dir='./ckpts/', progress=True, file_name=None)
if not os.path.exists('./pos_emb.pt'):
	load_file_from_url(url=pretrain_model_url['pos_emb'], model_dir='./', progress=True, file_name=None)
if not os.path.exists('./neg_emb.pt'):
	load_file_from_url(url=pretrain_model_url['neg_emb'], model_dir='./', progress=True, file_name=None)
if not os.path.exists('./apex-0.1-cp310-cp310-linux_x86_64.whl'):
	load_file_from_url(url=pretrain_model_url['apex'], model_dir='./', progress=True, file_name=None)
# Pre-download 3B by default (7B is fetched lazily on first use)
ensure_model_weights("3b")

subprocess.run(shlex.split("pip install apex-0.1-cp310-cp310-linux_x86_64.whl"))
print(f"โœ… setup completed Apex")

# download example videos (optional โ€“ app still works if network fails)
_example_videos = {
    '01.mp4': 'https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/23_1_lq.mp4',
    '02.mp4': 'https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/28_1_lq.mp4',
    '03.mp4': 'https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/2_1_lq.mp4',
}
for _fname, _url in _example_videos.items():
    if not os.path.exists(_fname):
        try:
            torch.hub.download_url_to_file(_url, _fname)
            print(f"โœ… Downloaded example video: {_fname}")
        except Exception as _e:
            print(f"โš ๏ธ  Could not download example video {_fname}: {_e}  (skipping)")

def configure_sequence_parallel(sp_size):
    if sp_size > 1:
        init_sequence_parallel(sp_size)

@spaces.GPU(duration=100)
def configure_runner(sp_size, model_size="3b"):
    model_size = model_size.lower().strip()
    ensure_model_weights(model_size)
    mcfg = MODEL_CONFIGS[model_size]
    config_path = os.path.join(mcfg["config_dir"], 'main.yaml')
    config = load_config(config_path)
    runner = VideoDiffusionInfer(config)
    OmegaConf.set_readonly(runner.config, False)

    init_torch(cudnn_benchmark=False, timeout=datetime.timedelta(seconds=3600))
    configure_sequence_parallel(sp_size)
    runner.configure_dit_model(device="cuda", checkpoint=mcfg["checkpoint"])
    runner.configure_vae_model()
    # Set memory limit.
    if hasattr(runner.vae, "set_memory_limit"):
        runner.vae.set_memory_limit(**runner.config.vae.memory_limit)
    return runner

@spaces.GPU(duration=100)
def generation_step(runner, text_embeds_dict, cond_latents):
    def _move_to_cuda(x):
        return [i.to(torch.device("cuda")) for i in x]

    noises = [torch.randn_like(latent) for latent in cond_latents]
    aug_noises = [torch.randn_like(latent) for latent in cond_latents]
    print(f"Generating with noise shape: {noises[0].size()}.")
    noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0)
    noises, aug_noises, cond_latents = list(
        map(lambda x: _move_to_cuda(x), (noises, aug_noises, cond_latents))
    )
    cond_noise_scale = 0.1

    def _add_noise(x, aug_noise):
        t = (
            torch.tensor([1000.0], device=torch.device("cuda"))
            * cond_noise_scale
        )
        shape = torch.tensor(x.shape[1:], device=torch.device("cuda"))[None]
        t = runner.timestep_transform(t, shape)
        print(
            f"Timestep shifting from"
            f" {1000.0 * cond_noise_scale} to {t}."
        )
        x = runner.schedule.forward(x, aug_noise, t)
        return x

    conditions = [
        runner.get_condition(
            noise,
            task="sr",
            latent_blur=_add_noise(latent_blur, aug_noise),
        )
        for noise, aug_noise, latent_blur in zip(noises, aug_noises, cond_latents)
    ]

    with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True):
        video_tensors = runner.inference(
            noises=noises,
            conditions=conditions,
            dit_offload=False,
            **text_embeds_dict,
        )

    samples = [
        (
            rearrange(video[:, None], "c t h w -> t c h w")
            if video.ndim == 3
            else rearrange(video, "c t h w -> t c h w")
        )
        for video in video_tensors
    ]
    del video_tensors

    return samples


# โ”€โ”€ Resolution presets โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
PRESET_RESOLUTIONS = {
    "720p  (1280ร—720)":  (1280,  720),
    "1080p (1920ร—1080)": (1920, 1080),
    "1440p (2560ร—1440)": (2560, 1440),
    "4K    (3840ร—2160)": (3840, 2160),
}
CHUNK_FRAMES = 121   # absolute model hard limit per forward pass

def _choose_safe_chunk_frames(h: int, w: int, requested: int = CHUNK_FRAMES) -> int:
    """
    Pick a safer temporal chunk size for high-resolution videos to avoid allocator/NVML crashes.
    720p can usually use the full 121 frames; above that we shrink aggressively.
    """
    pixels = int(h) * int(w)
    if pixels >= 3840 * 2160:   # 4K+
        return min(requested, 8)
    if pixels >= 2560 * 1440:   # 1440p
        return min(requested, 12)
    if pixels >= 1920 * 1080:   # 1080p
        return min(requested, 16)
    if pixels >= 1280 * 720:    # 720p
        return min(requested, 32)
    return min(requested, 64)

def _is_cuda_memory_error(exc: BaseException) -> bool:
    msg = str(exc)
    keys = (
        "out of memory",
        "cuda out of memory",
        "cudacachingallocator",
        "nvml_success == r internal assert failed",
        "allocator",
    )
    msg_low = msg.lower()
    return any(k in msg_low for k in keys)

# โ”€โ”€ Chunked video SR โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
@spaces.GPU(duration=100)
def generation_loop(video_path, seed=666, fps_out=24, model_size="3b",
                    res_mode="Preset", preset_res="1080p (1920ร—1080)", upscale_factor=2,
                    batch_size=1, cfg_scale=1.0, cfg_rescale=0.0, sample_steps=1, sp_size=1):

    runner = configure_runner(1, model_size=model_size)

    def _extract_text_embeds(n_chunks):
        embeds = []
        for _ in range(n_chunks):
            text_pos_embeds = torch.load('pos_emb.pt', map_location='cpu', weights_only=True)
            text_neg_embeds = torch.load('neg_emb.pt', map_location='cpu', weights_only=True)
            embeds.append({"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]})
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        return embeds

    def cut_video_to_model(video, sp_size):
        """Pad temporal dim to satisfy model alignment (no hard cap โ€“ chunks are pre-split)."""
        t = video.size(1)
        if t <= 4 * sp_size:
            padding = [video[:, -1].unsqueeze(1)] * (4 * sp_size - t + 1)
            video = torch.cat([video] + padding, dim=1)
            return video
        if (t - 1) % (4 * sp_size) == 0:
            return video
        n_pad = 4 * sp_size - ((t - 1) % (4 * sp_size))
        padding = [video[:, -1].unsqueeze(1)] * n_pad
        video = torch.cat([video] + padding, dim=1)
        return video

    # โ”€โ”€ Config โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
    runner.config.diffusion.cfg.scale = cfg_scale
    runner.config.diffusion.cfg.rescale = cfg_rescale
    runner.config.diffusion.timesteps.sampling.steps = sample_steps
    runner.configure_diffusion()
    seed = seed % (2**32)
    set_seed(seed, same_across_ranks=True)
    os.makedirs('output/', exist_ok=True)

    # โ”€โ”€ Detect media type โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
    media_type, _ = mimetypes.guess_type(video_path)
    is_image = media_type and media_type.startswith("image")
    is_video = media_type and media_type.startswith("video")

    # โ”€โ”€ Read full video โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
    if is_video:
        video_data, _, video_info = read_video(os.path.join(video_path), output_format="TCHW")
        full_video = video_data / 255.0                     # (T, C, H, W)
        fps_out = float(video_info.get("video_fps", fps_out))
        T_total, _, in_H, in_W = full_video.shape
        print(f"Input video: {T_total} frames @ {fps_out:.3f} fps, {in_W}ร—{in_H}")
    else:
        img = Image.open(video_path).convert("RGB")
        img_tensor = T.ToTensor()(img).unsqueeze(0)         # (1, C, H, W)
        full_video = img_tensor
        _, _, in_H, in_W = full_video.shape
        T_total = 1
        print(f"Input image: {in_W}ร—{in_H}")

    # โ”€โ”€ Compute target resolution โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
    if res_mode == "Preset":
        res_h, res_w = PRESET_RESOLUTIONS.get(preset_res, (1920, 1080))
    else:  # Upscale Factor
        scale = float(upscale_factor)
        res_h = int(in_H * scale)
        res_w = int(in_W * scale)
    print(f"Target resolution: {res_w}ร—{res_h}  (mode={res_mode})")

    if is_video and (res_h * res_w) > (1920 * 1080):
        print(
            "โš ๏ธ High-memory mode detected. 2K/4K video restoration is very likely to fail on limited GPU "
            "memory; the code will use smaller temporal chunks automatically."
        )

    target_resolution = (res_h * res_w) ** 0.5

    def make_transform(target_res):
        return Compose([
            NaResize(resolution=target_res, mode="area", downsample_only=False),
            Lambda(lambda x: torch.clamp(x, 0.0, 1.0)),
            DivisibleCrop((16, 16)),
            Normalize(0.5, 0.5),
            Rearrange("t c h w -> c t h w"),
        ])

    video_transform = make_transform(target_resolution)

    output_dir = 'output/' + str(uuid.uuid4()) + ('.png' if is_image else '.mp4')

    # โ”€โ”€ Process image (single pass) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
    if is_image:
        img_transform = make_transform((2560 * 1440) ** 0.5)
        cond = img_transform(full_video.to(torch.device("cuda")))  # (C,1,H,W)
        ori_length = cond.size(1)
        text_embeds = _extract_text_embeds(1)[0]
        for i, emb in enumerate(text_embeds["texts_pos"]):
            text_embeds["texts_pos"][i] = emb.to("cuda")
        for i, emb in enumerate(text_embeds["texts_neg"]):
            text_embeds["texts_neg"][i] = emb.to("cuda")
        latent = runner.vae_encode([cond])
        sample = generation_step(runner, text_embeds, cond_latents=latent)[0]
        if ori_length < sample.shape[0]:
            sample = sample[:ori_length]
        input_pixel = rearrange(cond[:, None], "c t h w -> t c h w") if cond.ndim == 3 else rearrange(cond, "c t h w -> t c h w")
        if use_colorfix:
            sample = wavelet_reconstruction(sample.to("cpu"), input_pixel[:sample.size(0)].to("cpu"))
        else:
            sample = sample.to("cpu")
        sample = rearrange(sample, "t c h w -> t h w c")
        sample = sample.clip(-1,1).mul_(0.5).add_(0.5).mul_(255).round().to(torch.uint8).numpy()
        mediapy.write_image(output_dir, sample[0])
        gc.collect(); torch.cuda.empty_cache()
        return output_dir, None, output_dir

    # โ”€โ”€ Chunked video processing โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
    safe_chunk_frames = _choose_safe_chunk_frames(res_h, res_w, CHUNK_FRAMES)
    if safe_chunk_frames != CHUNK_FRAMES:
        print(
            f"Reducing chunk size from {CHUNK_FRAMES} to {safe_chunk_frames} "
            f"for safer memory usage at {res_w}ร—{res_h}."
        )

    frame_chunks = []
    for start in range(0, T_total, safe_chunk_frames):
        end = min(start + safe_chunk_frames, T_total)
        frame_chunks.append(full_video[start:end])   # each: (t_chunk, C, H, W)

    n_chunks = len(frame_chunks)
    print(f"Processing {n_chunks} chunk(s) of up to {safe_chunk_frames} frames each โ€ฆ")
    text_embeds_list = _extract_text_embeds(n_chunks)

    all_output_frames = []   # will collect numpy uint8 frames

    for chunk_idx, (chunk_frames, text_embeds) in enumerate(zip(frame_chunks, text_embeds_list)):
        print(f"  Chunk {chunk_idx+1}/{n_chunks}: {chunk_frames.shape[0]} frames")

        cond = None
        cond_padded = None
        latent = None
        sample = None

        try:
            # Transform to model input space
            cond = video_transform(chunk_frames.to(torch.device("cuda"), non_blocking=True))
            ori_length = cond.size(1)

            # Pad to model alignment
            cond_padded = cut_video_to_model(cond, sp_size)

            # Move text embeds to GPU lazily right before use
            for i, emb in enumerate(text_embeds["texts_pos"]):
                text_embeds["texts_pos"][i] = emb.to("cuda", non_blocking=True)
            for i, emb in enumerate(text_embeds["texts_neg"]):
                text_embeds["texts_neg"][i] = emb.to("cuda", non_blocking=True)

            # Encode โ†’ diffuse โ†’ decode
            latent = runner.vae_encode([cond_padded])
            sample = generation_step(runner, text_embeds, cond_latents=latent)[0]

            # Trim padding
            if ori_length < sample.shape[0]:
                sample = sample[:ori_length]

            # Color fix
            input_pixel = rearrange(cond, "c t h w -> t c h w")
            if use_colorfix:
                sample = wavelet_reconstruction(sample.to("cpu"), input_pixel[:sample.size(0)].to("cpu"))
            else:
                sample = sample.to("cpu")

            # Convert to uint8 numpy (T, H, W, C)
            sample = rearrange(sample, "t c h w -> t h w c")
            sample = sample.clip(-1,1).mul_(0.5).add_(0.5).mul_(255).round().to(torch.uint8).numpy()
            all_output_frames.append(sample)

        except RuntimeError as e:
            if _is_cuda_memory_error(e):
                raise RuntimeError(
                    f"GPU memoryไธ่ถณ๏ผšๅฝ“ๅ‰ๅˆ†่พจ็އ {res_w}ร—{res_h}ใ€ๅˆ†ๅ— {chunk_frames.shape[0]} ๅธงไป็„ถ่ถ…ๅ‡บๆ˜พๅญ˜ใ€‚"

                    f"่ฏทๆ”นไธบๆ›ดไฝŽ่พ“ๅ‡บๅˆ†่พจ็އ๏ผˆๅปบ่ฎฎ 720p/1080p๏ผ‰ใ€ๆ›ดๅฐ upscale_factor๏ผŒๆˆ–็ปง็ปญ้™ไฝŽ safe_chunk_framesใ€‚"

                    f"ๅŽŸๅง‹้”™่ฏฏ: {e}"
                ) from e
            raise
        finally:
            del latent, cond, cond_padded, sample
            for k in ("texts_pos", "texts_neg"):
                for i, emb in enumerate(text_embeds[k]):
                    if isinstance(emb, torch.Tensor):
                        text_embeds[k][i] = emb.to("cpu")
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

    # โ”€โ”€ Concatenate chunks and write โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
    import numpy as np
    final_frames = np.concatenate(all_output_frames, axis=0)
    print(f"Total output frames: {final_frames.shape[0]} @ {fps_out:.3f} fps โ†’ {output_dir}")
    mediapy.write_video(output_dir, final_frames, fps=fps_out)

    gc.collect()
    torch.cuda.empty_cache()
    return None, output_dir, output_dir


with gr.Blocks(title="SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training") as demo:
    # Top logo and title
    gr.HTML("""
        <div style='text-align:center; margin-bottom: 10px;'>
            <img src='https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/assets/seedvr_logo.png' style='height:40px;' alt='SeedVR logo'/>
        </div>
        <p><b>Official Gradio demo</b> for
        <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>
        <b>SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training</b></a>.<br>
        ๐Ÿ”ฅ <b>SeedVR2</b> is a one-step image and video restoration algorithm for real-world and AIGC content.
        </p>
    """)

    # โ”€โ”€ Row 1: inputs + model settings โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
    with gr.Row():
        input_video = gr.File(label="Upload image or video", type="filepath")

        with gr.Column():
            model_selector = gr.Radio(
                choices=["3b", "7b"], value="3b", label="Model Size",
                info="3B: faster ยท lower VRAM  |  7B: higher quality ยท more VRAM",
            )
            seed = gr.Number(label="Seed", value=666)

    # โ”€โ”€ Row 2: resolution mode โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
    with gr.Row():
        res_mode = gr.Radio(
            choices=["Preset", "Upscale Factor"],
            value="Preset",
            label="Output Resolution Mode",
            info="Preset: fixed target resolution  |  Upscale Factor: multiply input resolution",
        )

    with gr.Row():
        preset_res = gr.Dropdown(
            choices=list(PRESET_RESOLUTIONS.keys()),
            value="1080p (1920ร—1080)",
            label="Preset Resolution",
            visible=True,
        )
        upscale_factor = gr.Slider(
            minimum=1, maximum=4, step=0.5, value=2,
            label="Upscale Factor  (e.g. 2 = 2ร— width & height)",
            visible=False,
        )

    # Toggle visibility based on mode
    def toggle_res_mode(mode):
        return gr.update(visible=(mode == "Preset")), gr.update(visible=(mode == "Upscale Factor"))
    res_mode.change(toggle_res_mode, inputs=res_mode, outputs=[preset_res, upscale_factor])

    gr.Markdown("โ„น๏ธ Output FPS and total duration are **automatically matched** to the input video (full length, no frame cap).")

    # โ”€โ”€ Row 3: outputs โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
    with gr.Row():
        output_video = gr.Video(label="Output Video")
        output_image = gr.Image(label="Output Image")
        download_link = gr.File(label="Download")

    run_button = gr.Button("โ–ถ  Run Super-Resolution", variant="primary")
    run_button.click(
        fn=lambda path, sd, model, mode, preset, scale: generation_loop(
            path, sd, 24, model, mode, preset, scale
        ),
        inputs=[input_video, seed, model_selector, res_mode, preset_res, upscale_factor],
        outputs=[output_image, output_video, download_link],
    )

    # Examples
    gr.Examples(
        examples=[
            ["./01.mp4", 4, "3b", "Preset", "1080p (1920ร—1080)", 2],
            ["./02.mp4", 4, "3b", "Upscale Factor", "1080p (1920ร—1080)", 2],
            ["./03.mp4", 4, "7b", "Preset", "1440p (2560ร—1440)", 2],
        ],
        inputs=[input_video, seed, model_selector, res_mode, preset_res, upscale_factor],
    )

    # Article/Footer
    gr.HTML("""
        <hr>
        <p>If you find SeedVR helpful, please โญ the 
        <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>GitHub repository</a>:</p>

        <a href="https://github.com/ByteDance-Seed/SeedVR" target="_blank">
            <img src="https://img.shields.io/github/stars/ByteDance-Seed/SeedVR?style=social" alt="GitHub Stars">
        </a>

        <h4>Notice</h4>
        <p>This demo supports up to <b>720p and 121 frames for videos or 2k images</b>.  
        For other use cases (image restoration beyond 2K, video resolutions beyond 720p, etc), check the <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>GitHub repo</a>.</p>

        <h4>Limitations</h4>
        <p>May fail on heavy degradations or small-motion AIGC clips, causing oversharpening or poor restoration.</p>

        <h4>Citation</h4>
        <pre style="font-size: 12px;">
        @article{wang2025seedvr2,
            title={SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training},
            author={Wang, Jianyi and Lin, Shanchuan and Lin, Zhijie and Ren, Yuxi and Wei, Meng and Yue, Zongsheng and Zhou, Shangchen and Chen, Hao and Zhao, Yang and Yang, Ceyuan and Xiao, Xuefeng and Loy, Chen Change and Jiang, Lu},
            booktitle={arXiv preprint arXiv:2506.05301},
            year={2025}
        }

        @inproceedings{wang2025seedvr,
            title={SeedVR: Seeding Infinity in Diffusion Transformer Towards Generic Video Restoration},
            author={Wang, Jianyi and Lin, Zhijie and Wei, Meng and Zhao, Yang and Yang, Ceyuan and Loy, Chen Change and Jiang, Lu},
            booktitle={CVPR},
            year={2025}
        }
        </pre>

        <h4>License</h4>
        <p>Licensed under the 
        <a href="http://www.apache.org/licenses/LICENSE-2.0" target="_blank">Apache 2.0 License</a>.</p>

        <h4>Contact</h4>
        <p>Email: <b>iceclearwjy@gmail.com</b></p>

        <p>
        <a href="https://twitter.com/Iceclearwjy">
        <img src="https://img.shields.io/twitter/follow/Iceclearwjy?label=%40Iceclearwjy&style=social" alt="Twitter Follow">
        </a>
        <a href="https://github.com/IceClear">
        <img src="https://img.shields.io/github/followers/IceClear?style=social" alt="GitHub Follow">
        </a>
        </p>

        <p style="text-align:center;">
        <img src="https://visitor-badge.laobi.icu/badge?page_id=ByteDance-Seed/SeedVR" alt="visitors">
        </p>
    """)

demo.queue()
demo.launch()