mrfakename's picture
Upload folder using huggingface_hub
9f5c8f7 verified
"""
LoRA Utilities for ACE-Step
Provides utilities for injecting LoRA adapters into the DiT decoder model.
Uses PEFT (Parameter-Efficient Fine-Tuning) library for LoRA implementation.
"""
import os
from typing import Optional, List, Dict, Any, Tuple
from loguru import logger
import torch
import torch.nn as nn
try:
from peft import (
get_peft_model,
LoraConfig,
TaskType,
PeftModel,
PeftConfig,
)
PEFT_AVAILABLE = True
except ImportError:
PEFT_AVAILABLE = False
logger.warning("PEFT library not installed. LoRA training will not be available.")
from acestep.training.configs import LoRAConfig
def check_peft_available() -> bool:
"""Check if PEFT library is available."""
return PEFT_AVAILABLE
def get_dit_target_modules(model) -> List[str]:
"""Get the list of module names in the DiT decoder that can have LoRA applied.
Args:
model: The AceStepConditionGenerationModel
Returns:
List of module names suitable for LoRA
"""
target_modules = []
# Focus on the decoder (DiT) attention layers
if hasattr(model, 'decoder'):
for name, module in model.decoder.named_modules():
# Target attention projection layers
if any(proj in name for proj in ['q_proj', 'k_proj', 'v_proj', 'o_proj']):
if isinstance(module, nn.Linear):
target_modules.append(name)
return target_modules
def freeze_non_lora_parameters(model, freeze_encoder: bool = True) -> None:
"""Freeze all non-LoRA parameters in the model.
Args:
model: The model to freeze parameters for
freeze_encoder: Whether to freeze the encoder (condition encoder)
"""
# Freeze all parameters first
for param in model.parameters():
param.requires_grad = False
# Count frozen and trainable parameters
total_params = 0
trainable_params = 0
for name, param in model.named_parameters():
total_params += param.numel()
if param.requires_grad:
trainable_params += param.numel()
logger.info(f"Frozen parameters: {total_params - trainable_params:,}")
logger.info(f"Trainable parameters: {trainable_params:,}")
def inject_lora_into_dit(
model,
lora_config: LoRAConfig,
) -> Tuple[Any, Dict[str, Any]]:
"""Inject LoRA adapters into the DiT decoder of the model.
Args:
model: The AceStepConditionGenerationModel
lora_config: LoRA configuration
Returns:
Tuple of (peft_model, info_dict)
"""
if not PEFT_AVAILABLE:
raise ImportError("PEFT library is required for LoRA training. Install with: pip install peft")
# Get the decoder (DiT model)
decoder = model.decoder
# Create PEFT LoRA config
peft_lora_config = LoraConfig(
r=lora_config.r,
lora_alpha=lora_config.alpha,
lora_dropout=lora_config.dropout,
target_modules=lora_config.target_modules,
bias=lora_config.bias,
task_type=TaskType.FEATURE_EXTRACTION, # For diffusion models
)
# Apply LoRA to the decoder
peft_decoder = get_peft_model(decoder, peft_lora_config)
# Replace the decoder in the original model
model.decoder = peft_decoder
# Freeze all non-LoRA parameters
# Freeze encoder, tokenizer, detokenizer
for name, param in model.named_parameters():
# Only keep LoRA parameters trainable
if 'lora_' not in name:
param.requires_grad = False
# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
info = {
"total_params": total_params,
"trainable_params": trainable_params,
"trainable_ratio": trainable_params / total_params if total_params > 0 else 0,
"lora_r": lora_config.r,
"lora_alpha": lora_config.alpha,
"target_modules": lora_config.target_modules,
}
logger.info(f"LoRA injected into DiT decoder:")
logger.info(f" Total parameters: {total_params:,}")
logger.info(f" Trainable parameters: {trainable_params:,} ({info['trainable_ratio']:.2%})")
logger.info(f" LoRA rank: {lora_config.r}, alpha: {lora_config.alpha}")
return model, info
def save_lora_weights(
model,
output_dir: str,
save_full_model: bool = False,
) -> str:
"""Save LoRA adapter weights.
Args:
model: Model with LoRA adapters
output_dir: Directory to save weights
save_full_model: Whether to save the full model state dict
Returns:
Path to saved weights
"""
os.makedirs(output_dir, exist_ok=True)
if hasattr(model, 'decoder') and hasattr(model.decoder, 'save_pretrained'):
# Save PEFT adapter
adapter_path = os.path.join(output_dir, "adapter")
model.decoder.save_pretrained(adapter_path)
logger.info(f"LoRA adapter saved to {adapter_path}")
return adapter_path
elif save_full_model:
# Save full model state dict (larger file)
model_path = os.path.join(output_dir, "model.pt")
torch.save(model.state_dict(), model_path)
logger.info(f"Full model state dict saved to {model_path}")
return model_path
else:
# Extract only LoRA parameters
lora_state_dict = {}
for name, param in model.named_parameters():
if 'lora_' in name:
lora_state_dict[name] = param.data.clone()
if not lora_state_dict:
logger.warning("No LoRA parameters found to save!")
return ""
lora_path = os.path.join(output_dir, "lora_weights.pt")
torch.save(lora_state_dict, lora_path)
logger.info(f"LoRA weights saved to {lora_path}")
return lora_path
def load_lora_weights(
model,
lora_path: str,
lora_config: Optional[LoRAConfig] = None,
) -> Any:
"""Load LoRA adapter weights into the model.
Args:
model: The base model (without LoRA)
lora_path: Path to saved LoRA weights (adapter or .pt file)
lora_config: LoRA configuration (required if loading from .pt file)
Returns:
Model with LoRA weights loaded
"""
if not os.path.exists(lora_path):
raise FileNotFoundError(f"LoRA weights not found: {lora_path}")
# Check if it's a PEFT adapter directory
if os.path.isdir(lora_path):
if not PEFT_AVAILABLE:
raise ImportError("PEFT library is required to load adapter. Install with: pip install peft")
# Load PEFT adapter
peft_config = PeftConfig.from_pretrained(lora_path)
model.decoder = PeftModel.from_pretrained(model.decoder, lora_path)
logger.info(f"LoRA adapter loaded from {lora_path}")
elif lora_path.endswith('.pt'):
# Load from PyTorch state dict
if lora_config is None:
raise ValueError("lora_config is required when loading from .pt file")
# First inject LoRA structure
model, _ = inject_lora_into_dit(model, lora_config)
# Load weights
lora_state_dict = torch.load(lora_path, map_location='cpu')
# Load into model
model_state = model.state_dict()
for name, param in lora_state_dict.items():
if name in model_state:
model_state[name].copy_(param)
else:
logger.warning(f"Unexpected key in LoRA state dict: {name}")
logger.info(f"LoRA weights loaded from {lora_path}")
else:
raise ValueError(f"Unsupported LoRA weight format: {lora_path}")
return model
def merge_lora_weights(model) -> Any:
"""Merge LoRA weights into the base model.
This permanently integrates the LoRA adaptations into the model weights.
After merging, the model can be used without PEFT.
Args:
model: Model with LoRA adapters
Returns:
Model with merged weights
"""
if hasattr(model, 'decoder') and hasattr(model.decoder, 'merge_and_unload'):
# PEFT model - merge and unload
model.decoder = model.decoder.merge_and_unload()
logger.info("LoRA weights merged into base model")
else:
logger.warning("Model does not support LoRA merging")
return model
def get_lora_info(model) -> Dict[str, Any]:
"""Get information about LoRA adapters in the model.
Args:
model: Model to inspect
Returns:
Dictionary with LoRA information
"""
info = {
"has_lora": False,
"lora_params": 0,
"total_params": 0,
"modules_with_lora": [],
}
total_params = 0
lora_params = 0
lora_modules = []
for name, param in model.named_parameters():
total_params += param.numel()
if 'lora_' in name:
lora_params += param.numel()
# Extract module name
module_name = name.rsplit('.lora_', 1)[0]
if module_name not in lora_modules:
lora_modules.append(module_name)
info["total_params"] = total_params
info["lora_params"] = lora_params
info["has_lora"] = lora_params > 0
info["modules_with_lora"] = lora_modules
if total_params > 0:
info["lora_ratio"] = lora_params / total_params
return info