Hpsoyl commited on
Commit
fd7a328
·
1 Parent(s): 3fe0e76
Files changed (2) hide show
  1. app.py +13 -3
  2. models/pipeline_ddpm_text_encoder.py +22 -10
app.py CHANGED
@@ -406,7 +406,7 @@ def get_gallery_selection(evt: gr.SelectData):
406
 
407
  # --- Generation Functions ---
408
  @spaces.GPU(duration=120)
409
- def generate_t2i(prompt, num_inference_steps, num_images, current_color):
410
  """
411
  Generates multiple images for Text-to-Image and returns a gallery.
412
  """
@@ -415,7 +415,7 @@ def generate_t2i(prompt, num_inference_steps, num_images, current_color):
415
  target_model_path = PROMPT_TO_MODEL_MAP.get(prompt, f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/FULL-checkpoint-275000")
416
  t2i_pipe = swap_t2i_unet(t2i_pipe, target_model_path)
417
 
418
- print(f"\n🚀 T2I Task started... | Prompt: '{prompt}' | Count: {num_images}")
419
 
420
  generated_raw_list = []
421
  generated_display_images = []
@@ -425,7 +425,14 @@ def generate_t2i(prompt, num_inference_steps, num_images, current_color):
425
  # Generate Batch
426
  for i in range(int(num_images)):
427
  # Generate single image
428
- image_np = t2i_pipe(prompt.lower(), generator=None, num_inference_steps=int(num_inference_steps), output_type="np").images
 
 
 
 
 
 
 
429
  generated_raw_list.append(image_np)
430
 
431
  # Save raw to temp
@@ -652,6 +659,9 @@ with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
652
  t2i_steps = gr.Slider(10, 200, 50, step=1, label="Inference Steps")
653
  # Added: Number of Images Slider
654
  t2i_num_images = gr.Slider(1, 9, 3, step=1, label="Number of Images")
 
 
 
655
  t2i_btn = gr.Button("Generate", variant="primary")
656
  with gr.Column(scale=2):
657
  # Changed: Image to Gallery
 
406
 
407
  # --- Generation Functions ---
408
  @spaces.GPU(duration=120)
409
+ def generate_t2i(prompt, num_inference_steps, num_images, current_color, height, width):
410
  """
411
  Generates multiple images for Text-to-Image and returns a gallery.
412
  """
 
415
  target_model_path = PROMPT_TO_MODEL_MAP.get(prompt, f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/FULL-checkpoint-275000")
416
  t2i_pipe = swap_t2i_unet(t2i_pipe, target_model_path)
417
 
418
+ print(f"\n🚀 T2I Task started... | Prompt: '{prompt}' | Count: {num_images} | Size: {height}x{width}")
419
 
420
  generated_raw_list = []
421
  generated_display_images = []
 
425
  # Generate Batch
426
  for i in range(int(num_images)):
427
  # Generate single image
428
+ image_np = t2i_pipe(
429
+ prompt.lower(),
430
+ generator=None,
431
+ num_inference_steps=int(num_inference_steps),
432
+ output_type="np",
433
+ height=int(height),
434
+ width=int(width)
435
+ ).images
436
  generated_raw_list.append(image_np)
437
 
438
  # Save raw to temp
 
659
  t2i_steps = gr.Slider(10, 200, 50, step=1, label="Inference Steps")
660
  # Added: Number of Images Slider
661
  t2i_num_images = gr.Slider(1, 9, 3, step=1, label="Number of Images")
662
+ with gr.Row():
663
+ t2i_height = gr.Slider(256, 1024, value=512, step=64, label="Height")
664
+ t2i_width = gr.Slider(256, 1024, value=512, step=64, label="Width")
665
  t2i_btn = gr.Button("Generate", variant="primary")
666
  with gr.Column(scale=2):
667
  # Changed: Image to Gallery
models/pipeline_ddpm_text_encoder.py CHANGED
@@ -64,6 +64,8 @@ class DDPMPipeline(DiffusionPipeline):
64
  num_inference_steps: int = 1000,
65
  output_type: Optional[str] = "pil",
66
  return_dict: bool = True,
 
 
67
  ) -> Union[ImagePipelineOutput, Tuple]:
68
  r"""
69
  The call function to the pipeline for generation.
@@ -117,17 +119,27 @@ class DDPMPipeline(DiffusionPipeline):
117
  )
118
  text_input_ids = text_inputs.input_ids.to(self.device)
119
  encoder_hidden_states = self.text_encoder(text_input_ids, return_dict=False)[0]
120
-
 
 
 
 
 
 
 
 
 
 
121
  # Sample gaussian noise to begin loop
122
- if isinstance(self.unet.config.sample_size, int):
123
- image_shape = (
124
- batch_size,
125
- self.unet.config.in_channels,
126
- self.unet.config.sample_size,
127
- self.unet.config.sample_size,
128
- )
129
- else:
130
- image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
131
 
132
  if self.device.type == "mps":
133
  # randn does not work reproducibly on mps
 
64
  num_inference_steps: int = 1000,
65
  output_type: Optional[str] = "pil",
66
  return_dict: bool = True,
67
+ height: Optional[int] = None, # <--- 新增参数
68
+ width: Optional[int] = None,
69
  ) -> Union[ImagePipelineOutput, Tuple]:
70
  r"""
71
  The call function to the pipeline for generation.
 
119
  )
120
  text_input_ids = text_inputs.input_ids.to(self.device)
121
  encoder_hidden_states = self.text_encoder(text_input_ids, return_dict=False)[0]
122
+
123
+ if height is None:
124
+ height = self.unet.config.sample_size
125
+ if width is None:
126
+ width = self.unet.config.sample_size
127
+ image_shape = (
128
+ batch_size,
129
+ self.unet.config.in_channels,
130
+ height,
131
+ width,
132
+ )
133
  # Sample gaussian noise to begin loop
134
+ # if isinstance(self.unet.config.sample_size, int):
135
+ # image_shape = (
136
+ # batch_size,
137
+ # self.unet.config.in_channels,
138
+ # self.unet.config.sample_size,
139
+ # self.unet.config.sample_size,
140
+ # )
141
+ # else:
142
+ # image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
143
 
144
  if self.device.type == "mps":
145
  # randn does not work reproducibly on mps