File size: 4,363 Bytes
332ad95
 
ade5830
4ec83db
 
8f43770
0d5f536
ade5830
332ad95
ade5830
54839fc
ade5830
4ec83db
 
 
ade5830
 
 
 
29dbc6d
4ec83db
 
 
ade5830
4ec83db
4a1c292
ade5830
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332ad95
4ec83db
332ad95
4ec83db
 
 
 
 
 
 
 
 
 
 
 
8f43770
4ec83db
 
 
ade5830
332ad95
5ba4a34
332ad95
6625197
ade5830
 
 
5ba4a34
8f43770
4ec83db
8f43770
 
 
 
 
0d5f536
5ba4a34
332ad95
0d5f536
 
 
 
 
332ad95
 
0d5f536
332ad95
0d5f536
332ad95
 
5ba4a34
332ad95
0d5f536
5ba4a34
332ad95
 
 
0d5f536
332ad95
 
 
8f43770
 
 
4ec83db
8f43770
 
 
5ba4a34
332ad95
 
ade5830
332ad95
 
 
4ec83db
 
 
 
5ba4a34
 
4ec83db
332ad95
 
ade5830
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import tempfile
import torch
import torchvision.transforms as T
from PIL import Image
import imageio.v3 as iio
import imageio
from tqdm import tqdm
import gradio as gr
import spaces
from diffusers import AsymmetricAutoencoderKL

# -------------------------------------------------------------
# Настройки модели
# -------------------------------------------------------------
MODEL_ID = "AiArtLab/simplevae2x"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16


# -------------------------------------------------------------
# Загрузка VAE
# -------------------------------------------------------------
def load_vae(model_id=MODEL_ID, device=DEVICE):
    last_err = None
    for attempt in (None, "vae5"):
        try:
            if attempt is None:
                vae = AsymmetricAutoencoderKL.from_pretrained(model_id, torch_dtype=DTYPE)
            else:
                vae = AsymmetricAutoencoderKL.from_pretrained(model_id, subfolder=attempt, torch_dtype=DTYPE)
            vae.to(device)
            vae.eval().half()
            return vae
        except Exception as e:
            last_err = e
    raise RuntimeError(f"Failed to load VAE {model_id}: {last_err}")

_vae = None
def get_vae():
    global _vae
    if _vae is None:
        _vae = load_vae()
    return _vae

# -------------------------------------------------------------
# Апскейл одного кадра через VAE
# -------------------------------------------------------------
def upscale_frame_via_vae(frame, vae):
    img = Image.fromarray(frame).convert("RGB")
    tfm = T.Compose([
        T.ToTensor(),
        T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])
    t = tfm(img).unsqueeze(0).to(device=DEVICE, dtype=DTYPE)

    with torch.no_grad():
        enc = vae.encode(t)
        lat = enc.latent_dist.mean
        dec = vae.decode(lat).sample

    x = (dec.clamp(-1, 1) + 1) * 127.5
    x = x.round().to(torch.uint8).squeeze(0).permute(1, 2, 0).cpu().numpy()
    return x

# -------------------------------------------------------------
# Основная функция апскейла видео (без аудио)
# -------------------------------------------------------------
@spaces.GPU
def upscale_video(video_file):
    vae = get_vae()
    tmp_dir = tempfile.mkdtemp()
    output_video_path = os.path.join(tmp_dir, "upscaled_video.mp4")

    # Получаем путь
    if isinstance(video_file, dict):
        video_path = video_file.get("name")
    else:
        video_path = video_file

    # Извлечение FPS и длительности через ffmpeg.probe
    import ffmpeg
    try:
        probe = ffmpeg.probe(video_path)
        video_stream = next(s for s in probe["streams"] if s["codec_type"] == "video")
        fps_str = video_stream["r_frame_rate"]
        fps = eval(fps_str)
        duration = float(video_stream.get("duration", 5.0))
    except Exception as e:
        print("⚠️ Не удалось извлечь метаданные:", e)
        fps, duration = 24, 5.0

    duration = min(duration, 5.0)
    print(f"🎞️ Video FPS: {fps}, duration: {duration:.2f}s")

    # Читаем и апскейлим кадры
    reader = iio.imiter(video_path)
    writer = imageio.get_writer(
        output_video_path,
        fps=fps,
        codec="libx264",
        quality=8,
        format="FFMPEG"
    )

    max_frames = int(fps * duration)
    for i, frame in enumerate(tqdm(reader, desc="Upscaling frames")):
        if i >= max_frames:
            break
        upscaled = upscale_frame_via_vae(frame, vae)
        writer.append_data(upscaled)

    writer.close()
    return output_video_path

# -------------------------------------------------------------
# Gradio UI
# -------------------------------------------------------------
demo = gr.Interface(
    fn=upscale_video,
    inputs=gr.Video(label="🎥 Upload a video"),
    outputs=gr.Video(label="🎬 Upscaled video"),
    title="🧠 Asymmetric VAE 2× Video Upscaler",
    description=(
        "Апскейлит видео (первые 5 секунд) через [`AiArtLab/simplevae2x`](https://huggingface.co/AiArtLab/simplevae2x). "
        "Аудио не сохраняется."
    ),
    allow_flagging="never",
)

if __name__ == "__main__":
    demo.launch()