Account00a commited on
Commit
60f0957
Β·
verified Β·
1 Parent(s): 7aa0371

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -140
app.py CHANGED
@@ -2,159 +2,47 @@ import spaces
2
  import torch
3
  import gc
4
  import gradio as gr
5
- from diffusers import AutoModel, WanPipeline
6
  from diffusers.utils import export_to_video
7
  import tempfile
8
  import time
9
 
10
- # ============================================================
11
- # SHOTARCH VIDEO GEN β€” Wan2.1-1.3B (Fully Optimized)
12
- # ============================================================
13
- # Optimizations applied:
14
- # 1. bfloat16 transformer (native on Ampere/Hopper GPUs)
15
- # 2. float32 VAE (required for sharp decode)
16
- # 3. VAE tiling (low peak VRAM during decode)
17
- # 4. torch.inference_mode (faster than no_grad)
18
- # 5. Pre-loaded on CPU β†’ instant GPU transfer
19
- # ============================================================
20
-
21
- print("πŸ“¦ Loading Wan2.1-1.3B on CPU (one-time)...")
22
- load_start = time.time()
23
-
24
- vae = AutoModel.from_pretrained(
25
- "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
26
- subfolder="vae",
27
- torch_dtype=torch.float32,
28
- )
29
-
30
- pipe = WanPipeline.from_pretrained(
31
- "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
32
- vae=vae,
33
- torch_dtype=torch.bfloat16,
34
- )
35
-
36
- print(f"βœ… Model loaded in {time.time()-load_start:.0f}s")
37
-
38
 
39
  @spaces.GPU(duration=240)
40
- def generate_video(
41
- prompt,
42
- negative_prompt,
43
- num_frames,
44
- height,
45
- width,
46
- num_inference_steps,
47
- guidance_scale,
48
- ):
49
- """Generate video with Wan2.1 on ZeroGPU."""
50
-
51
- # Move to GPU
52
- pipe.to("cuda")
53
- pipe.vae.enable_tiling()
54
-
55
- print(f"πŸŽ₯ Generating: {width}x{height}, {num_frames} frames, {num_inference_steps} steps")
56
- start = time.time()
57
-
58
  with torch.inference_mode():
59
- result = pipe(
60
- prompt=prompt,
61
- negative_prompt=negative_prompt,
62
- num_frames=int(num_frames),
63
- height=int(height),
64
- width=int(width),
65
- num_inference_steps=int(num_inference_steps),
66
- guidance_scale=float(guidance_scale),
67
- ).frames[0]
68
-
69
- elapsed = time.time() - start
70
- print(f"βœ… Generated in {elapsed:.1f}s")
71
-
72
- # Save video
73
  output_path = tempfile.mktemp(suffix=".mp4")
74
  export_to_video(result, output_path, fps=16)
75
-
76
- # Cleanup GPU memory
77
- gc.collect()
78
- torch.cuda.empty_cache()
79
-
80
  return output_path
81
 
82
-
83
- # ============================================================
84
- # GRADIO UI
85
- # ============================================================
86
- css = """
87
- #main-container { max-width: 1200px; margin: auto; }
88
- .generate-btn { height: 50px !important; font-size: 18px !important; }
89
- """
90
-
91
- with gr.Blocks(title="Shotarch Video Gen", css=css, theme=gr.themes.Soft()) as demo:
92
- gr.Markdown(
93
- """
94
- # 🎬 Shotarch Video Generator
95
- ### Powered by Wan2.1-1.3B β€” Ultra-Fast AI Video Generation
96
- """
97
- )
98
-
99
  with gr.Row():
100
- with gr.Column(scale=1):
101
- prompt = gr.Textbox(
102
- label="✏️ Prompt",
103
- lines=4,
104
- placeholder="A cinematic slow-motion shot of a futuristic cyberpunk sports car drifting through a neon-lit rain-soaked city at night...",
105
- )
106
- negative_prompt = gr.Textbox(
107
- label="🚫 Negative Prompt",
108
- lines=2,
109
- value="Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards",
110
- )
111
-
112
  with gr.Row():
