laloadrianmorales commited on
Commit
d9237e4
ยท
verified ยท
1 Parent(s): 4e077e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +270 -4
app.py CHANGED
@@ -1,7 +1,273 @@
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
  import gradio as gr
3
+ import torch
4
+ import spaces
5
+ from PIL import Image
6
+ import tempfile
7
+ import subprocess
8
+ import sys
9
+ from huggingface_hub import snapshot_download, hf_hub_download
10
+ import shutil
11
 
12
+ # Configuration
13
+ MODEL_REPO = "Skywork/Matrix-Game-2.0"
14
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
+ print(f"๐Ÿš€ Matrix-Game-2.0 Streamlined")
17
+ print(f"๐Ÿ“ฑ Device: {DEVICE}")
18
+ print(f"๐Ÿ”ฅ CUDA Available: {torch.cuda.is_available()}")
19
+
20
+ # Global variables for model loading
21
+ model_loaded = False
22
+ model_path = None
23
+
24
+ def download_and_setup_model():
25
+ """Download model and setup environment - run once"""
26
+ global model_loaded, model_path
27
+
28
+ if model_loaded:
29
+ return True
30
+
31
+ try:
32
+ print("๐Ÿ“ฅ Downloading Matrix-Game-2.0 model...")
33
+
34
+ # Download the model to cache
35
+ model_path = snapshot_download(
36
+ repo_id=MODEL_REPO,
37
+ cache_dir="./model_cache",
38
+ allow_patterns=["*.safetensors", "*.bin", "*.json", "*.yaml", "*.yml", "*.py"],
39
+ )
40
+
41
+ print(f"โœ… Model downloaded to: {model_path}")
42
+
43
+ # Clone the inference code repository
44
+ if not os.path.exists("Matrix-Game"):
45
+ print("๐Ÿ“ฅ Cloning Matrix-Game repository...")
46
+ result = subprocess.run([
47
+ 'git', 'clone', 'https://github.com/SkyworkAI/Matrix-Game.git'
48
+ ], capture_output=True, text=True, timeout=180)
49
+
50
+ if result.returncode != 0:
51
+ print(f"โŒ Git clone failed: {result.stderr}")
52
+ return False
53
+
54
+ # Setup Python path to include Matrix-Game modules
55
+ matrix_game_path = os.path.join(os.getcwd(), "Matrix-Game", "Matrix-Game-2")
56
+ if matrix_game_path not in sys.path:
57
+ sys.path.insert(0, matrix_game_path)
58
+
59
+ model_loaded = True
60
+ return True
61
+
62
+ except Exception as e:
63
+ print(f"โŒ Setup failed: {e}")
64
+ return False
65
+
66
+ @spaces.GPU(duration=120) # Allocate GPU for 2 minutes max
67
+ def generate_video(input_image, num_frames, seed, progress=gr.Progress()):
68
+ """Generate video using Matrix-Game-2.0"""
69
+
70
+ if input_image is None:
71
+ return None, "โŒ Please upload an input image first"
72
+
73
+ # Setup model if not already done
74
+ progress(0.1, desc="๐Ÿ”ง Setting up model...")
75
+ if not download_and_setup_model():
76
+ return None, "โŒ Failed to setup model"
77
+
78
+ progress(0.2, desc="๐Ÿ“ท Processing input image...")
79
+
80
+ try:
81
+ # Create temporary directories
82
+ temp_dir = tempfile.mkdtemp(prefix="matrix_gen_")
83
+ output_dir = os.path.join(temp_dir, "outputs")
84
+ os.makedirs(output_dir, exist_ok=True)
85
+
86
+ # Prepare input image
87
+ if max(input_image.size) > 512: # Resize for faster processing
88
+ ratio = 512 / max(input_image.size)
89
+ new_size = (int(input_image.size[0] * ratio), int(input_image.size[1] * ratio))
90
+ input_image = input_image.resize(new_size, Image.Resampling.LANCZOS)
91
+
92
+ input_path = os.path.join(temp_dir, "input.jpg")
93
+ input_image.save(input_path, "JPEG", quality=95)
94
+
95
+ progress(0.4, desc="๐Ÿš€ Generating video...")
96
+
97
+ # Find the inference script and config
98
+ matrix_dir = os.path.join("Matrix-Game", "Matrix-Game-2")
99
+
100
+ # Basic inference command (simplified)
101
+ cmd = [
102
+ sys.executable,
103
+ os.path.join(matrix_dir, "inference.py"),
104
+ "--img_path", input_path,
105
+ "--output_folder", output_dir,
106
+ "--num_output_frames", str(min(num_frames, 100)), # Limit frames for HF Spaces
107
+ "--seed", str(seed)
108
+ ]
109
+
110
+ # Add model and config paths if found
111
+ config_files = []
112
+ for root, dirs, files in os.walk(matrix_dir):
113
+ for file in files:
114
+ if file.endswith(('.yaml', '.yml')) and 'config' in file.lower():
115
+ config_files.append(os.path.join(root, file))
116
+
117
+ if config_files:
118
+ cmd.extend(["--config_path", config_files[0]])
119
+
120
+ if model_path:
121
+ cmd.extend(["--pretrained_model_path", model_path])
122
+
123
+ progress(0.6, desc="๐ŸŽฌ Running inference...")
124
+
125
+ # Execute with timeout
126
+ process = subprocess.run(
127
+ cmd,
128
+ capture_output=True,
129
+ text=True,
130
+ timeout=300, # 5 minute timeout
131
+ cwd=matrix_dir
132
+ )
133
+
134
+ progress(0.9, desc="๐Ÿ“น Finalizing video...")
135
+
136
+ # Find output video
137
+ video_files = []
138
+ for root, dirs, files in os.walk(output_dir):
139
+ for file in files:
140
+ if file.lower().endswith(('.mp4', '.avi', '.mov', '.gif')):
141
+ video_files.append(os.path.join(root, file))
142
+
143
+ if video_files:
144
+ # Copy to a permanent location
145
+ final_output = f"output_{seed}.mp4"
146
+ shutil.copy(video_files[0], final_output)
147
+
148
+ log = f"""
149
+ โœ… **Generation Successful!**
150
+ ๐Ÿ“Š Input: {input_image.size}
151
+ ๐ŸŽฌ Frames: {num_frames}
152
+ ๐ŸŽฒ Seed: {seed}
153
+ ๐Ÿ“ Output: {final_output}
154
+ """
155
+
156
+ progress(1.0, desc="โœ… Complete!")
157
+ return final_output, log
158
+ else:
159
+ error_log = f"""
160
+ โŒ **Generation Failed**
161
+ ๐Ÿ“ Error output: {process.stderr[:500] if process.stderr else 'No error details'}
162
+ ๐Ÿ’ญ Try adjusting parameters or using a different input image
163
+ """
164
+ return None, error_log
165
+
166
+ except subprocess.TimeoutExpired:
167
+ return None, "โŒ Generation timed out (>5 minutes). Try fewer frames."
168
+ except Exception as e:
169
+ return None, f"โŒ Error during generation: {str(e)}"
170
+ finally:
171
+ # Cleanup
172
+ if 'temp_dir' in locals() and os.path.exists(temp_dir):
173
+ shutil.rmtree(temp_dir, ignore_errors=True)
174
+
175
+ # Gradio Interface
176
+ def create_interface():
177
+ with gr.Blocks(
178
+ title="Matrix-Game-2.0",
179
+ theme=gr.themes.Soft(),
180
+ css="""
181
+ .gradio-container {
182
+ max-width: 1200px !important;
183
+ margin: auto !important;
184
+ }
185
+ """
186
+ ) as interface:
187
+
188
+ gr.HTML("""
189
+ <div style="text-align: center; padding: 30px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 15px; margin-bottom: 20px;">
190
+ <h1>๐ŸŽฎ Matrix-Game-2.0</h1>
191
+ <p style="font-size: 18px;">Interactive World Model for Real-Time Video Generation</p>
192
+ <p style="opacity: 0.8;">Upload an image and generate interactive video content!</p>
193
+ </div>
194
+ """)
195
+
196
+ with gr.Row():
197
+ with gr.Column(scale=1):
198
+ gr.Markdown("### ๐Ÿ“ธ Input")
199
+ input_image = gr.Image(
200
+ label="Input Image",
201
+ type="pil",
202
+ height=300
203
+ )
204
+
205
+ gr.Markdown("### โš™๏ธ Settings")
206
+ with gr.Row():
207
+ num_frames = gr.Slider(
208
+ minimum=25,
209
+ maximum=100,
210
+ value=50,
211
+ step=25,
212
+ label="Number of Frames"
213
+ )
214
+ seed = gr.Number(
215
+ value=42,
216
+ label="Seed",
217
+ precision=0
218
+ )
219
+
220
+ generate_btn = gr.Button(
221
+ "๐Ÿš€ Generate Video",
222
+ variant="primary",
223
+ size="lg"
224
+ )
225
+
226
+ gr.Markdown("""
227
+ ### ๐Ÿ’ก Tips
228
+ - Use clear, well-lit images
229
+ - Landscapes and scenes work best
230
+ - Lower frame counts = faster generation
231
+ - Try different seeds for variety
232
+ """)
233
+
234
+ with gr.Column(scale=1):
235
+ gr.Markdown("### ๐ŸŽฌ Generated Video")
236
+ output_video = gr.Video(
237
+ label="Result",
238
+ height=400
239
+ )
240
+
241
+ status_log = gr.Textbox(
242
+ label="Status Log",
243
+ lines=8,
244
+ max_lines=10
245
+ )
246
+
247
+ # Event handlers
248
+ generate_btn.click(
249
+ fn=generate_video,
250
+ inputs=[input_image, num_frames, seed],
251
+ outputs=[output_video, status_log]
252
+ )
253
+
254
+ # Example inputs
255
+ gr.Examples(
256
+ examples=[
257
+ ["https://images.unsplash.com/photo-1506905925346-21bda4d32df4", 50, 42],
258
+ ["https://images.unsplash.com/photo-1441974231531-c6227db76b6e", 75, 123],
259
+ ],
260
+ inputs=[input_image, num_frames, seed],
261
+ label="Example Images"
262
+ )
263
+
264
+ return interface
265
+
266
+ # Launch the app
267
+ if __name__ == "__main__":
268
+ demo = create_interface()
269
+ demo.launch(
270
+ server_name="0.0.0.0",
271
+ server_port=7860,
272
+ share=False
273
+ )