Spaces:
Running
on
Zero
Running
on
Zero
陈硕
commited on
Commit
·
3a2f1ee
1
Parent(s):
f8acb76
update orbit lora
Browse files
app.py
CHANGED
|
@@ -55,11 +55,21 @@ pipe_image = CogVideoXImageToVideoPipeline.from_pretrained(
|
|
| 55 |
text_encoder=pipe.text_encoder,
|
| 56 |
torch_dtype=torch.bfloat16,
|
| 57 |
)
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
|
| 65 |
# pipe.transformer.to(memory_format=torch.channels_last)
|
|
@@ -213,6 +223,7 @@ def convert_prompt(prompt: str, retry_times: int = 3) -> str:
|
|
| 213 |
@spaces.GPU
|
| 214 |
def infer(
|
| 215 |
prompt: str,
|
|
|
|
| 216 |
image_input: str,
|
| 217 |
num_inference_steps: int,
|
| 218 |
guidance_scale: float,
|
|
@@ -235,6 +246,16 @@ def infer(
|
|
| 235 |
# guidance_scale=guidance_scale,
|
| 236 |
# generator=torch.Generator(device="cpu").manual_seed(seed),
|
| 237 |
# ).frames
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
if image_input is not None:
|
| 239 |
image_input = Image.fromarray(image_input).resize(size=(720, 480)) # Convert to PIL
|
| 240 |
image = load_image(image_input)
|
|
@@ -301,6 +322,12 @@ with gr.Blocks() as demo:
|
|
| 301 |
</div>
|
| 302 |
""")
|
| 303 |
with gr.Row():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
with gr.Column():
|
| 305 |
with gr.Accordion("I2V: Image Input (cannot be used simultaneously with video input)", open=False):
|
| 306 |
image_input = gr.Image(label="Input Image (will be cropped to 720 * 480)")
|
|
@@ -340,6 +367,7 @@ with gr.Blocks() as demo:
|
|
| 340 |
|
| 341 |
def generate(
|
| 342 |
prompt,
|
|
|
|
| 343 |
image_input,
|
| 344 |
# video_input,
|
| 345 |
# video_strength,
|
|
@@ -350,6 +378,7 @@ with gr.Blocks() as demo:
|
|
| 350 |
):
|
| 351 |
latents, seed = infer(
|
| 352 |
prompt,
|
|
|
|
| 353 |
image_input,
|
| 354 |
# video_input,
|
| 355 |
# video_strength,
|
|
@@ -386,7 +415,7 @@ with gr.Blocks() as demo:
|
|
| 386 |
|
| 387 |
generate_button.click(
|
| 388 |
generate,
|
| 389 |
-
inputs=[prompt, image_input, seed_param, enable_scale, enable_rife],
|
| 390 |
outputs=[video_output, download_video_button, download_gif_button, seed_text],
|
| 391 |
)
|
| 392 |
|
|
|
|
| 55 |
text_encoder=pipe.text_encoder,
|
| 56 |
torch_dtype=torch.bfloat16,
|
| 57 |
)
|
| 58 |
+
|
| 59 |
+
os.makedirs("checkpoints", exist_ok=True)
|
| 60 |
+
|
| 61 |
+
# Download LoRA weights
|
| 62 |
+
hf_hub_download(
|
| 63 |
+
repo_id="wenqsun/DimensionX",
|
| 64 |
+
filename="orbit_left_lora_weights.safetensors",
|
| 65 |
+
local_dir="checkpoints"
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
hf_hub_download(
|
| 69 |
+
repo_id="wenqsun/DimensionX",
|
| 70 |
+
filename="orbit_up_lora_weights.safetensors",
|
| 71 |
+
local_dir="checkpoints"
|
| 72 |
+
)
|
| 73 |
|
| 74 |
|
| 75 |
# pipe.transformer.to(memory_format=torch.channels_last)
|
|
|
|
| 223 |
@spaces.GPU
|
| 224 |
def infer(
|
| 225 |
prompt: str,
|
| 226 |
+
orbit_type: str,
|
| 227 |
image_input: str,
|
| 228 |
num_inference_steps: int,
|
| 229 |
guidance_scale: float,
|
|
|
|
| 246 |
# guidance_scale=guidance_scale,
|
| 247 |
# generator=torch.Generator(device="cpu").manual_seed(seed),
|
| 248 |
# ).frames
|
| 249 |
+
|
| 250 |
+
lora_path = "checkpoints/"
|
| 251 |
+
weight_name = "orbit_left_lora_weights.safetensors" if orbit_type == "Left" else "orbit_up_lora_weights.safetensors"
|
| 252 |
+
lora_rank = 256
|
| 253 |
+
adapter_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 254 |
+
|
| 255 |
+
# Load LoRA weights on CPU
|
| 256 |
+
pipe.load_lora_weights(lora_path, weight_name=weight_name, adapter_name=f"adapter_{adapter_timestamp}")
|
| 257 |
+
pipe.fuse_lora(lora_scale=1 / lora_rank)
|
| 258 |
+
|
| 259 |
if image_input is not None:
|
| 260 |
image_input = Image.fromarray(image_input).resize(size=(720, 480)) # Convert to PIL
|
| 261 |
image = load_image(image_input)
|
|
|
|
| 322 |
</div>
|
| 323 |
""")
|
| 324 |
with gr.Row():
|
| 325 |
+
with gr.Column():
|
| 326 |
+
image_in = gr.Image(label="Image Input", type="filepath")
|
| 327 |
+
prompt = gr.Textbox(label="Prompt")
|
| 328 |
+
orbit_type = gr.Radio(label="Orbit type", choices=["Left", "Up"], value="Left", interactive=True)
|
| 329 |
+
submit_btn = gr.Button("Submit")
|
| 330 |
+
|
| 331 |
with gr.Column():
|
| 332 |
with gr.Accordion("I2V: Image Input (cannot be used simultaneously with video input)", open=False):
|
| 333 |
image_input = gr.Image(label="Input Image (will be cropped to 720 * 480)")
|
|
|
|
| 367 |
|
| 368 |
def generate(
|
| 369 |
prompt,
|
| 370 |
+
orbit_type,
|
| 371 |
image_input,
|
| 372 |
# video_input,
|
| 373 |
# video_strength,
|
|
|
|
| 378 |
):
|
| 379 |
latents, seed = infer(
|
| 380 |
prompt,
|
| 381 |
+
orbit_type,
|
| 382 |
image_input,
|
| 383 |
# video_input,
|
| 384 |
# video_strength,
|
|
|
|
| 415 |
|
| 416 |
generate_button.click(
|
| 417 |
generate,
|
| 418 |
+
inputs=[prompt, orbit_type, image_input, seed_param, enable_scale, enable_rife],
|
| 419 |
outputs=[video_output, download_video_button, download_gif_button, seed_text],
|
| 420 |
)
|
| 421 |
|