rahul7star commited on
Commit
d57c4c9
·
verified ·
1 Parent(s): 798901c

Update app_t2v.py

Browse files
Files changed (1) hide show
  1. app_t2v.py +537 -194
app_t2v.py CHANGED
@@ -1,67 +1,77 @@
1
- import spaces
2
  import os
3
  import sys
4
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
5
 
6
- # wan2.2-main/gradio_ti2v.py
7
  import gradio as gr
8
  import torch
9
  from huggingface_hub import snapshot_download
10
  from PIL import Image
11
  import random
12
  import numpy as np
 
 
13
 
14
- from huggingface_hub import hf_hub_download
 
 
 
 
15
  import wan
16
  from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
17
  from wan.utils.utils import cache_video
18
 
19
- import gc
 
 
 
 
 
20
 
21
- # --- 1. Global Setup and Model Loading ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- print("Starting Gradio App for Wan 2.2 TI2V-5B...")
 
24
 
25
- # Download model snapshots from Hugging Face Hub
26
  repo_id = "Wan-AI/Wan2.2-TI2V-5B"
27
  print(f"Downloading/loading checkpoints for {repo_id}...")
28
  ckpt_dir = snapshot_download(repo_id, local_dir_use_symlinks=False)
29
  print(f"Using checkpoints from {ckpt_dir}")
30
 
31
  # Load the model configuration
32
- TASK_NAME = 't2v-A14B'
33
  cfg = WAN_CONFIGS[TASK_NAME]
34
  FIXED_FPS = 24
35
  MIN_FRAMES_MODEL = 8
36
- MAX_FRAMES_MODEL = 121
37
 
38
- # Dimension calculation constants
39
- MOD_VALUE = 32
40
- DEFAULT_H_SLIDER_VALUE = 704
41
- DEFAULT_W_SLIDER_VALUE = 1280
42
- NEW_FORMULA_MAX_AREA = 1280.0 * 704.0
43
-
44
- SLIDER_MIN_H, SLIDER_MAX_H = 128, 1280
45
- SLIDER_MIN_W, SLIDER_MAX_W = 128, 1280
46
-
47
- # Instantiate the pipeline in the global scope
48
- print("Initializing WanTI2V pipeline...")
49
- device = "cuda" if torch.cuda.is_available() else "cpu"
50
  device_id = 0 if torch.cuda.is_available() else -1
