File size: 4,021 Bytes
4e077e1
d9237e4
 
 
 
 
 
c37c52c
a9779fe
d9237e4
4e077e1
d9237e4
 
 
4e077e1
c37c52c
 
d9237e4
c37c52c
d9237e4
 
 
 
 
 
 
 
 
 
c37c52c
d9237e4
 
c37c52c
d9237e4
 
 
 
 
 
 
 
 
 
 
 
 
 
c37c52c
d9237e4
 
a9779fe
 
d9237e4
c37c52c
d9237e4
 
c37c52c
d9237e4
 
c37c52c
d9237e4
 
 
c37c52c
a9779fe
d9237e4
 
 
 
 
c37c52c
d9237e4
 
 
 
 
 
 
 
c37c52c
 
d9237e4
 
 
 
 
c37c52c
d9237e4
 
 
 
 
c37c52c
d9237e4
 
 
c37c52c
d9237e4
c37c52c
d9237e4
c37c52c
d9237e4
 
c37c52c
d9237e4
 
 
 
c37c52c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9237e4
c37c52c
d9237e4
 
a9779fe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import gradio as gr
import torch
import spaces
from PIL import Image
import tempfile
import subprocess
import sys
import os
from huggingface_hub import snapshot_download
import shutil

# Configuration
MODEL_REPO = "Skywork/Matrix-Game-2.0"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {DEVICE}")
print(f"CUDA Available: {torch.cuda.is_available()}")

# Global variables
model_loaded = False
model_path = None

def download_and_setup_model():
    global model_loaded, model_path
    
    if model_loaded:
        return True
    
    try:
        print("Downloading model...")
        model_path = snapshot_download(
            repo_id=MODEL_REPO,
            cache_dir="./model_cache"
        )
        
        if not os.path.exists("Matrix-Game"):
            result = subprocess.run([
                'git', 'clone', 'https://github.com/SkyworkAI/Matrix-Game.git'
            ], capture_output=True, text=True, timeout=180)
            
            if result.returncode != 0:
                return False
        
        model_loaded = True
        return True
        
    except Exception as e:
        print(f"Setup failed: {e}")
        return False

@spaces.GPU(duration=120)
def generate_video(input_image, num_frames, seed):
    if input_image is None:
        return None, "Please upload an input image first"
    
    if not download_and_setup_model():
        return None, "Failed to setup model"
    
    try:
        temp_dir = tempfile.mkdtemp()
        output_dir = os.path.join(temp_dir, "outputs")
        os.makedirs(output_dir, exist_ok=True)
        
        # Resize image
        if max(input_image.size) > 512:
            ratio = 512 / max(input_image.size)
            new_size = (int(input_image.size[0] * ratio), int(input_image.size[1] * ratio))
            input_image = input_image.resize(new_size, Image.Resampling.LANCZOS)
        
        input_path = os.path.join(temp_dir, "input.jpg")
        input_image.save(input_path, "JPEG")
        
        matrix_dir = os.path.join("Matrix-Game", "Matrix-Game-2")
        
        cmd = [
            sys.executable, 
            os.path.join(matrix_dir, "inference.py"),
            "--img_path", input_path,
            "--output_folder", output_dir,
            "--num_output_frames", str(int(num_frames)),
            "--seed", str(int(seed))
        ]
        
        if model_path:
            cmd.extend(["--pretrained_model_path", model_path])
        
        process = subprocess.run(cmd, capture_output=True, text=True, timeout=300, cwd=matrix_dir)
        
        # Find output video
        video_files = []
        for root, dirs, files in os.walk(output_dir):
            for file in files:
                if file.lower().endswith(('.mp4', '.avi', '.mov')):
                    video_files.append(os.path.join(root, file))
        
        if video_files:
            final_output = f"output_{int(seed)}.mp4"
            shutil.copy(video_files[0], final_output)
            return final_output, f"Success! Generated {int(num_frames)} frames with seed {int(seed)}"
        else:
            return None, f"Generation failed: {process.stderr[:200]}"
            
    except Exception as e:
        return None, f"Error: {str(e)}"
    finally:
        if 'temp_dir' in locals() and os.path.exists(temp_dir):
            shutil.rmtree(temp_dir, ignore_errors=True)

# Ultra-minimal interface
with gr.Blocks() as demo:
    
    gr.HTML("<h1>Matrix-Game-2.0</h1><p>Interactive World Model</p>")
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="pil")
            num_frames = gr.Slider(minimum=25, maximum=100, value=50)
            seed = gr.Number(value=42)
            btn = gr.Button("Generate")
            
        with gr.Column():
            output_video = gr.Video()
            status = gr.Textbox()
    
    btn.click(generate_video, [input_image, num_frames, seed], [output_video, status])

if __name__ == "__main__":
    demo.launch(share=True)