LTTEAM commited on
Commit
89906f2
·
verified ·
1 Parent(s): 9c0bfc9

Update cogvideox/ui/controller.py

Browse files
Files changed (1) hide show
  1. cogvideox/ui/controller.py +24 -30
cogvideox/ui/controller.py CHANGED
@@ -31,6 +31,7 @@ from safetensors import safe_open
31
  from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio
32
  from ..utils.utils import save_videos_grid
33
 
 
34
  gradio_version = pkg_resources.get_distribution("gradio").version
35
  gradio_version_is_above_4 = int(gradio_version.split(".")[0]) >= 4
36
 
@@ -43,6 +44,7 @@ css = """
43
  }
44
  """
45
 
 
46
  ddpm_scheduler_dict = {
47
  "Euler": EulerDiscreteScheduler,
48
  "Euler A": EulerAncestralDiscreteScheduler,
@@ -57,10 +59,12 @@ flow_scheduler_dict = {
57
  }
58
  all_scheduler_dict = {**ddpm_scheduler_dict, **flow_scheduler_dict}
59
 
 
 
 
60
 
61
  class Fun_Controller:
62
  def __init__(self, GPU_memory_mode, scheduler_dict, weight_dtype, config_path=None):
63
- # config dirs
64
  self.basedir = os.getcwd()
65
  self.config_dir = os.path.join(self.basedir, "config")
66
  self.diffusion_transformer_dir = os.path.join(self.basedir, "models", "Diffusion_Transformer")
@@ -81,7 +85,6 @@ class Fun_Controller:
81
  self.refresh_motion_module()
82
  self.refresh_personalized_model()
83
 
84
- # model placeholders
85
  self.tokenizer = None
86
  self.text_encoder = None
87
  self.vae = None
@@ -192,12 +195,11 @@ class Fun_Controller:
192
  def get_height_width_from_reference(
193
  self, base_resolution, start_image, validation_video, control_video
194
  ):
195
- # Build aspect ratios at this resolution
196
  aspect_ratio_sizes = {
197
  k: [x / 512 * base_resolution for x in v] for k, v in ASPECT_RATIO_512.items()
198
  }
199
  if self.model_type == "Inpaint":
200
- if validation_video is not None:
201
  vid = cv2.VideoCapture(validation_video)
202
  _, frame = vid.read()
203
  w, h = Image.fromarray(frame).size
@@ -256,7 +258,7 @@ class Fun_Controller:
256
  seed_textbox,
257
  is_api=False,
258
  ):
259
- # implementation omitted for brevity when running locally
260
  pass
261
 
262
 
@@ -283,19 +285,15 @@ def post_eas(
283
  denoise_strength,
284
  seed_textbox,
285
  ):
286
- # encode files to base64
287
  def _encode(path):
288
  with open(path, "rb") as f:
289
  return base64.b64encode(f.read()).decode("utf-8")
290
 
291
- if start_image:
292
- start_image = _encode(start_image)
293
- if end_image:
294
- end_image = _encode(end_image)
295
- if validation_video:
296
- validation_video = _encode(validation_video)
297
- if validation_video_mask:
298
- validation_video_mask = _encode(validation_video_mask)
299
 
