Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
97a5f5d
1
Parent(s):
c7eaeb0
dmd + cof
Browse files- app.py +14 -6
- inference.py +22 -1
- videox_fun/ui/ui.py +2 -2
- videox_fun/utils/lora_utils.py +18 -40
app.py
CHANGED
|
@@ -79,6 +79,7 @@ def load_video_frames(video_path: str, source_frames: int):
|
|
| 79 |
pil_frame = Image.fromarray(frame)
|
| 80 |
if original_height is None:
|
| 81 |
original_width, original_height = pil_frame.size
|
|
|
|
| 82 |
frames.append(pil_frame)
|
| 83 |
except IndexError:
|
| 84 |
break
|
|
@@ -92,6 +93,9 @@ def load_video_frames(video_path: str, source_frames: int):
|
|
| 92 |
w, h = (original_width, original_height) if original_width else (832, 480)
|
| 93 |
frames.append(Image.new('RGB', (w, h), (0, 0, 0)))
|
| 94 |
|
|
|
|
|
|
|
|
|
|
| 95 |
input_video = torch.from_numpy(np.array(frames))
|
| 96 |
input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0).float()
|
| 97 |
input_video = input_video * (2.0 / 255.0) - 1.0
|
|
@@ -143,6 +147,8 @@ class VideoCoF_Controller(Wan_Controller):
|
|
| 143 |
# Ensure model is on CUDA inside the zero-gpu decorated function
|
| 144 |
if torch.cuda.is_available():
|
| 145 |
self.device = torch.device("cuda")
|
|
|
|
|
|
|
| 146 |
# If pipeline is not on cuda, move it (if possible, but usually accelerate handles this or it's handled by parts)
|
| 147 |
# However, Wan_Controller logic might rely on `self.device`.
|
| 148 |
# We explicitly set `self.device` to cuda here.
|
|
@@ -166,7 +172,7 @@ class VideoCoF_Controller(Wan_Controller):
|
|
| 166 |
# 1. Merge VideoCoF LoRA
|
| 167 |
if self.lora_model_path != "none":
|
| 168 |
print(f"Merge VideoCoF Lora: {self.lora_model_path}")
|
| 169 |
-
self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
|
| 170 |
|
| 171 |
# 2. Merge Acceleration LoRA (FusionX) if enabled
|
| 172 |
acc_lora_path = os.path.join(self.personalized_model_dir, "Wan2.1_Text_to_Video_14B_FusionX_LoRA.safetensors")
|
|
@@ -174,7 +180,7 @@ class VideoCoF_Controller(Wan_Controller):
|
|
| 174 |
if os.path.exists(acc_lora_path):
|
| 175 |
print(f"Merge Acceleration LoRA: {acc_lora_path}")
|
| 176 |
# FusionX LoRA generally uses multiplier 1.0
|
| 177 |
-
self.pipeline = merge_lora(self.pipeline, acc_lora_path, multiplier=1.0)
|
| 178 |
else:
|
| 179 |
print(f"Warning: Acceleration LoRA not found at {acc_lora_path}")
|
| 180 |
|
|
@@ -217,6 +223,7 @@ class VideoCoF_Controller(Wan_Controller):
|
|
| 217 |
print(f"Input video dimensions: {w}x{h}")
|
| 218 |
|
| 219 |
print(f"Running pipeline with frames={length_slider}, source={source_frames_slider}, reasoning={reasoning_frames_slider}")
|
|
|
|
| 220 |
|
| 221 |
sample = self.pipeline(
|
| 222 |
video=input_video_tensor,
|
|
@@ -230,6 +237,7 @@ class VideoCoF_Controller(Wan_Controller):
|
|
| 230 |
generator=generator,
|
| 231 |
guidance_scale=cfg_scale_slider,
|
| 232 |
num_inference_steps=sample_step_slider,
|
|
|
|
| 233 |
repeat_rope=repeat_rope_checkbox,
|
| 234 |
cot=True,
|
| 235 |
).videos
|
|
@@ -241,21 +249,21 @@ class VideoCoF_Controller(Wan_Controller):
|
|
| 241 |
# Unmerge in case of error (LIFO order)
|
| 242 |
if enable_acceleration and os.path.exists(acc_lora_path):
|
| 243 |
print("Unmerging Acceleration LoRA (due to error)")
|
| 244 |
-
self.pipeline = unmerge_lora(self.pipeline, acc_lora_path, multiplier=1.0)
|
| 245 |
|
| 246 |
if self.lora_model_path != "none":
|
| 247 |
print("Unmerging VideoCoF LoRA (due to error)")
|
| 248 |
-
self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
|
| 249 |
return gr.update(), gr.update(), f"Error: {str(e)}"
|
| 250 |
|
| 251 |
# Unmerge LoRAs (LIFO order)
|
| 252 |
if enable_acceleration and os.path.exists(acc_lora_path):
|
| 253 |
print("Unmerging Acceleration LoRA")
|
| 254 |
-
self.pipeline = unmerge_lora(self.pipeline, acc_lora_path, multiplier=1.0)
|
| 255 |
|
| 256 |
if self.lora_model_path != "none":
|
| 257 |
print("Unmerging VideoCoF LoRA")
|
| 258 |
-
self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
|
| 259 |
|
| 260 |
# Save output
|
| 261 |
save_sample_path = self.save_outputs(
|
|
|
|
| 79 |
pil_frame = Image.fromarray(frame)
|
| 80 |
if original_height is None:
|
| 81 |
original_width, original_height = pil_frame.size
|
| 82 |
+
print(f"Original video dimensions: {original_width}x{original_height}")
|
| 83 |
frames.append(pil_frame)
|
| 84 |
except IndexError:
|
| 85 |
break
|
|
|
|
| 93 |
w, h = (original_width, original_height) if original_width else (832, 480)
|
| 94 |
frames.append(Image.new('RGB', (w, h), (0, 0, 0)))
|
| 95 |
|
| 96 |
+
assert len(frames) == source_frames, f"Loaded {len(frames)} frames, expected {source_frames}"
|
| 97 |
+
print(f"Loaded {source_frames} source frames")
|
| 98 |
+
|
| 99 |
input_video = torch.from_numpy(np.array(frames))
|
| 100 |
input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0).float()
|
| 101 |
input_video = input_video * (2.0 / 255.0) - 1.0
|
|
|
|
| 147 |
# Ensure model is on CUDA inside the zero-gpu decorated function
|
| 148 |
if torch.cuda.is_available():
|
| 149 |
self.device = torch.device("cuda")
|
| 150 |
+
else:
|
| 151 |
+
self.device = torch.device("cpu")
|
| 152 |
# If pipeline is not on cuda, move it (if possible, but usually accelerate handles this or it's handled by parts)
|
| 153 |
# However, Wan_Controller logic might rely on `self.device`.
|
| 154 |
# We explicitly set `self.device` to cuda here.
|
|
|
|
| 172 |
# 1. Merge VideoCoF LoRA
|
| 173 |
if self.lora_model_path != "none":
|
| 174 |
print(f"Merge VideoCoF Lora: {self.lora_model_path}")
|
| 175 |
+
self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider, device=self.device)
|
| 176 |
|
| 177 |
# 2. Merge Acceleration LoRA (FusionX) if enabled
|
| 178 |
acc_lora_path = os.path.join(self.personalized_model_dir, "Wan2.1_Text_to_Video_14B_FusionX_LoRA.safetensors")
|
|
|
|
| 180 |
if os.path.exists(acc_lora_path):
|
| 181 |
print(f"Merge Acceleration LoRA: {acc_lora_path}")
|
| 182 |
# FusionX LoRA generally uses multiplier 1.0
|
| 183 |
+
self.pipeline = merge_lora(self.pipeline, acc_lora_path, multiplier=1.0, device=self.device)
|
| 184 |
else:
|
| 185 |
print(f"Warning: Acceleration LoRA not found at {acc_lora_path}")
|
| 186 |
|
|
|
|
| 223 |
print(f"Input video dimensions: {w}x{h}")
|
| 224 |
|
| 225 |
print(f"Running pipeline with frames={length_slider}, source={source_frames_slider}, reasoning={reasoning_frames_slider}")
|
| 226 |
+
shift = 3
|
| 227 |
|
| 228 |
sample = self.pipeline(
|
| 229 |
video=input_video_tensor,
|
|
|
|
| 237 |
generator=generator,
|
| 238 |
guidance_scale=cfg_scale_slider,
|
| 239 |
num_inference_steps=sample_step_slider,
|
| 240 |
+
shift=shift,
|
| 241 |
repeat_rope=repeat_rope_checkbox,
|
| 242 |
cot=True,
|
| 243 |
).videos
|
|
|
|
| 249 |
# Unmerge in case of error (LIFO order)
|
| 250 |
if enable_acceleration and os.path.exists(acc_lora_path):
|
| 251 |
print("Unmerging Acceleration LoRA (due to error)")
|
| 252 |
+
self.pipeline = unmerge_lora(self.pipeline, acc_lora_path, multiplier=1.0, device=self.device)
|
| 253 |
|
| 254 |
if self.lora_model_path != "none":
|
| 255 |
print("Unmerging VideoCoF LoRA (due to error)")
|
| 256 |
+
self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider, device=self.device)
|
| 257 |
return gr.update(), gr.update(), f"Error: {str(e)}"
|
| 258 |
|
| 259 |
# Unmerge LoRAs (LIFO order)
|
| 260 |
if enable_acceleration and os.path.exists(acc_lora_path):
|
| 261 |
print("Unmerging Acceleration LoRA")
|
| 262 |
+
self.pipeline = unmerge_lora(self.pipeline, acc_lora_path, multiplier=1.0, device=self.device)
|
| 263 |
|
| 264 |
if self.lora_model_path != "none":
|
| 265 |
print("Unmerging VideoCoF LoRA")
|
| 266 |
+
self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider, device=self.device)
|
| 267 |
|
| 268 |
# Save output
|
| 269 |
save_sample_path = self.save_outputs(
|
inference.py
CHANGED
|
@@ -89,6 +89,9 @@ def parse_args():
|
|
| 89 |
parser.add_argument("--output_dir", type=str, required=True, help="Output directory for generated videos")
|
| 90 |
parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducible generation")
|
| 91 |
parser.add_argument("--videocof_path", type=str, default=None, help="Path to videocof weight checkpoint")
|
|
|
|
|
|
|
|
|
|
| 92 |
parser.add_argument("--num_frames", type=int, default=65, help="Total number of frames (input + generated)")
|
| 93 |
parser.add_argument("--source_frames", type=int, default=33, help="Number of source frames; default 33")
|
| 94 |
parser.add_argument("--reasoning_frames", type=int, default=4, help="Grounding frames in the middle segment (pixel-space)")
|
|
@@ -320,7 +323,25 @@ def main():
|
|
| 320 |
else:
|
| 321 |
pipeline.to(device=device)
|
| 322 |
|
| 323 |
-
# LoRA
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
if args.videocof_path is not None:
|
| 325 |
pipeline = merge_lora(pipeline, args.videocof_path, lora_weight, device=device)
|
| 326 |
print(f"[GPU {rank}] Loaded LoRA from {args.videocof_path}")
|
|
|
|
| 89 |
parser.add_argument("--output_dir", type=str, required=True, help="Output directory for generated videos")
|
| 90 |
parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducible generation")
|
| 91 |
parser.add_argument("--videocof_path", type=str, default=None, help="Path to videocof weight checkpoint")
|
| 92 |
+
parser.add_argument("--lora_path", type=str, default=None, help="Path to LoRA checkpoint")
|
| 93 |
+
parser.add_argument("--enable_acceleration_lora", action="store_true", help="Enable loading the acceleration (FusionX) LoRA")
|
| 94 |
+
parser.add_argument("--acceleration_lora_path", type=str, default=None, help="Optional path to acceleration LoRA; defaults to FusionX under model directory")
|
| 95 |
parser.add_argument("--num_frames", type=int, default=65, help="Total number of frames (input + generated)")
|
| 96 |
parser.add_argument("--source_frames", type=int, default=33, help="Number of source frames; default 33")
|
| 97 |
parser.add_argument("--reasoning_frames", type=int, default=4, help="Grounding frames in the middle segment (pixel-space)")
|
|
|
|
| 323 |
else:
|
| 324 |
pipeline.to(device=device)
|
| 325 |
|
| 326 |
+
# Acceleration LoRA (FusionX) mirrors app.py behavior
|
| 327 |
+
if args.enable_acceleration_lora:
|
| 328 |
+
default_acc_path = os.path.join(model_name, "Wan2.1_Text_to_Video_14B_FusionX_LoRA.safetensors")
|
| 329 |
+
acc_lora_path = args.acceleration_lora_path or default_acc_path
|
| 330 |
+
if os.path.exists(acc_lora_path):
|
| 331 |
+
print(f"[GPU {rank}] Merge Acceleration LoRA: {acc_lora_path}")
|
| 332 |
+
pipeline = merge_lora(pipeline, acc_lora_path, multiplier=1.0, device=device)
|
| 333 |
+
else:
|
| 334 |
+
print(f"[GPU {rank}] Warning: Acceleration LoRA not found at {acc_lora_path}")
|
| 335 |
+
|
| 336 |
+
# Custom LoRA
|
| 337 |
+
if args.lora_path is not None:
|
| 338 |
+
if os.path.exists(args.lora_path):
|
| 339 |
+
print(f"[GPU {rank}] Loading custom LoRA: {args.lora_path}")
|
| 340 |
+
pipeline = merge_lora(pipeline, args.lora_path, lora_weight, device=device)
|
| 341 |
+
else:
|
| 342 |
+
print(f"[GPU {rank}] Warning: Provided lora_path not found: {args.lora_path}")
|
| 343 |
+
|
| 344 |
+
# VideoCoF LoRA
|
| 345 |
if args.videocof_path is not None:
|
| 346 |
pipeline = merge_lora(pipeline, args.videocof_path, lora_weight, device=device)
|
| 347 |
print(f"[GPU {rank}] Loaded LoRA from {args.videocof_path}")
|
videox_fun/ui/ui.py
CHANGED
|
@@ -194,10 +194,10 @@ def create_prompts(
|
|
| 194 |
negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value=negative_prompt)
|
| 195 |
return prompt_textbox, negative_prompt_textbox
|
| 196 |
|
| 197 |
-
def create_samplers(controller, maximum_step=
|
| 198 |
with gr.Row():
|
| 199 |
sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(controller.scheduler_dict.keys()), value=list(controller.scheduler_dict.keys())[0])
|
| 200 |
-
sample_step_slider = gr.Slider(label="Sampling steps", value=
|
| 201 |
|
| 202 |
return sampler_dropdown, sample_step_slider
|
| 203 |
|
|
|
|
| 194 |
negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value=negative_prompt)
|
| 195 |
return prompt_textbox, negative_prompt_textbox
|
| 196 |
|
| 197 |
+
def create_samplers(controller, maximum_step=50):
|
| 198 |
with gr.Row():
|
| 199 |
sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(controller.scheduler_dict.keys()), value=list(controller.scheduler_dict.keys())[0])
|
| 200 |
+
sample_step_slider = gr.Slider(label="Sampling steps", value=4, minimum=1, maximum=maximum_step, step=1)
|
| 201 |
|
| 202 |
return sampler_dropdown, sample_step_slider
|
| 203 |
|
videox_fun/utils/lora_utils.py
CHANGED
|
@@ -389,28 +389,9 @@ def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float3
|
|
| 389 |
key = key.replace(".self_attn.", "_self_attn_")
|
| 390 |
key = key.replace(".cross_attn.", "_cross_attn_")
|
| 391 |
key = key.replace(".ffn.", "_ffn_")
|
| 392 |
-
key = key.replace("text_embedding.", "text_embedding_")
|
| 393 |
-
key = key.replace("time_embedding.", "time_embedding_")
|
| 394 |
key = key.replace(".lora_A.default.", ".lora_down.")
|
| 395 |
key = key.replace(".lora_B.default.", ".lora_up.")
|
| 396 |
-
|
| 397 |
-
key = key.replace(".lora_B.weight", ".lora_up.weight")
|
| 398 |
-
|
| 399 |
-
if key.endswith(".lora_down.weight"):
|
| 400 |
-
layer = key[:-len(".lora_down.weight")]
|
| 401 |
-
elem = "lora_down.weight"
|
| 402 |
-
elif key.endswith(".lora_up.weight"):
|
| 403 |
-
layer = key[:-len(".lora_up.weight")]
|
| 404 |
-
elem = "lora_up.weight"
|
| 405 |
-
elif key.endswith(".alpha"):
|
| 406 |
-
layer = key[:-len(".alpha")]
|
| 407 |
-
elem = "alpha"
|
| 408 |
-
else:
|
| 409 |
-
continue
|
| 410 |
-
|
| 411 |
-
if layer.endswith("."):
|
| 412 |
-
layer = layer[:-1]
|
| 413 |
-
|
| 414 |
updates[layer][elem] = value
|
| 415 |
|
| 416 |
sequential_cpu_offload_flag = False
|
|
@@ -484,10 +465,20 @@ def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float3
|
|
| 484 |
if error_flag:
|
| 485 |
continue
|
| 486 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 487 |
origin_dtype = curr_layer.weight.data.dtype
|
| 488 |
origin_device = curr_layer.weight.data.device
|
| 489 |
|
| 490 |
curr_layer = curr_layer.to(device, dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 491 |
weight_up = elems['lora_up.weight'].to(device, dtype)
|
| 492 |
weight_down = elems['lora_down.weight'].to(device, dtype)
|
| 493 |
|
|
@@ -529,28 +520,9 @@ def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.fl
|
|
| 529 |
key = key.replace(".self_attn.", "_self_attn_")
|
| 530 |
key = key.replace(".cross_attn.", "_cross_attn_")
|
| 531 |
key = key.replace(".ffn.", "_ffn_")
|
| 532 |
-
key = key.replace("text_embedding.", "text_embedding_")
|
| 533 |
-
key = key.replace("time_embedding.", "time_embedding_")
|
| 534 |
key = key.replace(".lora_A.default.", ".lora_down.")
|
| 535 |
key = key.replace(".lora_B.default.", ".lora_up.")
|
| 536 |
-
|
| 537 |
-
key = key.replace(".lora_B.weight", ".lora_up.weight")
|
| 538 |
-
|
| 539 |
-
if key.endswith(".lora_down.weight"):
|
| 540 |
-
layer = key[:-len(".lora_down.weight")]
|
| 541 |
-
elem = "lora_down.weight"
|
| 542 |
-
elif key.endswith(".lora_up.weight"):
|
| 543 |
-
layer = key[:-len(".lora_up.weight")]
|
| 544 |
-
elem = "lora_up.weight"
|
| 545 |
-
elif key.endswith(".alpha"):
|
| 546 |
-
layer = key[:-len(".alpha")]
|
| 547 |
-
elem = "alpha"
|
| 548 |
-
else:
|
| 549 |
-
continue
|
| 550 |
-
|
| 551 |
-
if layer.endswith("."):
|
| 552 |
-
layer = layer[:-1]
|
| 553 |
-
|
| 554 |
updates[layer][elem] = value
|
| 555 |
|
| 556 |
sequential_cpu_offload_flag = False
|
|
@@ -617,10 +589,16 @@ def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.fl
|
|
| 617 |
if error_flag:
|
| 618 |
continue
|
| 619 |
|
|
|
|
|
|
|
|
|
|
| 620 |
origin_dtype = curr_layer.weight.data.dtype
|
| 621 |
origin_device = curr_layer.weight.data.device
|
| 622 |
|
| 623 |
curr_layer = curr_layer.to(device, dtype)
|
|
|
|
|
|
|
|
|
|
| 624 |
weight_up = elems['lora_up.weight'].to(device, dtype)
|
| 625 |
weight_down = elems['lora_down.weight'].to(device, dtype)
|
| 626 |
|
|
|
|
| 389 |
key = key.replace(".self_attn.", "_self_attn_")
|
| 390 |
key = key.replace(".cross_attn.", "_cross_attn_")
|
| 391 |
key = key.replace(".ffn.", "_ffn_")
|
|
|
|
|
|
|
| 392 |
key = key.replace(".lora_A.default.", ".lora_down.")
|
| 393 |
key = key.replace(".lora_B.default.", ".lora_up.")
|
| 394 |
+
layer, elem = key.split('.', 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
updates[layer][elem] = value
|
| 396 |
|
| 397 |
sequential_cpu_offload_flag = False
|
|
|
|
| 465 |
if error_flag:
|
| 466 |
continue
|
| 467 |
|
| 468 |
+
# Some resolved modules (e.g., container blocks/norm-only) may not have a weight parameter.
|
| 469 |
+
if not hasattr(curr_layer, "weight"):
|
| 470 |
+
# Skip incompatible / non-leaf modules
|
| 471 |
+
continue
|
| 472 |
+
|
| 473 |
origin_dtype = curr_layer.weight.data.dtype
|
| 474 |
origin_device = curr_layer.weight.data.device
|
| 475 |
|
| 476 |
curr_layer = curr_layer.to(device, dtype)
|
| 477 |
+
# Some checkpoints (e.g., norm-only entries) may not contain both weights.
|
| 478 |
+
if 'lora_up.weight' not in elems or 'lora_down.weight' not in elems:
|
| 479 |
+
# Skip incompatible layer instead of raising KeyError
|
| 480 |
+
curr_layer = curr_layer.to(origin_device, origin_dtype)
|
| 481 |
+
continue
|
| 482 |
weight_up = elems['lora_up.weight'].to(device, dtype)
|
| 483 |
weight_down = elems['lora_down.weight'].to(device, dtype)
|
| 484 |
|
|
|
|
| 520 |
key = key.replace(".self_attn.", "_self_attn_")
|
| 521 |
key = key.replace(".cross_attn.", "_cross_attn_")
|
| 522 |
key = key.replace(".ffn.", "_ffn_")
|
|
|
|
|
|
|
| 523 |
key = key.replace(".lora_A.default.", ".lora_down.")
|
| 524 |
key = key.replace(".lora_B.default.", ".lora_up.")
|
| 525 |
+
layer, elem = key.split('.', 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 526 |
updates[layer][elem] = value
|
| 527 |
|
| 528 |
sequential_cpu_offload_flag = False
|
|
|
|
| 589 |
if error_flag:
|
| 590 |
continue
|
| 591 |
|
| 592 |
+
if not hasattr(curr_layer, "weight"):
|
| 593 |
+
continue
|
| 594 |
+
|
| 595 |
origin_dtype = curr_layer.weight.data.dtype
|
| 596 |
origin_device = curr_layer.weight.data.device
|
| 597 |
|
| 598 |
curr_layer = curr_layer.to(device, dtype)
|
| 599 |
+
if 'lora_up.weight' not in elems or 'lora_down.weight' not in elems:
|
| 600 |
+
curr_layer = curr_layer.to(origin_device, origin_dtype)
|
| 601 |
+
continue
|
| 602 |
weight_up = elems['lora_up.weight'].to(device, dtype)
|
| 603 |
weight_down = elems['lora_down.weight'].to(device, dtype)
|
| 604 |
|