|
|
import torch |
|
|
import torch.nn as nn |
|
|
from typing import Dict, Tuple, Optional, Union |
|
|
import warnings |
|
|
|
|
|
try: |
|
|
from safetensors.torch import save_file as save_safetensors |
|
|
SAFETENSORS_AVAILABLE = True |
|
|
except ImportError: |
|
|
SAFETENSORS_AVAILABLE = False |
|
|
warnings.warn("safetensors not available. Install with: pip install safetensors") |
|
|
|
|
|
class LoRAExtractor: |
|
|
""" |
|
|
Extract LoRA tensors from the difference between original and fine-tuned models. |
|
|
|
|
|
LoRA (Low-Rank Adaptation) decomposes weight updates as ΔW = B @ A where: |
|
|
- A (lora_down): [rank, input_dim] matrix (saved as diffusion_model.param_name.lora_down.weight) |
|
|
- B (lora_up): [output_dim, rank] matrix (saved as diffusion_model.param_name.lora_up.weight) |
|
|
|
|
|
The decomposition uses SVD: ΔW = U @ S @ V^T ≈ (U @ S) @ V^T where: |
|
|
- lora_up = U @ S (contains all singular values) |
|
|
- lora_down = V^T (orthogonal matrix) |
|
|
|
|
|
Parameter handling based on name AND dimension: |
|
|
- 2D weight tensors: LoRA decomposition (.lora_down.weight, .lora_up.weight) |
|
|
- Any bias tensors: direct difference (.diff_b) |
|
|
- Other weight tensors (1D, 3D, 4D): full difference (.diff) |
|
|
|
|
|
Progress tracking and test mode are available for format validation and debugging. |
|
|
""" |
|
|
|
|
|
def __init__(self, rank: int = 128, threshold: float = 1e-6, test_mode: bool = False, show_reconstruction_errors: bool = False): |
|
|
""" |
|
|
Initialize LoRA extractor. |
|
|
|
|
|
Args: |
|
|
rank: Target rank for LoRA decomposition (default: 128) |
|
|
threshold: Minimum singular value threshold for decomposition |
|
|
test_mode: If True, creates zero tensors without computation for format testing |
|
|
show_reconstruction_errors: If True, calculates and displays reconstruction error for each LoRA pair |
|
|
""" |
|
|
self.rank = rank |
|
|
self.threshold = threshold |
|
|
self.test_mode = test_mode |
|
|
self.show_reconstruction_errors = show_reconstruction_errors |
|
|
|
|
|
def extract_lora_from_state_dicts( |
|
|
self, |
|
|
original_state_dict: Dict[str, torch.Tensor], |
|
|
finetuned_state_dict: Dict[str, torch.Tensor], |
|
|
device: str = 'cpu', |
|
|
show_progress: bool = True |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Extract LoRA tensors for all matching parameters between two state dictionaries. |
|
|
|
|
|
Args: |
|
|
original_state_dict: State dict of the original model |
|
|
finetuned_state_dict: State dict of the fine-tuned model |
|
|
device: Device to perform computations on |
|
|
show_progress: Whether to display progress information |
|
|
|
|
|
Returns: |
|
|
Dictionary mapping parameter names to their LoRA components: |
|
|
- For 2D weight tensors: 'diffusion_model.layer.lora_down.weight', 'diffusion_model.layer.lora_up.weight' |
|
|
- For any bias tensors: 'diffusion_model.layer.diff_b' |
|
|
- For other weight tensors (1D, 3D, 4D): 'diffusion_model.layer.diff' |
|
|
""" |
|
|
lora_tensors = {} |
|
|
|
|
|
|
|
|
common_keys = sorted(set(original_state_dict.keys()) & set(finetuned_state_dict.keys())) |
|
|
total_params = len(common_keys) |
|
|
processed_params = 0 |
|
|
extracted_components = 0 |
|
|
|
|
|
if show_progress: |
|
|
print(f"Starting LoRA extraction for {total_params} parameters on {device}...") |
|
|
|
|
|
|
|
|
threshold_tensor = torch.tensor(self.threshold, device=device) |
|
|
|
|
|
for param_name in common_keys: |
|
|
if show_progress: |
|
|
processed_params += 1 |
|
|
progress_pct = (processed_params / total_params) * 100 |
|
|
print(f"[{processed_params:4d}/{total_params}] ({progress_pct:5.1f}%) Processing: {param_name}") |
|
|
|
|
|
|
|
|
original_tensor = original_state_dict[param_name] |
|
|
finetuned_tensor = finetuned_state_dict[param_name] |
|
|
|
|
|
|
|
|
if original_tensor.shape != finetuned_tensor.shape: |
|
|
if show_progress: |
|
|
print(f" → Shape mismatch: {original_tensor.shape} vs {finetuned_tensor.shape}. Skipping.") |
|
|
continue |
|
|
|
|
|
|
|
|
if not self.test_mode: |
|
|
if original_tensor.device != torch.device(device): |
|
|
original_tensor = original_tensor.to(device, non_blocking=True) |
|
|
if finetuned_tensor.device != torch.device(device): |
|
|
finetuned_tensor = finetuned_tensor.to(device, non_blocking=True) |
|
|
|
|
|
|
|
|
delta_tensor = finetuned_tensor - original_tensor |
|
|
|
|
|
|
|
|
max_abs_diff = torch.max(torch.abs(delta_tensor)) |
|
|
if max_abs_diff <= threshold_tensor: |
|
|
if show_progress: |
|
|
print(f" → No significant changes detected (max diff: {max_abs_diff:.2e}), skipping") |
|
|
continue |
|
|
else: |
|
|
|
|
|
delta_tensor = torch.zeros_like(original_tensor) |
|
|
if device != 'cpu': |
|
|
delta_tensor = delta_tensor.to(device) |
|
|
|
|
|
|
|
|
extracted_tensors = self._extract_lora_components(delta_tensor, param_name) |
|
|
|
|
|
if extracted_tensors: |
|
|
lora_tensors.update(extracted_tensors) |
|
|
extracted_components += len(extracted_tensors) |
|
|
if show_progress: |
|
|
|
|
|
component_names = [] |
|
|
for key in extracted_tensors.keys(): |
|
|
if key.endswith('.lora_down.weight'): |
|
|
component_names.append('lora_down') |
|
|
elif key.endswith('.lora_up.weight'): |
|
|
component_names.append('lora_up') |
|
|
elif key.endswith('.diff_b'): |
|
|
component_names.append('diff_b') |
|
|
elif key.endswith('.diff'): |
|
|
component_names.append('diff') |
|
|
else: |
|
|
component_names.append(key.split('.')[-1]) |
|
|
print(f" → Extracted {len(extracted_tensors)} components: {component_names}") |
|
|
|
|
|
if show_progress: |
|
|
print(f"\nExtraction completed!") |
|
|
print(f"Processed: {processed_params}/{total_params} parameters") |
|
|
print(f"Extracted: {extracted_components} LoRA components") |
|
|
print(f"LoRA rank: {self.rank}") |
|
|
|
|
|
|
|
|
lora_down_count = sum(1 for k in lora_tensors.keys() if k.endswith('.lora_down.weight')) |
|
|
lora_up_count = sum(1 for k in lora_tensors.keys() if k.endswith('.lora_up.weight')) |
|
|
diff_b_count = sum(1 for k in lora_tensors.keys() if k.endswith('.diff_b')) |
|
|
diff_count = sum(1 for k in lora_tensors.keys() if k.endswith('.diff')) |
|
|
|
|
|
print(f"Summary: {lora_down_count} lora_down, {lora_up_count} lora_up, {diff_b_count} diff_b, {diff_count} diff") |
|
|
|
|
|
return lora_tensors |
|
|
|
|
|
def _extract_lora_components( |
|
|
self, |
|
|
delta_tensor: torch.Tensor, |
|
|
param_name: str |
|
|
) -> Optional[Dict[str, torch.Tensor]]: |
|
|
""" |
|
|
Extract LoRA components from a delta tensor. |
|
|
|
|
|
Args: |
|
|
delta_tensor: Difference between fine-tuned and original tensor |
|
|
param_name: Name of the parameter (for generating output keys) |
|
|
|
|
|
Returns: |
|
|
Dictionary with modified parameter names as keys and tensors as values |
|
|
""" |
|
|
|
|
|
is_weight = 'weight' in param_name.lower() |
|
|
is_bias = 'bias' in param_name.lower() |
|
|
|
|
|
|
|
|
base_name = param_name |
|
|
if base_name.endswith('.weight'): |
|
|
base_name = base_name[:-7] |
|
|
elif base_name.endswith('.bias'): |
|
|
base_name = base_name[:-5] |
|
|
|
|
|
|
|
|
base_name = f"diffusion_model.{base_name}" |
|
|
|
|
|
if self.test_mode: |
|
|
|
|
|
if delta_tensor.dim() == 2 and is_weight: |
|
|
|
|
|
output_dim, input_dim = delta_tensor.shape |
|
|
rank = min(self.rank, min(input_dim, output_dim)) |
|
|
return { |
|
|
f"{base_name}.lora_down.weight": torch.zeros(rank, input_dim, dtype=delta_tensor.dtype, device=delta_tensor.device), |
|
|
f"{base_name}.lora_up.weight": torch.zeros(output_dim, rank, dtype=delta_tensor.dtype, device=delta_tensor.device) |
|
|
} |
|
|
elif is_bias: |
|
|
|
|
|
return {f"{base_name}.diff_b": torch.zeros_like(delta_tensor)} |
|
|
else: |
|
|
|
|
|
return {f"{base_name}.diff": torch.zeros_like(delta_tensor)} |
|
|
|
|
|
|
|
|
if delta_tensor.dim() == 2 and is_weight: |
|
|
|
|
|
return self._decompose_2d_tensor(delta_tensor, base_name) |
|
|
|
|
|
elif is_bias: |
|
|
|
|
|
return {f"{base_name}.diff_b": delta_tensor.clone()} |
|
|
|
|
|
else: |
|
|
|
|
|
return {f"{base_name}.diff": delta_tensor.clone()} |
|
|
|
|
|
def _decompose_2d_tensor(self, delta_tensor: torch.Tensor, base_name: str) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Decompose a 2D tensor using SVD on GPU for maximum performance. |
|
|
|
|
|
Args: |
|
|
delta_tensor: 2D tensor to decompose (output_dim × input_dim) |
|
|
base_name: Base name for the parameter (already processed, with diffusion_model prefix) |
|
|
|
|
|
Returns: |
|
|
Dictionary with lora_down and lora_up tensors: |
|
|
- lora_down: [rank, input_dim] |
|
|
- lora_up: [output_dim, rank] |
|
|
""" |
|
|
|
|
|
dtype = delta_tensor.dtype |
|
|
device = delta_tensor.device |
|
|
|
|
|
|
|
|
delta_float = delta_tensor.float() if delta_tensor.dtype != torch.float32 else delta_tensor |
|
|
U, S, Vt = torch.linalg.svd(delta_float, full_matrices=False) |
|
|
|
|
|
|
|
|
|
|
|
significant_mask = S > self.threshold |
|
|
effective_rank = min(self.rank, torch.sum(significant_mask).item()) |
|
|
effective_rank = self.rank |
|
|
|
|
|
if effective_rank == 0: |
|
|
warnings.warn(f"No significant singular values found for {base_name}") |
|
|
effective_rank = 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lora_up = U[:, :effective_rank] * S[:effective_rank].unsqueeze(0) |
|
|
lora_down = Vt[:effective_rank, :] |
|
|
|
|
|
|
|
|
lora_up = lora_up.to(dtype) |
|
|
lora_down = lora_down.to(dtype) |
|
|
|
|
|
|
|
|
if self.show_reconstruction_errors: |
|
|
with torch.no_grad(): |
|
|
|
|
|
reconstructed = lora_up @ lora_down |
|
|
|
|
|
|
|
|
mse_error = torch.mean((delta_tensor - reconstructed) ** 2).item() |
|
|
max_error = torch.max(torch.abs(delta_tensor - reconstructed)).item() |
|
|
|
|
|
|
|
|
original_norm = torch.norm(delta_tensor).item() |
|
|
relative_error = (torch.norm(delta_tensor - reconstructed).item() / original_norm * 100) if original_norm > 0 else 0 |
|
|
|
|
|
|
|
|
delta_flat = delta_tensor.flatten() |
|
|
reconstructed_flat = reconstructed.flatten() |
|
|
if torch.norm(delta_flat) > 0 and torch.norm(reconstructed_flat) > 0: |
|
|
cosine_sim = torch.nn.functional.cosine_similarity( |
|
|
delta_flat.unsqueeze(0), |
|
|
reconstructed_flat.unsqueeze(0) |
|
|
).item() |
|
|
else: |
|
|
cosine_sim = 0.0 |
|
|
|
|
|
|
|
|
display_name = base_name[16:] if base_name.startswith('diffusion_model.') else base_name |
|
|
|
|
|
print(f" LoRA Error [{display_name}]: MSE={mse_error:.2e}, Max={max_error:.2e}, Rel={relative_error:.2f}%, Cos={cosine_sim:.4f}, Rank={effective_rank}") |
|
|
|
|
|
return { |
|
|
f"{base_name}.lora_down.weight": lora_down, |
|
|
f"{base_name}.lora_up.weight": lora_up |
|
|
} |
|
|
|
|
|
def verify_reconstruction( |
|
|
self, |
|
|
lora_tensors: Dict[str, torch.Tensor], |
|
|
original_deltas: Dict[str, torch.Tensor] |
|
|
) -> Dict[str, float]: |
|
|
""" |
|
|
Verify the quality of LoRA reconstruction for 2D tensors. |
|
|
|
|
|
Args: |
|
|
lora_tensors: Dictionary with LoRA tensors (flat structure with diffusion_model prefix) |
|
|
original_deltas: Dictionary with original delta tensors (without prefix) |
|
|
|
|
|
Returns: |
|
|
Dictionary mapping parameter names to reconstruction errors |
|
|
""" |
|
|
reconstruction_errors = {} |
|
|
|
|
|
|
|
|
lora_pairs = {} |
|
|
for key, tensor in lora_tensors.items(): |
|
|
if key.endswith('.lora_down.weight'): |
|
|
base_name = key[:-18] |
|
|
|
|
|
if base_name.startswith('diffusion_model.'): |
|
|
original_key = base_name[16:] |
|
|
else: |
|
|
original_key = base_name |
|
|
if base_name not in lora_pairs: |
|
|
lora_pairs[base_name] = {'original_key': original_key} |
|
|
lora_pairs[base_name]['lora_down'] = tensor |
|
|
elif key.endswith('.lora_up.weight'): |
|
|
base_name = key[:-16] |
|
|
|
|
|
if base_name.startswith('diffusion_model.'): |
|
|
original_key = base_name[16:] |
|
|
else: |
|
|
original_key = base_name |
|
|
if base_name not in lora_pairs: |
|
|
lora_pairs[base_name] = {'original_key': original_key} |
|
|
lora_pairs[base_name]['lora_up'] = tensor |
|
|
|
|
|
|
|
|
for base_name, components in lora_pairs.items(): |
|
|
if 'lora_down' in components and 'lora_up' in components and 'original_key' in components: |
|
|
original_key = components['original_key'] |
|
|
if original_key in original_deltas: |
|
|
lora_down = components['lora_down'] |
|
|
lora_up = components['lora_up'] |
|
|
original_delta = original_deltas[original_key] |
|
|
|
|
|
|
|
|
effective_rank = min(lora_up.shape[1], lora_down.shape[0]) |
|
|
|
|
|
|
|
|
reconstructed = lora_up @ lora_down |
|
|
|
|
|
|
|
|
mse_error = torch.mean((original_delta - reconstructed) ** 2).item() |
|
|
reconstruction_errors[base_name] = mse_error |
|
|
|
|
|
return reconstruction_errors |
|
|
|
|
|
def compute_reconstruction_errors( |
|
|
original_tensor: torch.Tensor, |
|
|
reconstructed_tensor: torch.Tensor, |
|
|
target_tensor: torch.Tensor |
|
|
) -> Dict[str, float]: |
|
|
""" |
|
|
Compute various error metrics between original, reconstructed, and target tensors. |
|
|
|
|
|
Args: |
|
|
original_tensor: Original tensor before fine-tuning |
|
|
reconstructed_tensor: Reconstructed tensor from LoRA (original + LoRA_reconstruction) |
|
|
target_tensor: Target tensor (fine-tuned) |
|
|
|
|
|
Returns: |
|
|
Dictionary with error metrics |
|
|
""" |
|
|
|
|
|
device = original_tensor.device |
|
|
reconstructed_tensor = reconstructed_tensor.to(device) |
|
|
target_tensor = target_tensor.to(device) |
|
|
|
|
|
|
|
|
delta_original = target_tensor - original_tensor |
|
|
delta_reconstructed = reconstructed_tensor - original_tensor |
|
|
reconstruction_error = target_tensor - reconstructed_tensor |
|
|
|
|
|
|
|
|
errors = {} |
|
|
|
|
|
|
|
|
errors['mse_delta'] = torch.mean((delta_original - delta_reconstructed) ** 2).item() |
|
|
errors['mse_final'] = torch.mean(reconstruction_error ** 2).item() |
|
|
|
|
|
|
|
|
errors['mae_delta'] = torch.mean(torch.abs(delta_original - delta_reconstructed)).item() |
|
|
errors['mae_final'] = torch.mean(torch.abs(reconstruction_error)).item() |
|
|
|
|
|
|
|
|
original_norm = torch.norm(original_tensor).item() |
|
|
target_norm = torch.norm(target_tensor).item() |
|
|
delta_norm = torch.norm(delta_original).item() |
|
|
|
|
|
if original_norm > 0: |
|
|
errors['relative_error_original'] = (torch.norm(reconstruction_error).item() / original_norm) * 100 |
|
|
if target_norm > 0: |
|
|
errors['relative_error_target'] = (torch.norm(reconstruction_error).item() / target_norm) * 100 |
|
|
if delta_norm > 0: |
|
|
errors['relative_error_delta'] = (torch.norm(delta_original - delta_reconstructed).item() / delta_norm) * 100 |
|
|
|
|
|
|
|
|
delta_flat = delta_original.flatten() |
|
|
reconstructed_flat = delta_reconstructed.flatten() |
|
|
|
|
|
if torch.norm(delta_flat) > 0 and torch.norm(reconstructed_flat) > 0: |
|
|
cosine_sim = torch.nn.functional.cosine_similarity( |
|
|
delta_flat.unsqueeze(0), |
|
|
reconstructed_flat.unsqueeze(0) |
|
|
).item() |
|
|
errors['cosine_similarity'] = cosine_sim |
|
|
else: |
|
|
errors['cosine_similarity'] = 0.0 |
|
|
|
|
|
|
|
|
if errors['mse_final'] > 0: |
|
|
signal_power = torch.mean(target_tensor ** 2).item() |
|
|
errors['snr_db'] = 10 * torch.log10(signal_power / errors['mse_final']).item() |
|
|
else: |
|
|
errors['snr_db'] = float('inf') |
|
|
|
|
|
return errors |
|
|
|
|
|
|
|
|
def load_and_extract_lora( |
|
|
original_model_path: str, |
|
|
finetuned_model_path: str, |
|
|
rank: int = 128, |
|
|
device: str = 'cuda' if torch.cuda.is_available() else 'cpu', |
|
|
show_progress: bool = True, |
|
|
test_mode: bool = False, |
|
|
show_reconstruction_errors: bool = False |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Convenience function to load models and extract LoRA tensors with GPU acceleration. |
|
|
|
|
|
Args: |
|
|
original_model_path: Path to original model state dict |
|
|
finetuned_model_path: Path to fine-tuned model state dict |
|
|
rank: Target LoRA rank (default: 128) |
|
|
device: Device for computation (defaults to GPU if available) |
|
|
show_progress: Whether to display progress information |
|
|
test_mode: If True, creates zero tensors without computation for format testing |
|
|
show_reconstruction_errors: If True, calculates and displays reconstruction error for each LoRA pair |
|
|
|
|
|
Returns: |
|
|
Dictionary of LoRA tensors with modified parameter names as keys |
|
|
""" |
|
|
|
|
|
if show_progress: |
|
|
print(f"Loading original model from: {original_model_path}") |
|
|
original_state_dict = torch.load(original_model_path, map_location='cpu') |
|
|
|
|
|
if show_progress: |
|
|
print(f"Loading fine-tuned model from: {finetuned_model_path}") |
|
|
finetuned_state_dict = torch.load(finetuned_model_path, map_location='cpu') |
|
|
|
|
|
|
|
|
if 'state_dict' in original_state_dict: |
|
|
original_state_dict = original_state_dict['state_dict'] |
|
|
if 'state_dict' in finetuned_state_dict: |
|
|
finetuned_state_dict = finetuned_state_dict['state_dict'] |
|
|
|
|
|
|
|
|
extractor = LoRAExtractor(rank=rank, test_mode=test_mode, show_reconstruction_errors=show_reconstruction_errors) |
|
|
lora_tensors = extractor.extract_lora_from_state_dicts( |
|
|
original_state_dict, |
|
|
finetuned_state_dict, |
|
|
device=device, |
|
|
show_progress=show_progress |
|
|
) |
|
|
|
|
|
return lora_tensors |
|
|
|
|
|
def save_lora_tensors(lora_tensors: Dict[str, torch.Tensor], save_path: str): |
|
|
"""Save extracted LoRA tensors to disk.""" |
|
|
torch.save(lora_tensors, save_path) |
|
|
print(f"LoRA tensors saved to {save_path}") |
|
|
|
|
|
def save_lora_safetensors(lora_tensors: Dict[str, torch.Tensor], save_path: str, rank: int = None): |
|
|
"""Save extracted LoRA tensors as safetensors format with metadata.""" |
|
|
if not SAFETENSORS_AVAILABLE: |
|
|
raise ImportError("safetensors not available. Install with: pip install safetensors") |
|
|
|
|
|
|
|
|
contiguous_tensors = {k: v.contiguous() if v.is_floating_point() else v.contiguous() |
|
|
for k, v in lora_tensors.items()} |
|
|
|
|
|
|
|
|
metadata = {} |
|
|
if rank is not None: |
|
|
metadata["rank"] = str(rank) |
|
|
|
|
|
save_safetensors(contiguous_tensors, save_path, metadata=metadata if metadata else None) |
|
|
print(f"LoRA tensors saved as safetensors to {save_path}") |
|
|
if metadata: |
|
|
print(f"Metadata: {metadata}") |
|
|
|
|
|
def analyze_lora_tensors(lora_tensors: Dict[str, torch.Tensor]): |
|
|
"""Analyze the extracted LoRA tensors.""" |
|
|
print(f"Extracted LoRA tensors ({len(lora_tensors)} components):") |
|
|
|
|
|
|
|
|
lora_down_tensors = {k: v for k, v in lora_tensors.items() if k.endswith('.lora_down.weight')} |
|
|
lora_up_tensors = {k: v for k, v in lora_tensors.items() if k.endswith('.lora_up.weight')} |
|
|
diff_b_tensors = {k: v for k, v in lora_tensors.items() if k.endswith('.diff_b')} |
|
|
diff_tensors = {k: v for k, v in lora_tensors.items() if k.endswith('.diff')} |
|
|
|
|
|
if lora_down_tensors: |
|
|
print(f"\nLinear LoRA down matrices ({len(lora_down_tensors)}):") |
|
|
for name, tensor in lora_down_tensors.items(): |
|
|
print(f" {name}: {tensor.shape}") |
|
|
|
|
|
if lora_up_tensors: |
|
|
print(f"\nLinear LoRA up matrices ({len(lora_up_tensors)}):") |
|
|
for name, tensor in lora_up_tensors.items(): |
|
|
print(f" {name}: {tensor.shape}") |
|
|
|
|
|
if diff_b_tensors: |
|
|
print(f"\nBias differences ({len(diff_b_tensors)}):") |
|
|
for name, tensor in diff_b_tensors.items(): |
|
|
print(f" {name}: {tensor.shape}") |
|
|
|
|
|
if diff_tensors: |
|
|
print(f"\nFull weight differences ({len(diff_tensors)}):") |
|
|
print(" (Includes conv, modulation, and other multi-dimensional tensors)") |
|
|
for name, tensor in diff_tensors.items(): |
|
|
print(f" {name}: {tensor.shape}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
from safetensors.torch import load_file as load_safetensors |
|
|
|
|
|
|
|
|
|
|
|
original_state_dict = load_safetensors("ckpts/wan2.2_text2video_14B_high_mbf16.safetensors") |
|
|
finetuned_state_dict = load_safetensors("ckpts/wan2.2_text2video_14B_low_mbf16.safetensors") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"Loaded original model with {len(original_state_dict)} parameters") |
|
|
print(f"Loaded fine-tuned model with {len(finetuned_state_dict)} parameters") |
|
|
|
|
|
|
|
|
|
|
|
extractor_test = LoRAExtractor(show_reconstruction_errors=True, rank=128) |
|
|
|
|
|
lora_tensors_test = extractor_test.extract_lora_from_state_dicts( |
|
|
original_state_dict, |
|
|
finetuned_state_dict, |
|
|
device='cuda', |
|
|
show_progress=True |
|
|
) |
|
|
|
|
|
print("\nTest mode tensor keys (first 10):") |
|
|
for i, key in enumerate(sorted(lora_tensors_test.keys())): |
|
|
if i < 10: |
|
|
print(f" {key}: {lora_tensors_test[key].shape}") |
|
|
elif i == 10: |
|
|
print(f" ... and {len(lora_tensors_test) - 10} more") |
|
|
break |
|
|
|
|
|
|
|
|
save_lora_safetensors(lora_tensors_test, "extracted_lora.safetensors") |
|
|
|
|
|
|