rahul7star commited on
Commit
4f347c8
·
verified ·
1 Parent(s): 3a531cb

Create app_fast_CPU.py

Browse files
Files changed (1) hide show
  1. app_fast_CPU.py +213 -0
app_fast_CPU.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import spaces
3
+ import torch
4
+ from diffusers import AutoencoderKLWan, WanPipeline, WanImageToVideoPipeline, UniPCMultistepScheduler
5
+ from diffusers.utils import export_to_video
6
+ import gradio as gr
7
+ import tempfile
8
+ from huggingface_hub import hf_hub_download
9
+ import numpy as np
10
+ from PIL import Image
11
+ import random
12
+
13
+ MODEL_ID = "FastVideo/FastWan2.2-TI2V-5B-FullAttn-Diffusers"
14
+ HF_MODEL = os.environ.get("HF_UPLOAD_REPO", "rahul7star/wan22TITV5B-image-analysis")
15
+
16
+ # --- CPU-only upload function ---
17
+ def upload_image_and_prompt_cpu(input_image, prompt_text) -> str:
18
+ from datetime import datetime
19
+ import tempfile, os, uuid
20
+ from huggingface_hub import upload_file
21
+ import shutil
22
+
23
+ today_str = datetime.now().strftime("%Y-%m-%d")
24
+ unique_subfolder = f"Upload-Image-{uuid.uuid4().hex[:8]}"
25
+ hf_folder = f"{today_str}/{unique_subfolder}"
26
+
27
+ # Save image temporarily
28
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_img:
29
+ if isinstance(input_image, str):
30
+ shutil.copy(input_image, tmp_img.name)
31
+ else:
32
+ input_image.save(tmp_img.name, format="PNG")
33
+ tmp_img_path = tmp_img.name
34
+
35
+ # Upload image
36
+ upload_file(tmp_img_path, f"{hf_folder}/input_image.png", repo_id=HF_MODEL,
37
+ repo_type="model", token=os.environ.get("HUGGINGFACE_HUB_TOKEN"))
38
+
39
+ # Save prompt as summary.txt
40
+ summary_file = tempfile.NamedTemporaryFile(delete=False, suffix=".txt").name
41
+ with open(summary_file, "w", encoding="utf-8") as f:
42
+ f.write(prompt_text)
43
+ upload_file(summary_file, f"{hf_folder}/summary.txt", repo_id=HF_MODEL,
44
+ repo_type="model", token=os.environ.get("HUGGINGFACE_HUB_TOKEN"))
45
+
46
+ os.remove(tmp_img_path)
47
+ os.remove(summary_file)
48
+ return hf_folder
49
+
50
+ # --- Load pipelines ---
51
+ vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
52
+ text_to_video_pipe = WanPipeline.from_pretrained(MODEL_ID, vae=vae, torch_dtype=torch.bfloat16)
53
+ image_to_video_pipe = WanImageToVideoPipeline.from_pretrained(MODEL_ID, vae=vae, torch_dtype=torch.bfloat16)
54
+
55
+ for pipe in [text_to_video_pipe, image_to_video_pipe]:
56
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
57
+ pipe.to("cuda")
58
+
59
+ # --- Constants ---
60
+ MOD_VALUE = 32
61
+ DEFAULT_H_SLIDER_VALUE = 896
62
+ DEFAULT_W_SLIDER_VALUE = 896
63
+ NEW_FORMULA_MAX_AREA = 720 * 1024
64
+ SLIDER_MIN_H, SLIDER_MAX_H = 256, 1024
65
+ SLIDER_MIN_W, SLIDER_MAX_W = 256, 1024
66
+ MAX_SEED = np.iinfo(np.int32).max
67
+ FIXED_FPS = 24
68
+ MIN_FRAMES_MODEL = 25
69
+ MAX_FRAMES_MODEL = 193
70
+
71
+ default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
72
+ default_negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards, watermark, text, signature"
73
+
74
+ # --- Utility functions ---
75
+ def _calculate_new_dimensions_wan(pil_image, mod_val, calculation_max_area, min_slider_h, max_slider_h, min_slider_w, max_slider_w, default_h, default_w):
76
+ orig_w, orig_h = pil_image.size
77
+ if orig_w <= 0 or orig_h <= 0:
78
+ return default_h, default_w
79
+ aspect_ratio = orig_h / orig_w
80
+ calc_h = round(np.sqrt(calculation_max_area * aspect_ratio))
81
+ calc_w = round(np.sqrt(calculation_max_area / aspect_ratio))
82
+ calc_h = max(mod_val, (calc_h // mod_val) * mod_val)
83
+ calc_w = max(mod_val, (calc_w // mod_val) * mod_val)
84
+ new_h = int(np.clip(calc_h, min_slider_h, (max_slider_h // mod_val) * mod_val))
85
+ new_w = int(np.clip(calc_w, min_slider_w, (max_slider_w // mod_val) * mod_val))
86
+ return new_h, new_w
87
+
88
+ def handle_image_upload_for_dims_wan(uploaded_pil_image, current_h_val, current_w_val):
89
+ if uploaded_pil_image is None:
90
+ return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
91
+ try:
92
+ new_h, new_w = _calculate_new_dimensions_wan(
93
+ uploaded_pil_image, MOD_VALUE, NEW_FORMULA_MAX_AREA,
94
+ SLIDER_MIN_H, SLIDER_MAX_H, SLIDER_MIN_W, SLIDER_MAX_W,
95
+ DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE
96
+ )
97
+ return gr.update(value=new_h), gr.update(value=new_w)
98
+ except Exception as e:
99
+ return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
100
+
101
+ def get_duration(*args, **kwargs):
102
+ return 60 # simplified for example
103
+
104
+ # --- GPU video generation ---
105
+ @spaces.GPU(duration=get_duration)
106
+ def generate_video(input_image, prompt, height, width,
107
+ negative_prompt=default_negative_prompt,
108
+ duration_seconds=2, guidance_scale=0, steps=4,
109
+ seed=44, randomize_seed=False,
110
+ progress=gr.Progress(track_tqdm=True)):
111
+
112
+ target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
113
+ target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
114
+ num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
115
+ current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
116
+
117
+ if input_image is not None:
118
+ resized_image = input_image.resize((target_w, target_h))
119
+ with torch.inference_mode():
120
+ output_frames_list = image_to_video_pipe(
121
+ image=resized_image, prompt=prompt, negative_prompt=negative_prompt,
122
+ height=target_h, width=target_w, num_frames=num_frames,
123
+ guidance_scale=float(guidance_scale), num_inference_steps=int(steps),
124
+ generator=torch.Generator(device="cuda").manual_seed(current_seed)
125
+ ).frames[0]
126
+ else:
127
+ with torch.inference_mode():
128
+ output_frames_list = text_to_video_pipe(
129
+ prompt=prompt, negative_prompt=negative_prompt,
130
+ height=target_h, width=target_w, num_frames=num_frames,
131
+ guidance_scale=float(guidance_scale), num_inference_steps=int(steps),
132
+ generator=torch.Generator(device="cuda").manual_seed(current_seed)
133
+ ).frames[0]
134
+
135
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
136
+ video_path = tmpfile.name
137
+ export_to_video(output_frames_list, video_path, fps=FIXED_FPS)
138
+ return video_path, current_seed
139
+
140
+ # --- Wrapper to upload image/prompt on CPU before GPU generation ---
141
+ def generate_video_with_upload(input_image, prompt, height, width,
142
+ negative_prompt=default_negative_prompt,
143
+ duration_seconds=2, guidance_scale=0, steps=4,
144
+ seed=44, randomize_seed=False):
145
+ # Upload on CPU (hidden, no UI)
146
+ try:
147
+ upload_image_and_prompt_cpu(input_image, prompt)
148
+ except Exception as e:
149
+ print("Upload failed:", e)
150
+
151
+ # Proceed with GPU video generation
152
+ return generate_video(input_image, prompt, height, width,
153
+ negative_prompt, duration_seconds,
154
+ guidance_scale, steps, seed, randomize_seed)
155
+
156
+ # --- Gradio UI ---
157
+ with gr.Blocks() as demo:
158
+ gr.Markdown("# Fast Wan 2.2 TI2V 5B Demo")
159
+ gr.Markdown("""This Demo is using [FastWan2.2-TI2V-5B](https://huggingface.co/FastVideo/FastWan2.2-TI2V-5B-FullAttn-Diffusers) fine-tuned with Sparse-distill for fast high-quality video generation.""")
160
+
161
+ with gr.Row():
162
+ with gr.Column():
163
+ input_image_component = gr.Image(type="pil", label="Input Image (optional, auto-resized to target H/W)")
164
+ prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
165
+ duration_seconds_input = gr.Slider(minimum=round(MIN_FRAMES_MODEL/FIXED_FPS,1),
166
+ maximum=round(MAX_FRAMES_MODEL/FIXED_FPS,1),
167
+ step=0.1, value=2, label="Duration (seconds)")
168
+ with gr.Accordion("Advanced Settings", open=False):
169
+ negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3)
170
+ seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
171
+ randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True)
172
+ with gr.Row():
173
+ height_input = gr.Slider(minimum=SLIDER_MIN_H, maximum=SLIDER_MAX_H, step=MOD_VALUE, value=DEFAULT_H_SLIDER_VALUE, label=f"Output Height")
174
+ width_input = gr.Slider(minimum=SLIDER_MIN_W, maximum=SLIDER_MAX_W, step=MOD_VALUE, value=DEFAULT_W_SLIDER_VALUE, label=f"Output Width")
175
+ steps_slider = gr.Slider(minimum=1, maximum=8, step=1, value=4, label="Inference Steps")
176
+ guidance_scale_input = gr.Slider(minimum=0.0, maximum=5.0, step=0.01, value=0.0, label="Guidance Scale")
177
+ generate_button = gr.Button("Generate Video", variant="primary")
178
+ with gr.Column():
179
+ video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False)
180
+
181
+ input_image_component.upload(
182
+ fn=handle_image_upload_for_dims_wan,
183
+ inputs=[input_image_component, height_input, width_input],
184
+ outputs=[height_input, width_input]
185
+ )
186
+ input_image_component.clear(
187
+ fn=handle_image_upload_for_dims_wan,
188
+ inputs=[input_image_component, height_input, width_input],
189
+ outputs=[height_input, width_input]
190
+ )
191
+
192
+ ui_inputs = [
193
+ input_image_component, prompt_input, height_input, width_input,
194
+ negative_prompt_input, duration_seconds_input,
195
+ guidance_scale_input, steps_slider, seed_input, randomize_seed_checkbox
196
+ ]
197
+ generate_button.click(fn=generate_video_with_upload, inputs=ui_inputs, outputs=[video_output, seed_input])
198
+
199
+ gr.Examples(
200
+ examples=[
201
+ [None, "A person eating spaghetti", 1024, 720],
202
+ ["cat.png", "The cat removes the glasses from its eyes.", 1088, 800],
203
+ [None, "A penguin playfully dancing in the snow, Antarctica", 1024, 720],
204
+ ["peng.png", "A penguin running towards camera joyfully, Antarctica", 896, 512],
205
+ ],
206
+ inputs=[input_image_component, prompt_input, height_input, width_input],
207
+ outputs=[video_output, seed_input],
208
+ fn=generate_video_with_upload,
209
+ cache_examples="lazy"
210
+ )
211
+
212
+ if __name__ == "__main__":
213
+ demo.queue().launch()