Matrix-Game-2 / app.py
laloadrianmorales's picture
Update app.py
a9779fe verified
raw
history blame
7.1 kB
import os
import gradio as gr
import torch
import spaces
from PIL import Image
import tempfile
import subprocess
import sys
from huggingface_hub import snapshot_download
import shutil
# Configuration
MODEL_REPO = "Skywork/Matrix-Game-2.0"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"๐Ÿš€ Matrix-Game-2.0 Streamlined")
print(f"๐Ÿ“ฑ Device: {DEVICE}")
print(f"๐Ÿ”ฅ CUDA Available: {torch.cuda.is_available()}")
# Global variables for model loading
model_loaded = False
model_path = None
def download_and_setup_model():
"""Download model and setup environment - run once"""
global model_loaded, model_path
if model_loaded:
return True
try:
print("๐Ÿ“ฅ Downloading Matrix-Game-2.0 model...")
# Download the model to cache
model_path = snapshot_download(
repo_id=MODEL_REPO,
cache_dir="./model_cache",
allow_patterns=["*.safetensors", "*.bin", "*.json", "*.yaml", "*.yml", "*.py"],
)
print(f"โœ… Model downloaded to: {model_path}")
# Clone the inference code repository
if not os.path.exists("Matrix-Game"):
print("๐Ÿ“ฅ Cloning Matrix-Game repository...")
result = subprocess.run([
'git', 'clone', 'https://github.com/SkyworkAI/Matrix-Game.git'
], capture_output=True, text=True, timeout=180)
if result.returncode != 0:
print(f"โŒ Git clone failed: {result.stderr}")
return False
# Setup Python path to include Matrix-Game modules
matrix_game_path = os.path.join(os.getcwd(), "Matrix-Game", "Matrix-Game-2")
if matrix_game_path not in sys.path:
sys.path.insert(0, matrix_game_path)
model_loaded = True
return True
except Exception as e:
print(f"โŒ Setup failed: {e}")
return False
@spaces.GPU(duration=120)
def generate_video(input_image, num_frames, seed):
"""Generate video using Matrix-Game-2.0"""
if input_image is None:
return None, "โŒ Please upload an input image first"
# Setup model if not already done
if not download_and_setup_model():
return None, "โŒ Failed to setup model"
try:
# Create temporary directories
temp_dir = tempfile.mkdtemp(prefix="matrix_gen_")
output_dir = os.path.join(temp_dir, "outputs")
os.makedirs(output_dir, exist_ok=True)
# Prepare input image
if max(input_image.size) > 512:
ratio = 512 / max(input_image.size)
new_size = (int(input_image.size[0] * ratio), int(input_image.size[1] * ratio))
input_image = input_image.resize(new_size, Image.Resampling.LANCZOS)
input_path = os.path.join(temp_dir, "input.jpg")
input_image.save(input_path, "JPEG", quality=95)
# Find the inference script and config
matrix_dir = os.path.join("Matrix-Game", "Matrix-Game-2")
# Basic inference command
cmd = [
sys.executable,
os.path.join(matrix_dir, "inference.py"),
"--img_path", input_path,
"--output_folder", output_dir,
"--num_output_frames", str(min(num_frames, 100)),
"--seed", str(seed)
]
# Add model and config paths if found
config_files = []
for root, dirs, files in os.walk(matrix_dir):
for file in files:
if file.endswith(('.yaml', '.yml')) and 'config' in file.lower():
config_files.append(os.path.join(root, file))
if config_files:
cmd.extend(["--config_path", config_files[0]])
if model_path:
cmd.extend(["--pretrained_model_path", model_path])
# Execute with timeout
process = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=300,
cwd=matrix_dir
)
# Find output video
video_files = []
for root, dirs, files in os.walk(output_dir):
for file in files:
if file.lower().endswith(('.mp4', '.avi', '.mov', '.gif')):
video_files.append(os.path.join(root, file))
if video_files:
# Copy to a permanent location
final_output = f"output_{seed}.mp4"
shutil.copy(video_files[0], final_output)
log = f"โœ… Generation Successful!\n๐Ÿ“Š Input: {input_image.size}\n๐ŸŽฌ Frames: {num_frames}\n๐ŸŽฒ Seed: {seed}\n๐Ÿ“ Output: {final_output}"
return final_output, log
else:
error_log = f"โŒ Generation Failed\n๐Ÿ“ Error output: {process.stderr[:500] if process.stderr else 'No error details'}\n๐Ÿ’ญ Try adjusting parameters or using a different input image"
return None, error_log
except subprocess.TimeoutExpired:
return None, "โŒ Generation timed out (>5 minutes). Try fewer frames."
except Exception as e:
return None, f"โŒ Error during generation: {str(e)}"
finally:
# Cleanup
if 'temp_dir' in locals() and os.path.exists(temp_dir):
shutil.rmtree(temp_dir, ignore_errors=True)
# Simple Gradio Interface
def create_interface():
with gr.Blocks(title="Matrix-Game-2.0") as interface:
gr.HTML("""
<div style="text-align: center; padding: 30px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 15px; margin-bottom: 20px;">
<h1>๐ŸŽฎ Matrix-Game-2.0</h1>
<p style="font-size: 18px;">Interactive World Model for Real-Time Video Generation</p>
<p style="opacity: 0.8;">Upload an image and generate interactive video content!</p>
</div>
""")
with gr.Row():
with gr.Column():
gr.Markdown("### ๐Ÿ“ธ Input")
input_image = gr.Image(label="Input Image", type="pil")
gr.Markdown("### โš™๏ธ Settings")
with gr.Row():
num_frames = gr.Slider(25, 100, 50, step=25, label="Number of Frames")
seed = gr.Number(value=42, label="Seed", precision=0)
generate_btn = gr.Button("๐Ÿš€ Generate Video", variant="primary")
with gr.Column():
gr.Markdown("### ๐ŸŽฌ Generated Video")
output_video = gr.Video(label="Result")
status_log = gr.Textbox(label="Status Log", lines=8)
# Event handlers
generate_btn.click(
fn=generate_video,
inputs=[input_image, num_frames, seed],
outputs=[output_video, status_log]
)
return interface
# Launch the app
if __name__ == "__main__":
demo = create_interface()
demo.launch(share=True)