import gradio as gr, subprocess, tempfile, sys, os, shutil from PIL import Image from huggingface_hub import snapshot_download import spaces, torch MODEL_REPO = "Skywork/Matrix-Game-2.0" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print("Device:", DEVICE) # ----- one-time model + code download ----- @spaces.cached def setup(): print("‣ downloading weights …") model_dir = snapshot_download(MODEL_REPO, cache_dir="model_cache") if not os.path.exists("Matrix-Game"): subprocess.check_call(["git", "clone", "https://github.com/SkyworkAI/Matrix-Game.git"]) return model_dir # ----------------------------------------- @spaces.GPU(duration=120) def run(img, frames, seed): if img is None: return None, "Upload an image first!" model_dir = setup() tmp = tempfile.mkdtemp() inp = os.path.join(tmp, "input.jpg") outd = os.path.join(tmp, "outputs") os.makedirs(outd, exist_ok=True) # down-size to <=512 px to keep VRAM happy if max(img.size) > 512: r = 512 / max(img.size) img = img.resize((int(img.size[0]*r), int(img.size[1]*r)), Image.Resampling.LANCZOS) img.save(inp) m2 = os.path.join("Matrix-Game", "Matrix-Game-2") cmd = [sys.executable, os.path.join(m2, "inference.py"), "--img_path", inp, "--output_folder", outd, "--num_output_frames", str(frames), "--seed", str(seed), "--pretrained_model_path", model_dir] print("‣ running:", " ".join(cmd)) proc = subprocess.run(cmd, capture_output=True, text=True, cwd=m2) print(proc.stdout or proc.stderr) # grab first video file we find for root, _, files in os.walk(outd): for f in files: if f.lower().endswith((".mp4", ".webm", ".mov")): final = os.path.join(root, f) shutil.move(final, "result.mp4") shutil.rmtree(tmp, ignore_errors=True) return "result.mp4", "✔ Done!" return None, "Generation failed – see logs" with gr.Blocks() as demo: gr.Markdown("# Matrix-Game 2.0 demo") with gr.Row(): with gr.Column(): img = gr.Image(label="Start frame (jpg/png)", type="pil") nfrm = gr.Slider(25, 150, 60, step=1, label="Frames") s = gr.Number(42, label="Seed") go = gr.Button("Generate") with gr.Column(): vid = gr.Video(label="Output") stat = gr.Textbox(label="Status") go.click(run, [img, nfrm, s], [vid, stat]) if __name__ == "__main__": demo.launch()