Update cogvideox/ui/controller.py
Browse files- cogvideox/ui/controller.py +24 -30
cogvideox/ui/controller.py
CHANGED
|
@@ -31,6 +31,7 @@ from safetensors import safe_open
|
|
| 31 |
from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio
|
| 32 |
from ..utils.utils import save_videos_grid
|
| 33 |
|
|
|
|
| 34 |
gradio_version = pkg_resources.get_distribution("gradio").version
|
| 35 |
gradio_version_is_above_4 = int(gradio_version.split(".")[0]) >= 4
|
| 36 |
|
|
@@ -43,6 +44,7 @@ css = """
|
|
| 43 |
}
|
| 44 |
"""
|
| 45 |
|
|
|
|
| 46 |
ddpm_scheduler_dict = {
|
| 47 |
"Euler": EulerDiscreteScheduler,
|
| 48 |
"Euler A": EulerAncestralDiscreteScheduler,
|
|
@@ -57,10 +59,12 @@ flow_scheduler_dict = {
|
|
| 57 |
}
|
| 58 |
all_scheduler_dict = {**ddpm_scheduler_dict, **flow_scheduler_dict}
|
| 59 |
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
class Fun_Controller:
|
| 62 |
def __init__(self, GPU_memory_mode, scheduler_dict, weight_dtype, config_path=None):
|
| 63 |
-
# config dirs
|
| 64 |
self.basedir = os.getcwd()
|
| 65 |
self.config_dir = os.path.join(self.basedir, "config")
|
| 66 |
self.diffusion_transformer_dir = os.path.join(self.basedir, "models", "Diffusion_Transformer")
|
|
@@ -81,7 +85,6 @@ class Fun_Controller:
|
|
| 81 |
self.refresh_motion_module()
|
| 82 |
self.refresh_personalized_model()
|
| 83 |
|
| 84 |
-
# model placeholders
|
| 85 |
self.tokenizer = None
|
| 86 |
self.text_encoder = None
|
| 87 |
self.vae = None
|
|
@@ -192,12 +195,11 @@ class Fun_Controller:
|
|
| 192 |
def get_height_width_from_reference(
|
| 193 |
self, base_resolution, start_image, validation_video, control_video
|
| 194 |
):
|
| 195 |
-
# Build aspect ratios at this resolution
|
| 196 |
aspect_ratio_sizes = {
|
| 197 |
k: [x / 512 * base_resolution for x in v] for k, v in ASPECT_RATIO_512.items()
|
| 198 |
}
|
| 199 |
if self.model_type == "Inpaint":
|
| 200 |
-
if validation_video
|
| 201 |
vid = cv2.VideoCapture(validation_video)
|
| 202 |
_, frame = vid.read()
|
| 203 |
w, h = Image.fromarray(frame).size
|
|
@@ -256,7 +258,7 @@ class Fun_Controller:
|
|
| 256 |
seed_textbox,
|
| 257 |
is_api=False,
|
| 258 |
):
|
| 259 |
-
#
|
| 260 |
pass
|
| 261 |
|
| 262 |
|
|
@@ -283,19 +285,15 @@ def post_eas(
|
|
| 283 |
denoise_strength,
|
| 284 |
seed_textbox,
|
| 285 |
):
|
| 286 |
-
# encode
|
| 287 |
def _encode(path):
|
| 288 |
with open(path, "rb") as f:
|
| 289 |
return base64.b64encode(f.read()).decode("utf-8")
|
| 290 |
|
| 291 |
-
if start_image:
|
| 292 |
-
|
| 293 |
-
if
|
| 294 |
-
|
| 295 |
-
if validation_video:
|
| 296 |
-
validation_video = _encode(validation_video)
|
| 297 |
-
if validation_video_mask:
|
| 298 |
-
validation_video_mask = _encode(validation_video_mask)
|
| 299 |
|
| 300 |
datas = {
|
| 301 |
"base_model_path": base_model_dropdown,
|
|
@@ -321,12 +319,9 @@ def post_eas(
|
|
| 321 |
}
|
| 322 |
|
| 323 |
session = requests.Session()
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
if token:
|
| 327 |
-
session.headers.update({"Authorization": token})
|
| 328 |
|
| 329 |
-
# build URL (fallback to local Gradio server if EAS_URL is not set)
|
| 330 |
eas_env = os.environ.get("EAS_URL")
|
| 331 |
if eas_env:
|
| 332 |
base_url = eas_env.rstrip("/")
|
|
@@ -336,8 +331,8 @@ def post_eas(
|
|
| 336 |
base_url = f"http://{host}:{port}"
|
| 337 |
endpoint = f"{base_url}/cogvideox_fun/infer_forward"
|
| 338 |
|
| 339 |
-
|
| 340 |
-
return
|
| 341 |
|
| 342 |
|
| 343 |
class Fun_Controller_EAS:
|
|
@@ -396,7 +391,6 @@ class Fun_Controller_EAS:
|
|
| 396 |
)
|
| 397 |
|
| 398 |
if "base64_encoding" not in outputs:
|
| 399 |
-
# error path
|
| 400 |
return (
|
| 401 |
gr.Image(visible=False, value=None),
|
| 402 |
gr.Video(visible=True, value=None),
|
|
@@ -408,26 +402,26 @@ class Fun_Controller_EAS:
|
|
| 408 |
prefix = str(idx).zfill(3)
|
| 409 |
|
| 410 |
if is_image or length_slider == 1:
|
| 411 |
-
|
| 412 |
-
with open(
|
| 413 |
f.write(data)
|
| 414 |
if gradio_version_is_above_4:
|
| 415 |
-
return gr.Image(value=
|
| 416 |
else:
|
| 417 |
return (
|
| 418 |
-
gr.Image.update(value=
|
| 419 |
gr.Video.update(value=None, visible=False),
|
| 420 |
"Success",
|
| 421 |
)
|
| 422 |
else:
|
| 423 |
-
|
| 424 |
-
with open(
|
| 425 |
f.write(data)
|
| 426 |
if gradio_version_is_above_4:
|
| 427 |
-
return gr.Image(value=None, visible=False), gr.Video(value=
|
| 428 |
else:
|
| 429 |
return (
|
| 430 |
gr.Image.update(value=None, visible=False),
|
| 431 |
-
gr.Video.update(value=
|
| 432 |
"Success",
|
| 433 |
)
|
|
|
|
| 31 |
from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio
|
| 32 |
from ..utils.utils import save_videos_grid
|
| 33 |
|
| 34 |
+
# version check
|
| 35 |
gradio_version = pkg_resources.get_distribution("gradio").version
|
| 36 |
gradio_version_is_above_4 = int(gradio_version.split(".")[0]) >= 4
|
| 37 |
|
|
|
|
| 44 |
}
|
| 45 |
"""
|
| 46 |
|
| 47 |
+
# Scheduler dictionaries
|
| 48 |
ddpm_scheduler_dict = {
|
| 49 |
"Euler": EulerDiscreteScheduler,
|
| 50 |
"Euler A": EulerAncestralDiscreteScheduler,
|
|
|
|
| 59 |
}
|
| 60 |
all_scheduler_dict = {**ddpm_scheduler_dict, **flow_scheduler_dict}
|
| 61 |
|
| 62 |
+
# alias for backward compatibility
|
| 63 |
+
all_cheduler_dict = all_scheduler_dict
|
| 64 |
+
|
| 65 |
|
| 66 |
class Fun_Controller:
|
| 67 |
def __init__(self, GPU_memory_mode, scheduler_dict, weight_dtype, config_path=None):
|
|
|
|
| 68 |
self.basedir = os.getcwd()
|
| 69 |
self.config_dir = os.path.join(self.basedir, "config")
|
| 70 |
self.diffusion_transformer_dir = os.path.join(self.basedir, "models", "Diffusion_Transformer")
|
|
|
|
| 85 |
self.refresh_motion_module()
|
| 86 |
self.refresh_personalized_model()
|
| 87 |
|
|
|
|
| 88 |
self.tokenizer = None
|
| 89 |
self.text_encoder = None
|
| 90 |
self.vae = None
|
|
|
|
| 195 |
def get_height_width_from_reference(
|
| 196 |
self, base_resolution, start_image, validation_video, control_video
|
| 197 |
):
|
|
|
|
| 198 |
aspect_ratio_sizes = {
|
| 199 |
k: [x / 512 * base_resolution for x in v] for k, v in ASPECT_RATIO_512.items()
|
| 200 |
}
|
| 201 |
if self.model_type == "Inpaint":
|
| 202 |
+
if validation_video:
|
| 203 |
vid = cv2.VideoCapture(validation_video)
|
| 204 |
_, frame = vid.read()
|
| 205 |
w, h = Image.fromarray(frame).size
|
|
|
|
| 258 |
seed_textbox,
|
| 259 |
is_api=False,
|
| 260 |
):
|
| 261 |
+
# local generation logic (omitted)
|
| 262 |
pass
|
| 263 |
|
| 264 |
|
|
|
|
| 285 |
denoise_strength,
|
| 286 |
seed_textbox,
|
| 287 |
):
|
| 288 |
+
# helper: encode file to base64
|
| 289 |
def _encode(path):
|
| 290 |
with open(path, "rb") as f:
|
| 291 |
return base64.b64encode(f.read()).decode("utf-8")
|
| 292 |
|
| 293 |
+
if start_image: start_image = _encode(start_image)
|
| 294 |
+
if end_image: end_image = _encode(end_image)
|
| 295 |
+
if validation_video: validation_video = _encode(validation_video)
|
| 296 |
+
if validation_video_mask: validation_video_mask = _encode(validation_video_mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
|
| 298 |
datas = {
|
| 299 |
"base_model_path": base_model_dropdown,
|
|
|
|
| 319 |
}
|
| 320 |
|
| 321 |
session = requests.Session()
|
| 322 |
+
if os.environ.get("EAS_TOKEN"):
|
| 323 |
+
session.headers.update({"Authorization": os.environ["EAS_TOKEN"]})
|
|
|
|
|
|
|
| 324 |
|
|
|
|
| 325 |
eas_env = os.environ.get("EAS_URL")
|
| 326 |
if eas_env:
|
| 327 |
base_url = eas_env.rstrip("/")
|
|
|
|
| 331 |
base_url = f"http://{host}:{port}"
|
| 332 |
endpoint = f"{base_url}/cogvideox_fun/infer_forward"
|
| 333 |
|
| 334 |
+
resp = session.post(url=endpoint, json=datas, timeout=300)
|
| 335 |
+
return resp.json()
|
| 336 |
|
| 337 |
|
| 338 |
class Fun_Controller_EAS:
|
|
|
|
| 391 |
)
|
| 392 |
|
| 393 |
if "base64_encoding" not in outputs:
|
|
|
|
| 394 |
return (
|
| 395 |
gr.Image(visible=False, value=None),
|
| 396 |
gr.Video(visible=True, value=None),
|
|
|
|
| 402 |
prefix = str(idx).zfill(3)
|
| 403 |
|
| 404 |
if is_image or length_slider == 1:
|
| 405 |
+
path = os.path.join(self.savedir_sample, f"{prefix}.png")
|
| 406 |
+
with open(path, "wb") as f:
|
| 407 |
f.write(data)
|
| 408 |
if gradio_version_is_above_4:
|
| 409 |
+
return gr.Image(value=path, visible=True), gr.Video(value=None, visible=False), "Success"
|
| 410 |
else:
|
| 411 |
return (
|
| 412 |
+
gr.Image.update(value=path, visible=True),
|
| 413 |
gr.Video.update(value=None, visible=False),
|
| 414 |
"Success",
|
| 415 |
)
|
| 416 |
else:
|
| 417 |
+
path = os.path.join(self.savedir_sample, f"{prefix}.mp4")
|
| 418 |
+
with open(path, "wb") as f:
|
| 419 |
f.write(data)
|
| 420 |
if gradio_version_is_above_4:
|
| 421 |
+
return gr.Image(value=None, visible=False), gr.Video(value=path, visible=True), "Success"
|
| 422 |
else:
|
| 423 |
return (
|
| 424 |
gr.Image.update(value=None, visible=False),
|
| 425 |
+
gr.Video.update(value=path, visible=True),
|
| 426 |
"Success",
|
| 427 |
)
|