File size: 2,661 Bytes
3ac36b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import gradio as gr, subprocess, tempfile, sys, os, shutil
from PIL import Image
from huggingface_hub import snapshot_download
import spaces, torch

MODEL_REPO = "Skywork/Matrix-Game-2.0"
DEVICE      = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

# ----- one-time model + code download -----
@spaces.cached
def setup():
    print("‣ downloading weights …")
    model_dir = snapshot_download(MODEL_REPO, cache_dir="model_cache")
    if not os.path.exists("Matrix-Game"):
        subprocess.check_call(["git", "clone",
                               "https://github.com/SkyworkAI/Matrix-Game.git"])
    return model_dir
# -----------------------------------------

@spaces.GPU(duration=120)
def run(img, frames, seed):
    if img is None:
        return None, "Upload an image first!"
    model_dir = setup()

    tmp  = tempfile.mkdtemp()
    inp  = os.path.join(tmp, "input.jpg")
    outd = os.path.join(tmp, "outputs")
    os.makedirs(outd, exist_ok=True)

    # down-size to <=512 px to keep VRAM happy
    if max(img.size) > 512:
        r = 512 / max(img.size)
        img = img.resize((int(img.size[0]*r), int(img.size[1]*r)),
                         Image.Resampling.LANCZOS)
    img.save(inp)

    m2 = os.path.join("Matrix-Game", "Matrix-Game-2")
    cmd = [sys.executable, os.path.join(m2, "inference.py"),
           "--img_path", inp,
           "--output_folder", outd,
           "--num_output_frames", str(frames),
           "--seed", str(seed),
           "--pretrained_model_path", model_dir]

    print("‣ running:", " ".join(cmd))
    proc = subprocess.run(cmd, capture_output=True, text=True, cwd=m2)
    print(proc.stdout or proc.stderr)

    # grab first video file we find
    for root, _, files in os.walk(outd):
        for f in files:
            if f.lower().endswith((".mp4", ".webm", ".mov")):
                final = os.path.join(root, f)
                shutil.move(final, "result.mp4")
                shutil.rmtree(tmp, ignore_errors=True)
                return "result.mp4", "✔ Done!"
    return None, "Generation failed – see logs"

with gr.Blocks() as demo:
    gr.Markdown("# Matrix-Game 2.0 demo")
    with gr.Row():
        with gr.Column():
            img   = gr.Image(label="Start frame (jpg/png)", type="pil")
            nfrm  = gr.Slider(25, 150, 60, step=1, label="Frames")
            s     = gr.Number(42, label="Seed")
            go    = gr.Button("Generate")
        with gr.Column():
            vid   = gr.Video(label="Output")
            stat  = gr.Textbox(label="Status")
    go.click(run, [img, nfrm, s], [vid, stat])

if __name__ == "__main__":
    demo.launch()