BestWishYsh commited on
Commit
57dd52c
·
verified ·
1 Parent(s): 2bbc03c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -10
app.py CHANGED
@@ -17,7 +17,6 @@ from utils.unet import UNet3DConditionModel
17
  from utils.pipeline_magictime import MagicTimePipeline
18
  from utils.util import save_videos_grid, convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint, load_diffusers_lora_unet, convert_ldm_clip_text_model
19
 
20
-
21
  from huggingface_hub import snapshot_download
22
 
23
  model_path = "ckpts"
@@ -77,7 +76,8 @@ os.system(f"rm -rf gradio_cached_examples/")
77
  device = "cuda"
78
 
79
  def random_seed():
80
- return random.randint(1, 10**16)
 
81
 
82
  class MagicTimeController:
83
  def __init__(self):
@@ -152,15 +152,12 @@ class MagicTimeController:
152
  self.unet = Swift.from_pretrained(self.unet, magic_adapter_t_path)
153
  self.text_encoder = Swift.from_pretrained(self.text_encoder, magic_text_encoder_path)
154
 
155
- return gr.Dropdown()
156
-
157
  def update_motion_module(self, motion_module_dropdown):
158
  self.selected_motion_module = motion_module_dropdown
159
  motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown)
160
  motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu")
161
  _, unexpected = self.unet.load_state_dict(motion_module_state_dict, strict=False)
162
  assert len(unexpected) == 0
163
- return gr.Dropdown()
164
 
165
  def update_motion_module_2(self, motion_module_dropdown):
166
  self.selected_motion_module = motion_module_dropdown
@@ -168,7 +165,6 @@ class MagicTimeController:
168
  motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu")
169
  _, unexpected = self.unet_model.load_state_dict(motion_module_state_dict, strict=False)
170
  assert len(unexpected) == 0
171
- return gr.Dropdown()
172
 
173
  @spaces.GPU(duration=300)
174
  def magictime(
@@ -199,8 +195,8 @@ class MagicTimeController:
199
  ).to(device)
200
 
201
  if int(seed_textbox) > 0: seed = int(seed_textbox)
202
- else: seed = random_seed()
203
- torch.manual_seed(int(seed))
204
 
205
  assert seed == torch.initial_seed()
206
  print(f"### seed: {seed}")
@@ -233,7 +229,9 @@ class MagicTimeController:
233
 
234
  torch.cuda.empty_cache()
235
  time.sleep(1)
236
- return gr.Video(value=save_sample_path), gr.Json(value=json_config)
 
 
237
 
238
  controller = MagicTimeController()
239
 
@@ -292,4 +290,4 @@ def ui():
292
  if __name__ == "__main__":
293
  demo = ui()
294
  demo.queue(max_size=20)
295
- demo.launch()
 
17
  from utils.pipeline_magictime import MagicTimePipeline
18
  from utils.util import save_videos_grid, convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint, load_diffusers_lora_unet, convert_ldm_clip_text_model
19
 
 
20
  from huggingface_hub import snapshot_download
21
 
22
  model_path = "ckpts"
 
76
  device = "cuda"
77
 
78
  def random_seed():
79
+ # 修复:转成字符串以匹配 Textbox 组件的预期输入类型
80
+ return str(random.randint(1, 10**16))
81
 
82
  class MagicTimeController:
83
  def __init__(self):
 
152
  self.unet = Swift.from_pretrained(self.unet, magic_adapter_t_path)
153
  self.text_encoder = Swift.from_pretrained(self.text_encoder, magic_text_encoder_path)
154
 
 
 
155
  def update_motion_module(self, motion_module_dropdown):
156
  self.selected_motion_module = motion_module_dropdown
157
  motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown)
158
  motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu")
159
  _, unexpected = self.unet.load_state_dict(motion_module_state_dict, strict=False)
160
  assert len(unexpected) == 0
 
161
 
162
  def update_motion_module_2(self, motion_module_dropdown):
163
  self.selected_motion_module = motion_module_dropdown
 
165
  motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu")
166
  _, unexpected = self.unet_model.load_state_dict(motion_module_state_dict, strict=False)
167
  assert len(unexpected) == 0
 
168
 
169
  @spaces.GPU(duration=300)
170
  def magictime(
 
195
  ).to(device)
196
 
197
  if int(seed_textbox) > 0: seed = int(seed_textbox)
198
+ else: seed = int(random_seed())
199
+ torch.manual_seed(seed)
200
 
201
  assert seed == torch.initial_seed()
202
  print(f"### seed: {seed}")
 
229
 
230
  torch.cuda.empty_cache()
231
  time.sleep(1)
232
+
233
+ # 修复:直接返回文件路径和字典,不要返回包裹的 gr 组件对象
234
+ return save_sample_path, json_config
235
 
236
  controller = MagicTimeController()
237
 
 
290
  if __name__ == "__main__":
291
  demo = ui()
292
  demo.queue(max_size=20)
293
+ demo.launch(share=True) # 加上了 share=True 用于创建公开链接