Spaces:
Paused
Paused
| import gradio as gr, subprocess, tempfile, sys, os, shutil | |
| from PIL import Image | |
| from huggingface_hub import snapshot_download | |
| import spaces, torch | |
| MODEL_REPO = "Skywork/Matrix-Game-2.0" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print("Device:", DEVICE) | |
| # ----- one-time model + code download ----- | |
| def setup(): | |
| print("‣ downloading weights …") | |
| model_dir = snapshot_download(MODEL_REPO, cache_dir="model_cache") | |
| if not os.path.exists("Matrix-Game"): | |
| subprocess.check_call(["git", "clone", | |
| "https://github.com/SkyworkAI/Matrix-Game.git"]) | |
| return model_dir | |
| # ----------------------------------------- | |
| def run(img, frames, seed): | |
| if img is None: | |
| return None, "Upload an image first!" | |
| model_dir = setup() | |
| tmp = tempfile.mkdtemp() | |
| inp = os.path.join(tmp, "input.jpg") | |
| outd = os.path.join(tmp, "outputs") | |
| os.makedirs(outd, exist_ok=True) | |
| # down-size to <=512 px to keep VRAM happy | |
| if max(img.size) > 512: | |
| r = 512 / max(img.size) | |
| img = img.resize((int(img.size[0]*r), int(img.size[1]*r)), | |
| Image.Resampling.LANCZOS) | |
| img.save(inp) | |
| m2 = os.path.join("Matrix-Game", "Matrix-Game-2") | |
| cmd = [sys.executable, os.path.join(m2, "inference.py"), | |
| "--img_path", inp, | |
| "--output_folder", outd, | |
| "--num_output_frames", str(frames), | |
| "--seed", str(seed), | |
| "--pretrained_model_path", model_dir] | |
| print("‣ running:", " ".join(cmd)) | |
| proc = subprocess.run(cmd, capture_output=True, text=True, cwd=m2) | |
| print(proc.stdout or proc.stderr) | |
| # grab first video file we find | |
| for root, _, files in os.walk(outd): | |
| for f in files: | |
| if f.lower().endswith((".mp4", ".webm", ".mov")): | |
| final = os.path.join(root, f) | |
| shutil.move(final, "result.mp4") | |
| shutil.rmtree(tmp, ignore_errors=True) | |
| return "result.mp4", "✔ Done!" | |
| return None, "Generation failed – see logs" | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Matrix-Game 2.0 demo") | |
| with gr.Row(): | |
| with gr.Column(): | |
| img = gr.Image(label="Start frame (jpg/png)", type="pil") | |
| nfrm = gr.Slider(25, 150, 60, step=1, label="Frames") | |
| s = gr.Number(42, label="Seed") | |
| go = gr.Button("Generate") | |
| with gr.Column(): | |
| vid = gr.Video(label="Output") | |
| stat = gr.Textbox(label="Status") | |
| go.click(run, [img, nfrm, s], [vid, stat]) | |
| if __name__ == "__main__": | |
| demo.launch() | |