|
|
""" |
|
|
LoRA Utilities for ACE-Step |
|
|
|
|
|
Provides utilities for injecting LoRA adapters into the DiT decoder model. |
|
|
Uses PEFT (Parameter-Efficient Fine-Tuning) library for LoRA implementation. |
|
|
""" |
|
|
|
|
|
import os |
|
|
from typing import Optional, List, Dict, Any, Tuple |
|
|
from loguru import logger |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
try: |
|
|
from peft import ( |
|
|
get_peft_model, |
|
|
LoraConfig, |
|
|
TaskType, |
|
|
PeftModel, |
|
|
PeftConfig, |
|
|
) |
|
|
PEFT_AVAILABLE = True |
|
|
except ImportError: |
|
|
PEFT_AVAILABLE = False |
|
|
logger.warning("PEFT library not installed. LoRA training will not be available.") |
|
|
|
|
|
from acestep.training.configs import LoRAConfig |
|
|
|
|
|
|
|
|
def check_peft_available() -> bool: |
|
|
"""Check if PEFT library is available.""" |
|
|
return PEFT_AVAILABLE |
|
|
|
|
|
|
|
|
def get_dit_target_modules(model) -> List[str]: |
|
|
"""Get the list of module names in the DiT decoder that can have LoRA applied. |
|
|
|
|
|
Args: |
|
|
model: The AceStepConditionGenerationModel |
|
|
|
|
|
Returns: |
|
|
List of module names suitable for LoRA |
|
|
""" |
|
|
target_modules = [] |
|
|
|
|
|
|
|
|
if hasattr(model, 'decoder'): |
|
|
for name, module in model.decoder.named_modules(): |
|
|
|
|
|
if any(proj in name for proj in ['q_proj', 'k_proj', 'v_proj', 'o_proj']): |
|
|
if isinstance(module, nn.Linear): |
|
|
target_modules.append(name) |
|
|
|
|
|
return target_modules |
|
|
|
|
|
|
|
|
def freeze_non_lora_parameters(model, freeze_encoder: bool = True) -> None: |
|
|
"""Freeze all non-LoRA parameters in the model. |
|
|
|
|
|
Args: |
|
|
model: The model to freeze parameters for |
|
|
freeze_encoder: Whether to freeze the encoder (condition encoder) |
|
|
""" |
|
|
|
|
|
for param in model.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
total_params = 0 |
|
|
trainable_params = 0 |
|
|
|
|
|
for name, param in model.named_parameters(): |
|
|
total_params += param.numel() |
|
|
if param.requires_grad: |
|
|
trainable_params += param.numel() |
|
|
|
|
|
logger.info(f"Frozen parameters: {total_params - trainable_params:,}") |
|
|
logger.info(f"Trainable parameters: {trainable_params:,}") |
|
|
|
|
|
|
|
|
def inject_lora_into_dit( |
|
|
model, |
|
|
lora_config: LoRAConfig, |
|
|
) -> Tuple[Any, Dict[str, Any]]: |
|
|
"""Inject LoRA adapters into the DiT decoder of the model. |
|
|
|
|
|
Args: |
|
|
model: The AceStepConditionGenerationModel |
|
|
lora_config: LoRA configuration |
|
|
|
|
|
Returns: |
|
|
Tuple of (peft_model, info_dict) |
|
|
""" |
|
|
if not PEFT_AVAILABLE: |
|
|
raise ImportError("PEFT library is required for LoRA training. Install with: pip install peft") |
|
|
|
|
|
|
|
|
decoder = model.decoder |
|
|
|
|
|
|
|
|
peft_lora_config = LoraConfig( |
|
|
r=lora_config.r, |
|
|
lora_alpha=lora_config.alpha, |
|
|
lora_dropout=lora_config.dropout, |
|
|
target_modules=lora_config.target_modules, |
|
|
bias=lora_config.bias, |
|
|
task_type=TaskType.FEATURE_EXTRACTION, |
|
|
) |
|
|
|
|
|
|
|
|
peft_decoder = get_peft_model(decoder, peft_lora_config) |
|
|
|
|
|
|
|
|
model.decoder = peft_decoder |
|
|
|
|
|
|
|
|
|
|
|
for name, param in model.named_parameters(): |
|
|
|
|
|
if 'lora_' not in name: |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
|
|
info = { |
|
|
"total_params": total_params, |
|
|
"trainable_params": trainable_params, |
|
|
"trainable_ratio": trainable_params / total_params if total_params > 0 else 0, |
|
|
"lora_r": lora_config.r, |
|
|
"lora_alpha": lora_config.alpha, |
|
|
"target_modules": lora_config.target_modules, |
|
|
} |
|
|
|
|
|
logger.info(f"LoRA injected into DiT decoder:") |
|
|
logger.info(f" Total parameters: {total_params:,}") |
|
|
logger.info(f" Trainable parameters: {trainable_params:,} ({info['trainable_ratio']:.2%})") |
|
|
logger.info(f" LoRA rank: {lora_config.r}, alpha: {lora_config.alpha}") |
|
|
|
|
|
return model, info |
|
|
|
|
|
|
|
|
def save_lora_weights( |
|
|
model, |
|
|
output_dir: str, |
|
|
save_full_model: bool = False, |
|
|
) -> str: |
|
|
"""Save LoRA adapter weights. |
|
|
|
|
|
Args: |
|
|
model: Model with LoRA adapters |
|
|
output_dir: Directory to save weights |
|
|
save_full_model: Whether to save the full model state dict |
|
|
|
|
|
Returns: |
|
|
Path to saved weights |
|
|
""" |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
if hasattr(model, 'decoder') and hasattr(model.decoder, 'save_pretrained'): |
|
|
|
|
|
adapter_path = os.path.join(output_dir, "adapter") |
|
|
model.decoder.save_pretrained(adapter_path) |
|
|
logger.info(f"LoRA adapter saved to {adapter_path}") |
|
|
return adapter_path |
|
|
elif save_full_model: |
|
|
|
|
|
model_path = os.path.join(output_dir, "model.pt") |
|
|
torch.save(model.state_dict(), model_path) |
|
|
logger.info(f"Full model state dict saved to {model_path}") |
|
|
return model_path |
|
|
else: |
|
|
|
|
|
lora_state_dict = {} |
|
|
for name, param in model.named_parameters(): |
|
|
if 'lora_' in name: |
|
|
lora_state_dict[name] = param.data.clone() |
|
|
|
|
|
if not lora_state_dict: |
|
|
logger.warning("No LoRA parameters found to save!") |
|
|
return "" |
|
|
|
|
|
lora_path = os.path.join(output_dir, "lora_weights.pt") |
|
|
torch.save(lora_state_dict, lora_path) |
|
|
logger.info(f"LoRA weights saved to {lora_path}") |
|
|
return lora_path |
|
|
|
|
|
|
|
|
def load_lora_weights( |
|
|
model, |
|
|
lora_path: str, |
|
|
lora_config: Optional[LoRAConfig] = None, |
|
|
) -> Any: |
|
|
"""Load LoRA adapter weights into the model. |
|
|
|
|
|
Args: |
|
|
model: The base model (without LoRA) |
|
|
lora_path: Path to saved LoRA weights (adapter or .pt file) |
|
|
lora_config: LoRA configuration (required if loading from .pt file) |
|
|
|
|
|
Returns: |
|
|
Model with LoRA weights loaded |
|
|
""" |
|
|
if not os.path.exists(lora_path): |
|
|
raise FileNotFoundError(f"LoRA weights not found: {lora_path}") |
|
|
|
|
|
|
|
|
if os.path.isdir(lora_path): |
|
|
if not PEFT_AVAILABLE: |
|
|
raise ImportError("PEFT library is required to load adapter. Install with: pip install peft") |
|
|
|
|
|
|
|
|
peft_config = PeftConfig.from_pretrained(lora_path) |
|
|
model.decoder = PeftModel.from_pretrained(model.decoder, lora_path) |
|
|
logger.info(f"LoRA adapter loaded from {lora_path}") |
|
|
|
|
|
elif lora_path.endswith('.pt'): |
|
|
|
|
|
if lora_config is None: |
|
|
raise ValueError("lora_config is required when loading from .pt file") |
|
|
|
|
|
|
|
|
model, _ = inject_lora_into_dit(model, lora_config) |
|
|
|
|
|
|
|
|
lora_state_dict = torch.load(lora_path, map_location='cpu') |
|
|
|
|
|
|
|
|
model_state = model.state_dict() |
|
|
for name, param in lora_state_dict.items(): |
|
|
if name in model_state: |
|
|
model_state[name].copy_(param) |
|
|
else: |
|
|
logger.warning(f"Unexpected key in LoRA state dict: {name}") |
|
|
|
|
|
logger.info(f"LoRA weights loaded from {lora_path}") |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unsupported LoRA weight format: {lora_path}") |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def merge_lora_weights(model) -> Any: |
|
|
"""Merge LoRA weights into the base model. |
|
|
|
|
|
This permanently integrates the LoRA adaptations into the model weights. |
|
|
After merging, the model can be used without PEFT. |
|
|
|
|
|
Args: |
|
|
model: Model with LoRA adapters |
|
|
|
|
|
Returns: |
|
|
Model with merged weights |
|
|
""" |
|
|
if hasattr(model, 'decoder') and hasattr(model.decoder, 'merge_and_unload'): |
|
|
|
|
|
model.decoder = model.decoder.merge_and_unload() |
|
|
logger.info("LoRA weights merged into base model") |
|
|
else: |
|
|
logger.warning("Model does not support LoRA merging") |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def get_lora_info(model) -> Dict[str, Any]: |
|
|
"""Get information about LoRA adapters in the model. |
|
|
|
|
|
Args: |
|
|
model: Model to inspect |
|
|
|
|
|
Returns: |
|
|
Dictionary with LoRA information |
|
|
""" |
|
|
info = { |
|
|
"has_lora": False, |
|
|
"lora_params": 0, |
|
|
"total_params": 0, |
|
|
"modules_with_lora": [], |
|
|
} |
|
|
|
|
|
total_params = 0 |
|
|
lora_params = 0 |
|
|
lora_modules = [] |
|
|
|
|
|
for name, param in model.named_parameters(): |
|
|
total_params += param.numel() |
|
|
if 'lora_' in name: |
|
|
lora_params += param.numel() |
|
|
|
|
|
module_name = name.rsplit('.lora_', 1)[0] |
|
|
if module_name not in lora_modules: |
|
|
lora_modules.append(module_name) |
|
|
|
|
|
info["total_params"] = total_params |
|
|
info["lora_params"] = lora_params |
|
|
info["has_lora"] = lora_params > 0 |
|
|
info["modules_with_lora"] = lora_modules |
|
|
|
|
|
if total_params > 0: |
|
|
info["lora_ratio"] = lora_params / total_params |
|
|
|
|
|
return info |
|
|
|