113
- width = gr.Slider(
114
- minimum=480, maximum=1280, value=1280, step=16, label="Width"
115
- )
116
- height = gr.Slider(
117
- minimum=320, maximum=720, value=720, step=16, label="Height"
118
- )
119
-
120
  with gr.Row():
121
- num_frames = gr.Slider(
122
- minimum=17,
123
- maximum=81,
124
- value=81,
125
- step=4,
126
- label="Frames (81 = 5 sec @ 16fps)",
127
- )
128
- steps = gr.Slider(
129
- minimum=10,
130
- maximum=50,
131
- value=25,
132
- step=1,
133
- label="Inference Steps",
134
- )
135
-
136
- guidance = gr.Slider(
137
- minimum=1.0,
138
- maximum=15.0,
139
- value=5.0,
140
- step=0.5,
141
- label="Guidance Scale",
142
- )
143
-
144
- generate_btn = gr.Button(
145
- "🎬 Generate Video",
146
- variant="primary",
147
- elem_classes="generate-btn",
148
- )
149
-
150
- with gr.Column(scale=1):
151
- output_video = gr.Video(label="πŸŽ₯ Generated Video")
152
-
153
- # API-friendly: this function is also callable via /api/predict
154
- generate_btn.click(
155
- fn=generate_video,
156
- inputs=[prompt, negative_prompt, num_frames, height, width, steps, guidance],
157
- outputs=output_video,
158
- )
159
 
160
  demo.launch()
 
2
  import torch
3
  import gc
4
  import gradio as gr
5
+ from diffusers import WanPipeline
6
  from diffusers.utils import export_to_video
7
  import tempfile
8
  import time
9
 
10
+ pipe = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  @spaces.GPU(duration=240)
13
+ def generate_video(prompt, negative_prompt, num_frames, height, width, num_inference_steps, guidance_scale):
14
+ global pipe
15
+ if pipe is None:
16
+ print("πŸ“¦ Loading Wan2.1-1.3B to GPU...")
17
+ pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", torch_dtype=torch.float16, low_cpu_mem_usage=True)
18
+ pipe.to("cuda")
19
+ pipe.vae.enable_tiling()
20
+ print("βœ… Loaded!")
21
+
 
 
 
 
 
 
 
 
 
22
  with torch.inference_mode():
23
+ result = pipe(prompt=prompt, negative_prompt=negative_prompt, num_frames=int(num_frames), height=int(height), width=int(width), num_inference_steps=int(num_inference_steps), guidance_scale=float(guidance_scale)).frames[0]
24
+
 
 
 
 
 
 
 
 
 
 
 
 
25
  output_path = tempfile.mktemp(suffix=".mp4")
26
  export_to_video(result, output_path, fps=16)
27
+ gc.collect(); torch.cuda.empty_cache()
 
 
 
 
28
  return output_path
29
 
30
+ with gr.Blocks(title="Shotarch Video Gen", theme=gr.themes.Soft()) as demo:
31
+ gr.Markdown("# 🎬 Shotarch Video Generator\n### Wan2.1-1.3B on ZeroGPU")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  with gr.Row():
33
+ with gr.Column():
34
+ prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Describe your video...")
35
+ negative_prompt = gr.Textbox(label="Negative Prompt", lines=2, value="Bright tones, overexposed, static, blurred details, worst quality, low quality, ugly, deformed, still picture")
 
 
 
 
 
 
 
 
 
36
  with gr.Row():
37
+ width = gr.Slider(480, 1280, value=1280, step=16, label="Width")
38
+ height = gr.Slider(320, 720, value=720, step=16, label="Height")
 
 
 
 
 
39
  with gr.Row():
40
+ num_frames = gr.Slider(17, 81, value=81, step=4, label="Frames (81=5sec)")
41
+ steps = gr.Slider(10, 50, value=25, step=1, label="Steps")
42
+ guidance = gr.Slider(1.0, 15.0, value=5.0, step=0.5, label="Guidance Scale")
43
+ btn = gr.Button("🎬 Generate Video", variant="primary")
44
+ with gr.Column():
45
+ output = gr.Video(label="Generated Video")
46
+ btn.click(fn=generate_video, inputs=[prompt, negative_prompt, num_frames, height, width, steps, guidance], outputs=output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  demo.launch()