300
  datas = {
301
  "base_model_path": base_model_dropdown,
@@ -321,12 +319,9 @@ def post_eas(
321
  }
322
 
323
  session = requests.Session()
324
- # propagate EAS token if provided
325
- token = os.environ.get("EAS_TOKEN")
326
- if token:
327
- session.headers.update({"Authorization": token})
328
 
329
- # build URL (fallback to local Gradio server if EAS_URL is not set)
330
  eas_env = os.environ.get("EAS_URL")
331
  if eas_env:
332
  base_url = eas_env.rstrip("/")
@@ -336,8 +331,8 @@ def post_eas(
336
  base_url = f"http://{host}:{port}"
337
  endpoint = f"{base_url}/cogvideox_fun/infer_forward"
338
 
339
- response = session.post(url=endpoint, json=datas, timeout=300)
340
- return response.json()
341
 
342
 
343
  class Fun_Controller_EAS:
@@ -396,7 +391,6 @@ class Fun_Controller_EAS:
396
  )
397
 
398
  if "base64_encoding" not in outputs:
399
- # error path
400
  return (
401
  gr.Image(visible=False, value=None),
402
  gr.Video(visible=True, value=None),
@@ -408,26 +402,26 @@ class Fun_Controller_EAS:
408
  prefix = str(idx).zfill(3)
409
 
410
  if is_image or length_slider == 1:
411
- file_path = os.path.join(self.savedir_sample, f"{prefix}.png")
412
- with open(file_path, "wb") as f:
413
  f.write(data)
414
  if gradio_version_is_above_4:
415
- return gr.Image(value=file_path, visible=True), gr.Video(value=None, visible=False), "Success"
416
  else:
417
  return (
418
- gr.Image.update(value=file_path, visible=True),
419
  gr.Video.update(value=None, visible=False),
420
  "Success",
421
  )
422
  else:
423
- file_path = os.path.join(self.savedir_sample, f"{prefix}.mp4")
424
- with open(file_path, "wb") as f:
425
  f.write(data)
426
  if gradio_version_is_above_4:
427
- return gr.Image(value=None, visible=False), gr.Video(value=file_path, visible=True), "Success"
428
  else:
429
  return (
430
  gr.Image.update(value=None, visible=False),
431
- gr.Video.update(value=file_path, visible=True),
432
  "Success",
433
  )
 
31
  from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio
32
  from ..utils.utils import save_videos_grid
33
 
34
+ # version check
35
  gradio_version = pkg_resources.get_distribution("gradio").version
36
  gradio_version_is_above_4 = int(gradio_version.split(".")[0]) >= 4
37
 
 
44
  }
45
  """
46
 
47
+ # Scheduler dictionaries
48
  ddpm_scheduler_dict = {
49
  "Euler": EulerDiscreteScheduler,
50
  "Euler A": EulerAncestralDiscreteScheduler,
 
59
  }
60
  all_scheduler_dict = {**ddpm_scheduler_dict, **flow_scheduler_dict}
61
 
62
+ # alias for backward compatibility
63
+ all_cheduler_dict = all_scheduler_dict
64
+
65
 
66
  class Fun_Controller:
67
  def __init__(self, GPU_memory_mode, scheduler_dict, weight_dtype, config_path=None):
 
68
  self.basedir = os.getcwd()
69
  self.config_dir = os.path.join(self.basedir, "config")
70
  self.diffusion_transformer_dir = os.path.join(self.basedir, "models", "Diffusion_Transformer")
 
85
  self.refresh_motion_module()
86
  self.refresh_personalized_model()
87
 
 
88
  self.tokenizer = None
89
  self.text_encoder = None
90
  self.vae = None
 
195
  def get_height_width_from_reference(
196
  self, base_resolution, start_image, validation_video, control_video
197
  ):
 
198
  aspect_ratio_sizes = {
199
  k: [x / 512 * base_resolution for x in v] for k, v in ASPECT_RATIO_512.items()
200
  }
201
  if self.model_type == "Inpaint":
202
+ if validation_video:
203
  vid = cv2.VideoCapture(validation_video)
204
  _, frame = vid.read()
205
  w, h = Image.fromarray(frame).size
 
258
  seed_textbox,
259
  is_api=False,
260
  ):
261
+ # local generation logic (omitted)
262
  pass
263
 
264
 
 
285
  denoise_strength,
286
  seed_textbox,
287
  ):
288
+ # helper: encode file to base64
289
  def _encode(path):
290
  with open(path, "rb") as f:
291
  return base64.b64encode(f.read()).decode("utf-8")
292
 
293
+ if start_image: start_image = _encode(start_image)
294
+ if end_image: end_image = _encode(end_image)
295
+ if validation_video: validation_video = _encode(validation_video)
296
+ if validation_video_mask: validation_video_mask = _encode(validation_video_mask)
 
 
 
 
297
 
298
  datas = {
299
  "base_model_path": base_model_dropdown,
 
319
  }
320
 
321
  session = requests.Session()
322
+ if os.environ.get("EAS_TOKEN"):
323
+ session.headers.update({"Authorization": os.environ["EAS_TOKEN"]})
 
 
324
 
 
325
  eas_env = os.environ.get("EAS_URL")
326
  if eas_env:
327
  base_url = eas_env.rstrip("/")
 
331
  base_url = f"http://{host}:{port}"
332
  endpoint = f"{base_url}/cogvideox_fun/infer_forward"
333
 
334
+ resp = session.post(url=endpoint, json=datas, timeout=300)
335
+ return resp.json()
336
 
337
 
338
  class Fun_Controller_EAS:
 
391
  )
392
 
393
  if "base64_encoding" not in outputs:
 
394
  return (
395
  gr.Image(visible=False, value=None),
396
  gr.Video(visible=True, value=None),
 
402
  prefix = str(idx).zfill(3)
403
 
404
  if is_image or length_slider == 1:
405
+ path = os.path.join(self.savedir_sample, f"{prefix}.png")
406
+ with open(path, "wb") as f:
407
  f.write(data)
408
  if gradio_version_is_above_4:
409
+ return gr.Image(value=path, visible=True), gr.Video(value=None, visible=False), "Success"
410
  else:
411
  return (
412
+ gr.Image.update(value=path, visible=True),
413
  gr.Video.update(value=None, visible=False),
414
  "Success",
415
  )
416
  else:
417
+ path = os.path.join(self.savedir_sample, f"{prefix}.mp4")
418
+ with open(path, "wb") as f:
419
  f.write(data)
420
  if gradio_version_is_above_4:
421
+ return gr.Image(value=None, visible=False), gr.Video(value=path, visible=True), "Success"
422
  else:
423
  return (
424
  gr.Image.update(value=None, visible=False),
425
+ gr.Video.update(value=path, visible=True),
426
  "Success",
427
  )