51
-
52
-
53
-
54
-
55
- #lora
56
- LORA_REPO_ID = "Kijai/WanVideo_comfy"
57
- LORA_FILENAME = "Lightx2v/lightx2v_T2V_14B_cfg_step_distill_v2_lora_rank256_bf16.safetensors"
58
-
59
-
60
-
61
-
62
-
63
-
64
- pipeline = wan.WanTI2V(
65
  config=cfg,
66
  checkpoint_dir=ckpt_dir,
67
  device_id=device_id,
@@ -73,206 +83,539 @@ pipeline = wan.WanTI2V(
73
  init_on_cpu=False,
74
  convert_model_dtype=True,
75
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
 
 
 
 
 
 
 
 
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
-
80
- # causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
81
- # pipeline.load_lora_weights(causvid_path, adapter_name="causvid_lora")
82
- # pipeline.set_adapters(["causvid_lora"], adapter_weights=[0.95])
83
- # pipeline.fuse_lora()
84
-
85
-
86
-
87
-
88
- print("Pipeline initialized and ready.")
89
-
90
- # --- Helper Functions (from Wan 2.1 Fast demo) ---
91
- def _calculate_new_dimensions_wan(pil_image, mod_val, calculation_max_area,
92
- min_slider_h, max_slider_h,
93
- min_slider_w, max_slider_w,
94
- default_h, default_w):
95
- orig_w, orig_h = pil_image.size
96
- if orig_w <= 0 or orig_h <= 0:
97
- return default_h, default_w
98
-
99
- aspect_ratio = orig_h / orig_w
100
 
101
- calc_h = round(np.sqrt(calculation_max_area * aspect_ratio))
102
- calc_w = round(np.sqrt(calculation_max_area / aspect_ratio))
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- calc_h = max(mod_val, (calc_h // mod_val) * mod_val)
105
- calc_w = max(mod_val, (calc_w // mod_val) * mod_val)
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
- new_h = int(np.clip(calc_h, min_slider_h, (max_slider_h // mod_val) * mod_val))
108
- new_w = int(np.clip(calc_w, min_slider_w, (max_slider_w // mod_val) * mod_val))
109
 
110
- return new_h, new_w
 
111
 
112
- def handle_image_upload_for_dims_wan(uploaded_pil_image, current_h_val, current_w_val):
113
- if uploaded_pil_image is None:
114
- return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
115
  try:
116
- # Convert numpy array to PIL Image if needed
117
- if hasattr(uploaded_pil_image, 'shape'): # numpy array
118
- pil_image = Image.fromarray(uploaded_pil_image).convert("RGB")
119
- else: # already PIL Image
120
- pil_image = uploaded_pil_image
121
 
122
- new_h, new_w = _calculate_new_dimensions_wan(
123
- pil_image, MOD_VALUE, NEW_FORMULA_MAX_AREA,
124
- SLIDER_MIN_H, SLIDER_MAX_H, SLIDER_MIN_W, SLIDER_MAX_W,
125
- DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE
126
- )
127
- return gr.update(value=new_h), gr.update(value=new_w)
128
- except Exception as e:
129
- gr.Warning("Error attempting to calculate new dimensions")
130
- return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
131
-
132
- def get_duration(image,
133
- prompt,
134
- height,
135
- width,
136
- duration_seconds,
137
- sampling_steps,
138
- guide_scale,
139
- shift,
140
- seed,
141
- progress):
142
- """Calculate dynamic GPU duration based on parameters."""
143
- if duration_seconds >= 3:
144
- return 220
145
- elif sampling_steps > 35 and duration_seconds >= 2:
146
- return 180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  elif sampling_steps < 35 or duration_seconds < 2:
148
  return 105
149
  else:
150
  return 90
151
 
152
- # --- 2. Gradio Inference Function ---
153
- @spaces.GPU(duration=get_duration)
154
  def generate_video(
155
  image,
156
  prompt,
157
- height,
158
- width,
159
  duration_seconds,
160
- sampling_steps=38,
161
- guide_scale=cfg.sample_guide_scale,
162
- shift=cfg.sample_shift,
163
- seed=42,
164
  progress=gr.Progress(track_tqdm=True)
165
  ):
166
- """The main function to generate video, called by the Gradio interface."""
 
 
 
 
 
 
167
  if seed == -1:
168
  seed = random.randint(0, sys.maxsize)
169
 
170
- # Ensure dimensions are multiples of MOD_VALUE
171
- target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
172
- target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
173
-
174
  input_image = None
175
  if image is not None:
176
- input_image = Image.fromarray(image).convert("RGB")
177
- # Resize image to match target dimensions
178
- input_image = input_image.resize((target_w, target_h))
 
 
 
 
179
 
180
  # Calculate number of frames based on duration
181
  num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
182
 
183
- # Create size string for the pipeline
184
- size_str = f"{target_h}*{target_w}"
185
-
186
- video_tensor = pipeline.generate(
187
- input_prompt=prompt,
188
- img=input_image, # Pass None for T2V, Image for I2V
189
- size=SIZE_CONFIGS.get(size_str, (target_h, target_w)),
190
- max_area=MAX_AREA_CONFIGS.get(size_str, target_h * target_w),
191
- frame_num=num_frames, # Use calculated frames instead of cfg.frame_num
192
- shift=shift,
193
- sample_solver='unipc',
194
- sampling_steps=int(sampling_steps),
195
- guide_scale=guide_scale,
196
- seed=seed,
197
- offload_model=True
198
- )
199
 
200
- # Save the video to a temporary file
201
- video_path = cache_video(
202
- tensor=video_tensor[None], # Add a batch dimension
203
- save_file=None, # cache_video will create a temp file
204
- fps=cfg.sample_fps,
205
- normalize=True,
206
- value_range=(-1, 1)
207
- )
208
- del video_tensor
209
- gc.collect()
 
 
 
 
 
 
 
 
 
 
 
 
210
  return video_path
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
- # --- 3. Gradio Interface ---
214
- css = ".gradio-container {max-width: 1100px !important; margin: 0 auto} #output_video {height: 500px;} #input_image {height: 500px;}"
215
-
216
- with gr.Blocks(css=css, theme=gr.themes.Soft(), delete_cache=(60, 900)) as demo:
217
- gr.Markdown("# Wan 2.2 TI2V 5B")
218
- gr.Markdown("generate high quality videos using **Wan 2.2 5B Text-Image-to-Video model**,[[model]](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B),[[paper]](https://arxiv.org/abs/2503.20314)")
219
-
220
- with gr.Row():
221
- with gr.Column(scale=2):
222
- image_input = gr.Image(type="numpy", label="Optional (blank = text-to-image)", elem_id="input_image")
223
- prompt_input = gr.Textbox(label="Prompt", value="A beautiful waterfall in a lush jungle, cinematic.", lines=3)
224
- duration_input = gr.Slider(
225
- minimum=round(MIN_FRAMES_MODEL/FIXED_FPS, 1),
226
- maximum=round(MAX_FRAMES_MODEL/FIXED_FPS, 1),
227
- step=0.1,
228
- value=2.0,
229
- label="Duration (seconds)",
230
- info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps."
231
- )
232
-
233
- with gr.Accordion("Advanced Settings", open=False):
234
- with gr.Row():
235
- height_input = gr.Slider(minimum=SLIDER_MIN_H, maximum=SLIDER_MAX_H, step=MOD_VALUE, value=DEFAULT_H_SLIDER_VALUE, label=f"Output Height (multiple of {MOD_VALUE})")
236
- width_input = gr.Slider(minimum=SLIDER_MIN_W, maximum=SLIDER_MAX_W, step=MOD_VALUE, value=DEFAULT_W_SLIDER_VALUE, label=f"Output Width (multiple of {MOD_VALUE})")
237
- steps_input = gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=38, step=1)
238
- scale_input = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, value=cfg.sample_guide_scale, step=0.1)
239
- shift_input = gr.Slider(label="Sample Shift", minimum=1.0, maximum=20.0, value=cfg.sample_shift, step=0.1)
240
- seed_input = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
241
-
242
- with gr.Column(scale=2):
243
- video_output = gr.Video(label="Generated Video", elem_id="output_video")
244
- run_button = gr.Button("Generate Video", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
- # Add image upload handler
247
- image_input.upload(
248
- fn=handle_image_upload_for_dims_wan,
249
- inputs=[image_input, height_input, width_input],
250
- outputs=[height_input, width_input]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  )
252
 
253
- image_input.clear(
254
- fn=handle_image_upload_for_dims_wan,
255
- inputs=[image_input, height_input, width_input],
256
- outputs=[height_input, width_input]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  )
258
-
259
- example_image_path = os.path.join(os.path.dirname(__file__), "examples/i2v_input.JPG")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  gr.Examples(
261
  examples=[
262
- [example_image_path, "The cat removes the glasses from its eyes.", 1088, 800, 1.5],
263
- [None, "A cinematic shot of a boat sailing on a calm sea at sunset.", 704, 1280, 2.0],
264
- [None, "Drone footage flying over a futuristic city with flying cars.", 704, 1280, 2.0],
265
  ],
266
- inputs=[image_input, prompt_input, height_input, width_input, duration_input],
267
- outputs=video_output,
268
- fn=generate_video,
269
- cache_examples="lazy",
270
- )
271
-
272
- run_button.click(
273
- fn=generate_video,
274
- inputs=[image_input, prompt_input, height_input, width_input, duration_input, steps_input, scale_input, shift_input, seed_input],
275
- outputs=video_output
276
  )
277
 
278
  if __name__ == "__main__":
 
 
1
  import os
2
  import sys
3
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
4
 
 
5
  import gradio as gr
6
  import torch
7
  from huggingface_hub import snapshot_download
8
  from PIL import Image
9
  import random
10
  import numpy as np
11
+ import spaces
12
+ import gc
13
 
14
+ # Import for Stable Diffusion XL
15
+ from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
16
+ from compel import Compel, ReturnedEmbeddingsType
17
+
18
+ # Import for Wan2.2
19
  import wan
20
  from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
21
  from wan.utils.utils import cache_video
22
 
23
+ # --- Global Setup ---
24
+ print("Starting Integrated Text-to-Image-to-Video App...")
25
+
26
+ # --- 1. Setup Text-to-Image Model (SDXL) ---
27
+ print("Loading Stable Diffusion XL model...")
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
 
30
+ # Initialize SDXL pipeline
31
+ sdxl_pipe = StableDiffusionXLPipeline.from_pretrained(
32
+ "votepurchase/pornmasterPro_noobV3VAE",
33
+ torch_dtype=torch.float16,
34
+ variant="fp16",
35
+ use_safetensors=True
36
+ )
37
+
38
+ sdxl_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sdxl_pipe.scheduler.config)
39
+ sdxl_pipe.to(device)
40
+
41
+ # Force all components to use the same dtype
42
+ sdxl_pipe.text_encoder.to(torch.float16)
43
+ sdxl_pipe.text_encoder_2.to(torch.float16)
44
+ sdxl_pipe.vae.to(torch.float16)
45
+ sdxl_pipe.unet.to(torch.float16)
46
+
47
+ # Initialize Compel for long prompt processing
48
+ compel = Compel(
49
+ tokenizer=[sdxl_pipe.tokenizer, sdxl_pipe.tokenizer_2],
50
+ text_encoder=[sdxl_pipe.text_encoder, sdxl_pipe.text_encoder_2],
51
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
52
+ requires_pooled=[False, True],
53
+ truncate_long_prompts=False
54
+ )
55
 
56
+ # --- 2. Setup Image-to-Video Model (Wan2.2) ---
57
+ print("Loading Wan 2.2 TI2V-5B model...")
58
 
59
+ # Download model snapshots
60
  repo_id = "Wan-AI/Wan2.2-TI2V-5B"
61
  print(f"Downloading/loading checkpoints for {repo_id}...")
62
  ckpt_dir = snapshot_download(repo_id, local_dir_use_symlinks=False)
63
  print(f"Using checkpoints from {ckpt_dir}")
64
 
65
  # Load the model configuration
66
+ TASK_NAME = 'ti2v-5B'
67
  cfg = WAN_CONFIGS[TASK_NAME]
68
  FIXED_FPS = 24
69
  MIN_FRAMES_MODEL = 8
70
+ MAX_FRAMES_MODEL = 121
71
 
72
+ # Instantiate the pipeline
 
 
 
 
 
 
 
 
 
 
 
73
  device_id = 0 if torch.cuda.is_available() else -1
74
+ wan_pipeline = wan.WanTI2V(
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  config=cfg,
76
  checkpoint_dir=ckpt_dir,
77
  device_id=device_id,
 
83
  init_on_cpu=False,
84
  convert_model_dtype=True,
85
  )
86
+ print("All models loaded and ready.")
87
+
88
+ # --- Constants ---
89
+ MAX_SEED = np.iinfo(np.int32).max
90
+ MAX_IMAGE_SIZE = 1216
91
+
92
+ # --- Helper Functions ---
93
+ def clear_gpu_memory():
94
+ """Clear GPU memory more thoroughly"""
95
+ if torch.cuda.is_available():
96
+ torch.cuda.empty_cache()
97
+ torch.cuda.ipc_collect()
98
+ gc.collect()
99
 
100
+ def process_long_prompt(prompt, negative_prompt=""):
101
+ """Simple long prompt processing using Compel"""
102
+ try:
103
+ conditioning, pooled = compel([prompt, negative_prompt])
104
+ return conditioning, pooled
105
+ except Exception as e:
106
+ print(f"Long prompt processing failed: {e}, falling back to standard processing")
107
+ return None, None
108
 
109
+ def select_best_size_for_image(image, available_sizes):
110
+ """Select the size option with aspect ratio closest to the input image."""
111
+ if image is None:
112
+ return available_sizes[0]
113
+
114
+ img_width, img_height = image.size
115
+ img_aspect_ratio = img_height / img_width
116
+
117
+ best_size = available_sizes[0]
118
+ best_diff = float('inf')
119
+
120
+ for size_str in available_sizes:
121
+ height, width = map(int, size_str.split('*'))
122
+ size_aspect_ratio = height / width
123
+ diff = abs(img_aspect_ratio - size_aspect_ratio)
124
+
125
+ if diff < best_diff:
126
+ best_diff = diff
127
+ best_size = size_str
128
+
129
+ return best_size
130
 
131
+ def validate_video_inputs(image, prompt, duration_seconds):
132
+ """Validate user inputs for video generation"""
133
+ errors = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
+ if not prompt or len(prompt.strip()) < 5:
136
+ errors.append("Prompt must be at least 5 characters long.")
137
+
138
+ if image is not None:
139
+ if isinstance(image, np.ndarray):
140
+ img = Image.fromarray(image)
141
+ else:
142
+ img = image
143
+ if img.size[0] * img.size[1] > 4096 * 4096:
144
+ errors.append("Image size is too large (maximum 4096x4096).")
145
+
146
+ if duration_seconds > 5.0 and image is None:
147
+ errors.append("Videos longer than 5 seconds require an input image.")
148
+
149
+ return errors
150
 
151
+ # --- Text-to-Image Generation Function ---
152
+ @spaces.GPU(duration=30)
153
+ def generate_image(
154
+ prompt,
155
+ negative_prompt,
156
+ seed,
157
+ randomize_seed,
158
+ width,
159
+ height,
160
+ guidance_scale,
161
+ num_inference_steps,
162
+ progress=gr.Progress(track_tqdm=True)
163
+ ):
164
+ """Generate image from text prompt"""
165
+ progress(0, desc="Initializing image generation...")
166
 
167
+ use_long_prompt = len(prompt.split()) > 60 or len(prompt) > 300
 
168
 
169
+ if randomize_seed:
170
+ seed = random.randint(0, MAX_SEED)
171
 
172
+ generator = torch.Generator(device=device).manual_seed(seed)
173
+
 
174
  try:
175
+ progress(0.3, desc="Processing prompt...")
176
+
177
+ if use_long_prompt:
178
+ print("Using long prompt processing...")
179
+ conditioning, pooled = process_long_prompt(prompt, negative_prompt)
180
 
181
+ if conditioning is not None:
182
+ progress(0.5, desc="Generating image...")
183
+ output_image = sdxl_pipe(
184
+ prompt_embeds=conditioning[0:1],
185
+ pooled_prompt_embeds=pooled[0:1],
186
+ negative_prompt_embeds=conditioning[1:2],
187
+ negative_pooled_prompt_embeds=pooled[1:2],
188
+ guidance_scale=guidance_scale,
189
+ num_inference_steps=num_inference_steps,
190
+ width=width,
191
+ height=height,
192
+ generator=generator
193
+ ).images[0]
194
+ progress(1.0, desc="Complete!")
195
+ return output_image, seed
196
+
197
+ # Fall back to standard processing
198
+ progress(0.5, desc="Generating image...")
199
+ output_image = sdxl_pipe(
200
+ prompt=prompt,
201
+ negative_prompt=negative_prompt,
202
+ guidance_scale=guidance_scale,
203
+ num_inference_steps=num_inference_steps,
204
+ width=width,
205
+ height=height,
206
+ generator=generator
207
+ ).images[0]
208
+
209
+ progress(1.0, desc="Complete!")
210
+ return output_image, seed
211
+
212
+ except RuntimeError as e:
213
+ print(f"Error during generation: {e}")
214
+ error_img = Image.new('RGB', (width, height), color=(0, 0, 0))
215
+ return error_img, seed
216
+ finally:
217
+ clear_gpu_memory()
218
+
219
+ # --- Image-to-Video Generation Function ---
220
+ def get_video_duration(image, prompt, size, duration_seconds, sampling_steps, guide_scale, shift, seed, progress):
221
+ """Calculate dynamic GPU duration for video generation"""
222
+ if sampling_steps > 35 and duration_seconds >= 2:
223
+ return 120
224
  elif sampling_steps < 35 or duration_seconds < 2:
225
  return 105
226
  else:
227
  return 90
228
 
229
+ @spaces.GPU(duration=get_video_duration)
 
230
  def generate_video(
231
  image,
232
  prompt,
233
+ size,
 
234
  duration_seconds,
235
+ sampling_steps,
236
+ guide_scale,
237
+ shift,
238
+ seed,
239
  progress=gr.Progress(track_tqdm=True)
240
  ):
241
+ """Generate video from image and prompt"""
242
+ errors = validate_video_inputs(image, prompt, duration_seconds)
243
+ if errors:
244
+ raise gr.Error("\n".join(errors))
245
+
246
+ progress(0, desc="Setting up video generation...")
247
+
248
  if seed == -1:
249
  seed = random.randint(0, sys.maxsize)
250
 
251
+ progress(0.1, desc="Processing image...")
252
+
 
 
253
  input_image = None
254
  if image is not None:
255
+ if isinstance(image, np.ndarray):
256
+ input_image = Image.fromarray(image).convert("RGB")
257
+ else:
258
+ input_image = image.convert("RGB")
259
+ # Resize image to match selected size
260
+ target_height, target_width = map(int, size.split('*'))
261
+ input_image = input_image.resize((target_width, target_height))
262
 
263
  # Calculate number of frames based on duration
264
  num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
265
 
266
+ progress(0.2, desc="Generating video...")
267
+
268
+ try:
269
+ video_tensor = wan_pipeline.generate(
270
+ input_prompt=prompt,
271
+ img=input_image,
272
+ size=SIZE_CONFIGS[size],
273
+ max_area=MAX_AREA_CONFIGS[size],
274
+ frame_num=num_frames,
275
+ shift=shift,
276
+ sample_solver='unipc',
277
+ sampling_steps=int(sampling_steps),
278
+ guide_scale=guide_scale,
279
+ seed=seed,
280
+ offload_model=True
281
+ )
282
 
283
+ progress(0.9, desc="Saving video...")
284
+
285
+ video_path = cache_video(
286
+ tensor=video_tensor[None],
287
+ save_file=None,
288
+ fps=cfg.sample_fps,
289
+ normalize=True,
290
+ value_range=(-1, 1)
291
+ )
292
+
293
+ progress(1.0, desc="Complete!")
294
+
295
+ except torch.cuda.OutOfMemoryError:
296
+ clear_gpu_memory()
297
+ raise gr.Error("GPU out of memory. Please try with lower settings.")
298
+ except Exception as e:
299
+ raise gr.Error(f"Video generation failed: {str(e)}")
300
+ finally:
301
+ if 'video_tensor' in locals():
302
+ del video_tensor
303
+ clear_gpu_memory()
304
+
305
  return video_path
306
 
307
+ # --- Combined Generation Function ---
308
+ def generate_image_to_video(
309
+ img_prompt,
310
+ img_negative_prompt,
311
+ img_seed,
312
+ img_randomize_seed,
313
+ img_width,
314
+ img_height,
315
+ img_guidance_scale,
316
+ img_num_inference_steps,
317
+ video_prompt,
318
+ video_size,
319
+ video_duration,
320
+ video_sampling_steps,
321
+ video_guide_scale,
322
+ video_shift,
323
+ video_seed
324
+ ):
325
+ """Generate image from text, then use it to generate video"""
326
+ # First generate image
327
+ generated_image, used_seed = generate_image(
328
+ img_prompt,
329
+ img_negative_prompt,
330
+ img_seed,
331
+ img_randomize_seed,
332
+ img_width,
333
+ img_height,
334
+ img_guidance_scale,
335
+ img_num_inference_steps
336
+ )
337
+
338
+ # Update the best video size based on generated image
339
+ available_sizes = list(SUPPORTED_SIZES[TASK_NAME])
340
+ best_size = select_best_size_for_image(generated_image, available_sizes)
341
+
342
+ # Then generate video using the generated image
343
+ video_path = generate_video(
344
+ generated_image,
345
+ video_prompt,
346
+ best_size, # Use auto-selected size
347
+ video_duration,
348
+ video_sampling_steps,
349
+ video_guide_scale,
350
+ video_shift,
351
+ video_seed
352
+ )
353
+
354
+ return generated_image, video_path, used_seed, best_size
355
+
356
+ # --- Gradio Interface ---
357
+ css = """
358
+ .gradio-container {max-width: 1400px !important; margin: 0 auto}
359
+ #output_video {height: 500px;}
360
+ #input_image {height: 400px;}
361
+ #generated_image {height: 400px;}
362
+ .tab-nav button {font-size: 18px !important; padding: 10px 20px !important;}
363
+ """
364
+
365
+ # Prompt templates
366
+ video_templates = {
367
+ "Cinematic": "cinematic shot of {subject}, professional lighting, smooth camera movement, 4k quality",
368
+ "Animation": "animated style {subject}, vibrant colors, fluid motion, dynamic movement",
369
+ "Nature": "nature documentary footage of {subject}, wildlife photography, natural movement",
370
+ "Slow Motion": "slow motion capture of {subject}, high speed camera, detailed motion",
371
+ "Action": "dynamic action shot of {subject}, fast paced movement, energetic motion"
372
+ }
373
+
374
+ def apply_template(template, current_prompt):
375
+ """Apply prompt template"""
376
+ if "{subject}" in template:
377
+ subject = current_prompt.split(",")[0] if "," in current_prompt else current_prompt
378
+ return template.replace("{subject}", subject)
379
+ return template + " " + current_prompt
380
+
381
+ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
382
+ gr.Markdown("""
383
+ # 🎨 Integrated Text-to-Image-to-Video Generator
384
+
385
+ Generate images from text and convert them to high-quality videos using:
386
+ - **Stable Diffusion XL** for Text-to-Image generation
387
+ - **Wan 2.2 5B** for Image-to-Video generation
388
+
389
+ ### ✨ Features:
390
+ - 📝 **Text-to-Image**: Generate images from text descriptions
391
+ - 🎬 **Image-to-Video**: Convert images (uploaded or generated) to videos
392
+ - 🔄 **Text-to-Image-to-Video**: Complete pipeline from text to video
393
+ """)
394
+
395
+ # Badge section
396
+ gr.HTML(
397
+ """
398
+ <div style="display: flex; justify-content: center; align-items: center; gap: 20px; margin: 20px 0;">
399
+ <a href="https://huggingface.co/spaces/Heartsync/Wan-2.2-ADULT" target="_blank">
400
+ <img src="https://img.shields.io/static/v1?label=T2I%20%26%20TI2V&message=Wan-2.2-ADULT&color=%230000ff&labelColor=%23800080&logo=huggingface&logoColor=white&style=for-the-badge" alt="badge">
401
+ </a>
402
+ <a href="https://huggingface.co/spaces/Heartsync/PornHUB" target="_blank">
403
+ <img src="https://img.shields.io/static/v1?label=T2I%20&message=PornHUB&color=%230000ff&labelColor=%23800080&logo=huggingface&logoColor=white&style=for-the-badge" alt="badge">
404
+ </a>
405
+ <a href="https://huggingface.co/spaces/Heartsync/Hentai-Adult" target="_blank">
406
+ <img src="https://img.shields.io/static/v1?label=T2I%20&message=Hentai-Adult&color=%230000ff&labelColor=%23800080&logo=huggingface&logoColor=white&style=for-the-badge" alt="badge">
407
+ </a>
408
+ </div>
409
+ """
410
+ )
411
 
412
+
413
+ with gr.Tabs() as tabs:
414
+ # Tab 1: Text-to-Image
415
+ with gr.Tab("Text to Image", id="t2i_tab"):
416
+ with gr.Row():
417
+ with gr.Column(scale=1):
418
+ t2i_prompt = gr.Textbox(
419
+ label="Prompt",
420
+ placeholder="Describe the image you want to generate...",
421
+ lines=3
422
+ )
423
+ t2i_negative_prompt = gr.Textbox(
424
+ label="Negative Prompt",
425
+ value="nsfw, (low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn",
426
+ lines=2
427
+ )
428
+
429
+ with gr.Row():
430
+ t2i_width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
431
+ t2i_height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
432
+
433
+ with gr.Accordion("Advanced Settings", open=False):
434
+ t2i_seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
435
+ t2i_randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
436
+ t2i_guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=20.0, step=0.1, value=7)
437
+ t2i_num_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=28)
438
+
439
+ t2i_generate_btn = gr.Button("Generate Image", variant="primary", size="lg")
440
+
441
+ with gr.Column(scale=1):
442
+ t2i_output = gr.Image(label="Generated Image", elem_id="generated_image")
443
+ t2i_seed_output = gr.Number(label="Used Seed", interactive=False)
444
+
445
+ # Tab 2: Image-to-Video
446
+ with gr.Tab("Image to Video", id="i2v_tab"):
447
+ with gr.Row():
448
+ with gr.Column(scale=1):
449
+ i2v_image = gr.Image(type="numpy", label="Input Image", elem_id="input_image")
450
+ i2v_prompt = gr.Textbox(
451
+ label="Video Prompt",
452
+ value="Generate a video with smooth and natural movement. Objects should have visible motion while maintaining fluid transitions.",
453
+ lines=3
454
+ )
455
+
456
+ with gr.Accordion("Prompt Templates", open=False):
457
+ gr.Markdown("Click a template to apply it to your prompt:")
458
+ template_buttons = {}
459
+ for name, template in video_templates.items():
460
+ btn = gr.Button(name, size="sm")
461
+ template_buttons[name] = (btn, template)
462
+
463
+ i2v_duration = gr.Slider(
464
+ label="Duration (seconds)",
465
+ minimum=round(MIN_FRAMES_MODEL/FIXED_FPS, 1),
466
+ maximum=round(MAX_FRAMES_MODEL/FIXED_FPS, 1),
467
+ step=0.1,
468
+ value=2.0
469
+ )
470
+ i2v_size = gr.Dropdown(
471
+ label="Output Resolution",
472
+ choices=list(SUPPORTED_SIZES[TASK_NAME]),
473
+ value="704*1280"
474
+ )
475
+
476
+ with gr.Accordion("Advanced Settings", open=False):
477
+ i2v_steps = gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=38, step=1)
478
+ i2v_guide_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, value=cfg.sample_guide_scale, step=0.1)
479
+ i2v_shift = gr.Slider(label="Sample Shift", minimum=1.0, maximum=20.0, value=cfg.sample_shift, step=0.1)
480
+ i2v_seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
481
+
482
+ i2v_generate_btn = gr.Button("Generate Video", variant="primary", size="lg")
483
+
484
+ with gr.Column(scale=1):
485
+ i2v_output = gr.Video(label="Generated Video", elem_id="output_video")
486
+
487
+ # Tab 3: Text-to-Image-to-Video
488
+ with gr.Tab("Text to Image to Video", id="t2i2v_tab"):
489
+ gr.Markdown("### 🎯 Complete Pipeline: Generate an image from text, then convert it to video")
490
 
491
+ with gr.Row():
492
+ with gr.Column(scale=1):
493
+ gr.Markdown("#### Step 1: Image Generation Settings")
494
+ t2i2v_img_prompt = gr.Textbox(
495
+ label="Image Prompt",
496
+ placeholder="Describe the image to generate...",
497
+ lines=3
498
+ )
499
+ t2i2v_img_negative = gr.Textbox(
500
+ label="Negative Prompt",
501
+ value="nsfw, (low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn",
502
+ lines=2
503
+ )
504
+
505
+ with gr.Row():
506
+ t2i2v_img_width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
507
+ t2i2v_img_height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
508
+
509
+ with gr.Accordion("Image Advanced Settings", open=False):
510
+ t2i2v_img_seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
511
+ t2i2v_img_randomize = gr.Checkbox(label="Randomize seed", value=True)
512
+ t2i2v_img_guidance = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=20.0, step=0.1, value=7)
513
+ t2i2v_img_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=28)
514
+
515
+ gr.Markdown("#### Step 2: Video Generation Settings")
516
+ t2i2v_video_prompt = gr.Textbox(
517
+ label="Video Prompt",
518
+ value="Generate a video with smooth and natural movement. Objects should have visible motion while maintaining fluid transitions.",
519
+ lines=3
520
+ )
521
+ t2i2v_video_duration = gr.Slider(
522
+ label="Duration (seconds)",
523
+ minimum=round(MIN_FRAMES_MODEL/FIXED_FPS, 1),
524
+ maximum=round(MAX_FRAMES_MODEL/FIXED_FPS, 1),
525
+ step=0.1,
526
+ value=2.0
527
+ )
528
+
529
+ # Add the missing video size dropdown component
530
+ t2i2v_video_size = gr.Dropdown(
531
+ label="Video Output Resolution",
532
+ choices=list(SUPPORTED_SIZES[TASK_NAME]),
533
+ value="704*1280",
534
+ info="This will be auto-adjusted based on generated image aspect ratio"
535
+ )
536
+
537
+ with gr.Accordion("Video Advanced Settings", open=False):
538
+ t2i2v_video_steps = gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=38, step=1)
539
+ t2i2v_video_guide = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, value=cfg.sample_guide_scale, step=0.1)
540
+ t2i2v_video_shift = gr.Slider(label="Sample Shift", minimum=1.0, maximum=20.0, value=cfg.sample_shift, step=0.1)
541
+ t2i2v_video_seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
542
+
543
+ t2i2v_generate_btn = gr.Button("Generate Image → Video", variant="primary", size="lg")
544
+
545
+ with gr.Column(scale=1):
546
+ gr.Markdown("#### Results")
547
+ t2i2v_image_output = gr.Image(label="Generated Image", elem_id="generated_image")
548
+ t2i2v_video_output = gr.Video(label="Generated Video", elem_id="output_video")
549
+ with gr.Row():
550
+ t2i2v_seed_output = gr.Number(label="Image Seed Used", interactive=False)
551
+ t2i2v_size_output = gr.Textbox(label="Video Size Used", interactive=False)
552
+
553
+ # Event handlers
554
+
555
+ # Tab 1: Text-to-Image
556
+ t2i_generate_btn.click(
557
+ fn=generate_image,
558
+ inputs=[
559
+ t2i_prompt, t2i_negative_prompt, t2i_seed, t2i_randomize_seed,
560
+ t2i_width, t2i_height, t2i_guidance_scale, t2i_num_steps
561
+ ],
562
+ outputs=[t2i_output, t2i_seed_output]
563
  )
564
 
565
+ # Tab 2: Image-to-Video
566
+ # Connect template buttons
567
+ for name, (btn, template) in template_buttons.items():
568
+ btn.click(
569
+ fn=lambda t=template, p=i2v_prompt: apply_template(t, p),
570
+ inputs=[i2v_prompt],
571
+ outputs=i2v_prompt
572
+ )
573
+
574
+ # Auto-select best size when image is uploaded
575
+ def handle_image_upload(image):
576
+ if image is None:
577
+ return gr.update()
578
+ pil_image = Image.fromarray(image).convert("RGB")
579
+ available_sizes = list(SUPPORTED_SIZES[TASK_NAME])
580
+ best_size = select_best_size_for_image(pil_image, available_sizes)
581
+ return gr.update(value=best_size)
582
+
583
+ i2v_image.upload(
584
+ fn=handle_image_upload,
585
+ inputs=[i2v_image],
586
+ outputs=[i2v_size]
587
  )
588
+
589
+ i2v_generate_btn.click(
590
+ fn=generate_video,
591
+ inputs=[
592
+ i2v_image, i2v_prompt, i2v_size, i2v_duration,
593
+ i2v_steps, i2v_guide_scale, i2v_shift, i2v_seed
594
+ ],
595
+ outputs=i2v_output
596
+ )
597
+
598
+ # Tab 3: Text-to-Image-to-Video
599
+ t2i2v_generate_btn.click(
600
+ fn=generate_image_to_video,
601
+ inputs=[
602
+ t2i2v_img_prompt, t2i2v_img_negative, t2i2v_img_seed, t2i2v_img_randomize,
603
+ t2i2v_img_width, t2i2v_img_height, t2i2v_img_guidance, t2i2v_img_steps,
604
+ t2i2v_video_prompt, t2i2v_video_size, t2i2v_video_duration,
605
+ t2i2v_video_steps, t2i2v_video_guide, t2i2v_video_shift, t2i2v_video_seed
606
+ ],
607
+ outputs=[t2i2v_image_output, t2i2v_video_output, t2i2v_seed_output, t2i2v_size_output]
608
+ )
609
+
610
+ # Examples
611
  gr.Examples(
612
  examples=[
613
+ ["A majestic lion sitting on a rock at sunset, golden hour lighting, photorealistic", "Generate a video with the lion slowly turning its head and mane flowing in the wind"],
614
+ ["A futuristic cyberpunk city with neon lights and flying cars", "Cinematic shot with smooth camera movement through the city streets"],
615
+ ["A serene Japanese garden with cherry blossoms and a koi pond", "Gentle breeze causing cherry blossoms to fall, ripples in the pond"],
616
  ],
617
+ inputs=[t2i2v_img_prompt, t2i2v_video_prompt],
618
+ label="Example Prompts"
 
 
 
 
 
 
 
 
619
  )
620
 
621
  if __name__ == "__main__":