Matrix-Game-2 / app.py
laloadrianmorales's picture
Update app.py
c37c52c verified
raw
history blame
4.02 kB
import gradio as gr
import torch
import spaces
from PIL import Image
import tempfile
import subprocess
import sys
import os
from huggingface_hub import snapshot_download
import shutil
# Configuration
MODEL_REPO = "Skywork/Matrix-Game-2.0"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {DEVICE}")
print(f"CUDA Available: {torch.cuda.is_available()}")
# Global variables
model_loaded = False
model_path = None
def download_and_setup_model():
global model_loaded, model_path
if model_loaded:
return True
try:
print("Downloading model...")
model_path = snapshot_download(
repo_id=MODEL_REPO,
cache_dir="./model_cache"
)
if not os.path.exists("Matrix-Game"):
result = subprocess.run([
'git', 'clone', 'https://github.com/SkyworkAI/Matrix-Game.git'
], capture_output=True, text=True, timeout=180)
if result.returncode != 0:
return False
model_loaded = True
return True
except Exception as e:
print(f"Setup failed: {e}")
return False
@spaces.GPU(duration=120)
def generate_video(input_image, num_frames, seed):
if input_image is None:
return None, "Please upload an input image first"
if not download_and_setup_model():
return None, "Failed to setup model"
try:
temp_dir = tempfile.mkdtemp()
output_dir = os.path.join(temp_dir, "outputs")
os.makedirs(output_dir, exist_ok=True)
# Resize image
if max(input_image.size) > 512:
ratio = 512 / max(input_image.size)
new_size = (int(input_image.size[0] * ratio), int(input_image.size[1] * ratio))
input_image = input_image.resize(new_size, Image.Resampling.LANCZOS)
input_path = os.path.join(temp_dir, "input.jpg")
input_image.save(input_path, "JPEG")
matrix_dir = os.path.join("Matrix-Game", "Matrix-Game-2")
cmd = [
sys.executable,
os.path.join(matrix_dir, "inference.py"),
"--img_path", input_path,
"--output_folder", output_dir,
"--num_output_frames", str(int(num_frames)),
"--seed", str(int(seed))
]
if model_path:
cmd.extend(["--pretrained_model_path", model_path])
process = subprocess.run(cmd, capture_output=True, text=True, timeout=300, cwd=matrix_dir)
# Find output video
video_files = []
for root, dirs, files in os.walk(output_dir):
for file in files:
if file.lower().endswith(('.mp4', '.avi', '.mov')):
video_files.append(os.path.join(root, file))
if video_files:
final_output = f"output_{int(seed)}.mp4"
shutil.copy(video_files[0], final_output)
return final_output, f"Success! Generated {int(num_frames)} frames with seed {int(seed)}"
else:
return None, f"Generation failed: {process.stderr[:200]}"
except Exception as e:
return None, f"Error: {str(e)}"
finally:
if 'temp_dir' in locals() and os.path.exists(temp_dir):
shutil.rmtree(temp_dir, ignore_errors=True)
# Ultra-minimal interface
with gr.Blocks() as demo:
gr.HTML("<h1>Matrix-Game-2.0</h1><p>Interactive World Model</p>")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil")
num_frames = gr.Slider(minimum=25, maximum=100, value=50)
seed = gr.Number(value=42)
btn = gr.Button("Generate")
with gr.Column():
output_video = gr.Video()
status = gr.Textbox()
btn.click(generate_video, [input_image, num_frames, seed], [output_video, status])
if __name__ == "__main__":
demo.launch(share=True)