Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| LoRA Utilities for ACE-Step | |
| Provides utilities for injecting LoRA adapters into the DiT decoder model. | |
| Uses PEFT (Parameter-Efficient Fine-Tuning) library for LoRA implementation. | |
| """ | |
| import os | |
| from typing import Optional, List, Dict, Any, Tuple | |
| from loguru import logger | |
| import torch | |
| import torch.nn as nn | |
| try: | |
| from peft import ( | |
| get_peft_model, | |
| LoraConfig, | |
| TaskType, | |
| PeftModel, | |
| PeftConfig, | |
| ) | |
| PEFT_AVAILABLE = True | |
| except ImportError: | |
| PEFT_AVAILABLE = False | |
| logger.warning("PEFT library not installed. LoRA training will not be available.") | |
| from acestep.training.configs import LoRAConfig | |
| def check_peft_available() -> bool: | |
| """Check if PEFT library is available.""" | |
| return PEFT_AVAILABLE | |
| def get_dit_target_modules(model) -> List[str]: | |
| """Get the list of module names in the DiT decoder that can have LoRA applied. | |
| Args: | |
| model: The AceStepConditionGenerationModel | |
| Returns: | |
| List of module names suitable for LoRA | |
| """ | |
| target_modules = [] | |
| # Focus on the decoder (DiT) attention layers | |
| if hasattr(model, 'decoder'): | |
| for name, module in model.decoder.named_modules(): | |
| # Target attention projection layers | |
| if any(proj in name for proj in ['q_proj', 'k_proj', 'v_proj', 'o_proj']): | |
| if isinstance(module, nn.Linear): | |
| target_modules.append(name) | |
| return target_modules | |
| def freeze_non_lora_parameters(model, freeze_encoder: bool = True) -> None: | |
| """Freeze all non-LoRA parameters in the model. | |
| Args: | |
| model: The model to freeze parameters for | |
| freeze_encoder: Whether to freeze the encoder (condition encoder) | |
| """ | |
| # Freeze all parameters first | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| # Count frozen and trainable parameters | |
| total_params = 0 | |
| trainable_params = 0 | |
| for name, param in model.named_parameters(): | |
| total_params += param.numel() | |
| if param.requires_grad: | |
| trainable_params += param.numel() | |
| logger.info(f"Frozen parameters: {total_params - trainable_params:,}") | |
| logger.info(f"Trainable parameters: {trainable_params:,}") | |
| def inject_lora_into_dit( | |
| model, | |
| lora_config: LoRAConfig, | |
| ) -> Tuple[Any, Dict[str, Any]]: | |
| """Inject LoRA adapters into the DiT decoder of the model. | |
| Args: | |
| model: The AceStepConditionGenerationModel | |
| lora_config: LoRA configuration | |
| Returns: | |
| Tuple of (peft_model, info_dict) | |
| """ | |
| if not PEFT_AVAILABLE: | |
| raise ImportError("PEFT library is required for LoRA training. Install with: pip install peft") | |
| # Get the decoder (DiT model) | |
| decoder = model.decoder | |
| # Create PEFT LoRA config | |
| peft_lora_config = LoraConfig( | |
| r=lora_config.r, | |
| lora_alpha=lora_config.alpha, | |
| lora_dropout=lora_config.dropout, | |
| target_modules=lora_config.target_modules, | |
| bias=lora_config.bias, | |
| task_type=TaskType.FEATURE_EXTRACTION, # For diffusion models | |
| ) | |
| # Apply LoRA to the decoder | |
| peft_decoder = get_peft_model(decoder, peft_lora_config) | |
| # Replace the decoder in the original model | |
| model.decoder = peft_decoder | |
| # Freeze all non-LoRA parameters | |
| # Freeze encoder, tokenizer, detokenizer | |
| for name, param in model.named_parameters(): | |
| # Only keep LoRA parameters trainable | |
| if 'lora_' not in name: | |
| param.requires_grad = False | |
| # Count parameters | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| info = { | |
| "total_params": total_params, | |
| "trainable_params": trainable_params, | |
| "trainable_ratio": trainable_params / total_params if total_params > 0 else 0, | |
| "lora_r": lora_config.r, | |
| "lora_alpha": lora_config.alpha, | |
| "target_modules": lora_config.target_modules, | |
| } | |
| logger.info(f"LoRA injected into DiT decoder:") | |
| logger.info(f" Total parameters: {total_params:,}") | |
| logger.info(f" Trainable parameters: {trainable_params:,} ({info['trainable_ratio']:.2%})") | |
| logger.info(f" LoRA rank: {lora_config.r}, alpha: {lora_config.alpha}") | |
| return model, info | |
| def save_lora_weights( | |
| model, | |
| output_dir: str, | |
| save_full_model: bool = False, | |
| ) -> str: | |
| """Save LoRA adapter weights. | |
| Args: | |
| model: Model with LoRA adapters | |
| output_dir: Directory to save weights | |
| save_full_model: Whether to save the full model state dict | |
| Returns: | |
| Path to saved weights | |
| """ | |
| os.makedirs(output_dir, exist_ok=True) | |
| if hasattr(model, 'decoder') and hasattr(model.decoder, 'save_pretrained'): | |
| # Save PEFT adapter | |
| adapter_path = os.path.join(output_dir, "adapter") | |
| model.decoder.save_pretrained(adapter_path) | |
| logger.info(f"LoRA adapter saved to {adapter_path}") | |
| return adapter_path | |
| elif save_full_model: | |
| # Save full model state dict (larger file) | |
| model_path = os.path.join(output_dir, "model.pt") | |
| torch.save(model.state_dict(), model_path) | |
| logger.info(f"Full model state dict saved to {model_path}") | |
| return model_path | |
| else: | |
| # Extract only LoRA parameters | |
| lora_state_dict = {} | |
| for name, param in model.named_parameters(): | |
| if 'lora_' in name: | |
| lora_state_dict[name] = param.data.clone() | |
| if not lora_state_dict: | |
| logger.warning("No LoRA parameters found to save!") | |
| return "" | |
| lora_path = os.path.join(output_dir, "lora_weights.pt") | |
| torch.save(lora_state_dict, lora_path) | |
| logger.info(f"LoRA weights saved to {lora_path}") | |
| return lora_path | |
| def load_lora_weights( | |
| model, | |
| lora_path: str, | |
| lora_config: Optional[LoRAConfig] = None, | |
| ) -> Any: | |
| """Load LoRA adapter weights into the model. | |
| Args: | |
| model: The base model (without LoRA) | |
| lora_path: Path to saved LoRA weights (adapter or .pt file) | |
| lora_config: LoRA configuration (required if loading from .pt file) | |
| Returns: | |
| Model with LoRA weights loaded | |
| """ | |
| if not os.path.exists(lora_path): | |
| raise FileNotFoundError(f"LoRA weights not found: {lora_path}") | |
| # Check if it's a PEFT adapter directory | |
| if os.path.isdir(lora_path): | |
| if not PEFT_AVAILABLE: | |
| raise ImportError("PEFT library is required to load adapter. Install with: pip install peft") | |
| # Load PEFT adapter | |
| peft_config = PeftConfig.from_pretrained(lora_path) | |
| model.decoder = PeftModel.from_pretrained(model.decoder, lora_path) | |
| logger.info(f"LoRA adapter loaded from {lora_path}") | |
| elif lora_path.endswith('.pt'): | |
| # Load from PyTorch state dict | |
| if lora_config is None: | |
| raise ValueError("lora_config is required when loading from .pt file") | |
| # First inject LoRA structure | |
| model, _ = inject_lora_into_dit(model, lora_config) | |
| # Load weights | |
| lora_state_dict = torch.load(lora_path, map_location='cpu', weights_only=True) | |
| # Load into model | |
| model_state = model.state_dict() | |
| for name, param in lora_state_dict.items(): | |
| if name in model_state: | |
| model_state[name].copy_(param) | |
| else: | |
| logger.warning(f"Unexpected key in LoRA state dict: {name}") | |
| logger.info(f"LoRA weights loaded from {lora_path}") | |
| else: | |
| raise ValueError(f"Unsupported LoRA weight format: {lora_path}") | |
| return model | |
| def save_training_checkpoint( | |
| model, | |
| optimizer, | |
| scheduler, | |
| epoch: int, | |
| global_step: int, | |
| output_dir: str, | |
| ) -> str: | |
| """Save a training checkpoint including LoRA weights and training state. | |
| Args: | |
| model: Model with LoRA adapters | |
| optimizer: Optimizer state | |
| scheduler: Scheduler state | |
| epoch: Current epoch number | |
| global_step: Current global step | |
| output_dir: Directory to save checkpoint | |
| Returns: | |
| Path to saved checkpoint directory | |
| """ | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Save LoRA adapter weights | |
| adapter_path = save_lora_weights(model, output_dir) | |
| # Save training state (optimizer, scheduler, epoch, step) | |
| training_state = { | |
| "epoch": epoch, | |
| "global_step": global_step, | |
| "optimizer_state_dict": optimizer.state_dict(), | |
| "scheduler_state_dict": scheduler.state_dict(), | |
| } | |
| state_path = os.path.join(output_dir, "training_state.pt") | |
| torch.save(training_state, state_path) | |
| logger.info(f"Training checkpoint saved to {output_dir} (epoch {epoch}, step {global_step})") | |
| return output_dir | |
| def load_training_checkpoint( | |
| checkpoint_dir: str, | |
| optimizer=None, | |
| scheduler=None, | |
| device: torch.device = None, | |
| ) -> Dict[str, Any]: | |
| """Load training checkpoint. | |
| Args: | |
| checkpoint_dir: Directory containing checkpoint files | |
| optimizer: Optimizer instance to load state into (optional) | |
| scheduler: Scheduler instance to load state into (optional) | |
| device: Device to load tensors to | |
| Returns: | |
| Dictionary with checkpoint info: | |
| - epoch: Saved epoch number | |
| - global_step: Saved global step | |
| - adapter_path: Path to adapter weights | |
| - loaded_optimizer: Whether optimizer state was loaded | |
| - loaded_scheduler: Whether scheduler state was loaded | |
| """ | |
| result = { | |
| "epoch": 0, | |
| "global_step": 0, | |
| "adapter_path": None, | |
| "loaded_optimizer": False, | |
| "loaded_scheduler": False, | |
| } | |
| # Find adapter path | |
| adapter_path = os.path.join(checkpoint_dir, "adapter") | |
| if os.path.exists(adapter_path): | |
| result["adapter_path"] = adapter_path | |
| elif os.path.exists(checkpoint_dir): | |
| result["adapter_path"] = checkpoint_dir | |
| # Load training state | |
| state_path = os.path.join(checkpoint_dir, "training_state.pt") | |
| if os.path.exists(state_path): | |
| map_location = device if device else "cpu" | |
| training_state = torch.load(state_path, map_location=map_location, weights_only=True) | |
| result["epoch"] = training_state.get("epoch", 0) | |
| result["global_step"] = training_state.get("global_step", 0) | |
| # Load optimizer state if provided | |
| if optimizer is not None and "optimizer_state_dict" in training_state: | |
| try: | |
| optimizer.load_state_dict(training_state["optimizer_state_dict"]) | |
| result["loaded_optimizer"] = True | |
| logger.info("Optimizer state loaded from checkpoint") | |
| except Exception as e: | |
| logger.warning(f"Failed to load optimizer state: {e}") | |
| # Load scheduler state if provided | |
| if scheduler is not None and "scheduler_state_dict" in training_state: | |
| try: | |
| scheduler.load_state_dict(training_state["scheduler_state_dict"]) | |
| result["loaded_scheduler"] = True | |
| logger.info("Scheduler state loaded from checkpoint") | |
| except Exception as e: | |
| logger.warning(f"Failed to load scheduler state: {e}") | |
| logger.info(f"Loaded checkpoint from epoch {result['epoch']}, step {result['global_step']}") | |
| else: | |
| # Fallback: extract epoch from path | |
| import re | |
| match = re.search(r'epoch_(\d+)', checkpoint_dir) | |
| if match: | |
| result["epoch"] = int(match.group(1)) | |
| logger.info(f"No training_state.pt found, extracted epoch {result['epoch']} from path") | |
| return result | |
| def merge_lora_weights(model) -> Any: | |
| """Merge LoRA weights into the base model. | |
| This permanently integrates the LoRA adaptations into the model weights. | |
| After merging, the model can be used without PEFT. | |
| Args: | |
| model: Model with LoRA adapters | |
| Returns: | |
| Model with merged weights | |
| """ | |
| if hasattr(model, 'decoder') and hasattr(model.decoder, 'merge_and_unload'): | |
| # PEFT model - merge and unload | |
| model.decoder = model.decoder.merge_and_unload() | |
| logger.info("LoRA weights merged into base model") | |
| else: | |
| logger.warning("Model does not support LoRA merging") | |
| return model | |
| def get_lora_info(model) -> Dict[str, Any]: | |
| """Get information about LoRA adapters in the model. | |
| Args: | |
| model: Model to inspect | |
| Returns: | |
| Dictionary with LoRA information | |
| """ | |
| info = { | |
| "has_lora": False, | |
| "lora_params": 0, | |
| "total_params": 0, | |
| "modules_with_lora": [], | |
| } | |
| total_params = 0 | |
| lora_params = 0 | |
| lora_modules = [] | |
| for name, param in model.named_parameters(): | |
| total_params += param.numel() | |
| if 'lora_' in name: | |
| lora_params += param.numel() | |
| # Extract module name | |
| module_name = name.rsplit('.lora_', 1)[0] | |
| if module_name not in lora_modules: | |
| lora_modules.append(module_name) | |
| info["total_params"] = total_params | |
| info["lora_params"] = lora_params | |
| info["has_lora"] = lora_params > 0 | |
| info["modules_with_lora"] = lora_modules | |
| if total_params > 0: | |
| info["lora_ratio"] = lora_params / total_params | |
| return info | |