File size: 9,552 Bytes
9f5c8f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
"""
LoRA Utilities for ACE-Step

Provides utilities for injecting LoRA adapters into the DiT decoder model.
Uses PEFT (Parameter-Efficient Fine-Tuning) library for LoRA implementation.
"""

import os
from typing import Optional, List, Dict, Any, Tuple
from loguru import logger

import torch
import torch.nn as nn

try:
    from peft import (
        get_peft_model,
        LoraConfig,
        TaskType,
        PeftModel,
        PeftConfig,
    )
    PEFT_AVAILABLE = True
except ImportError:
    PEFT_AVAILABLE = False
    logger.warning("PEFT library not installed. LoRA training will not be available.")

from acestep.training.configs import LoRAConfig


def check_peft_available() -> bool:
    """Check if PEFT library is available."""
    return PEFT_AVAILABLE


def get_dit_target_modules(model) -> List[str]:
    """Get the list of module names in the DiT decoder that can have LoRA applied.
    
    Args:
        model: The AceStepConditionGenerationModel
        
    Returns:
        List of module names suitable for LoRA
    """
    target_modules = []
    
    # Focus on the decoder (DiT) attention layers
    if hasattr(model, 'decoder'):
        for name, module in model.decoder.named_modules():
            # Target attention projection layers
            if any(proj in name for proj in ['q_proj', 'k_proj', 'v_proj', 'o_proj']):
                if isinstance(module, nn.Linear):
                    target_modules.append(name)
    
    return target_modules


def freeze_non_lora_parameters(model, freeze_encoder: bool = True) -> None:
    """Freeze all non-LoRA parameters in the model.
    
    Args:
        model: The model to freeze parameters for
        freeze_encoder: Whether to freeze the encoder (condition encoder)
    """
    # Freeze all parameters first
    for param in model.parameters():
        param.requires_grad = False
    
    # Count frozen and trainable parameters
    total_params = 0
    trainable_params = 0
    
    for name, param in model.named_parameters():
        total_params += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    
    logger.info(f"Frozen parameters: {total_params - trainable_params:,}")
    logger.info(f"Trainable parameters: {trainable_params:,}")


def inject_lora_into_dit(
    model,
    lora_config: LoRAConfig,
) -> Tuple[Any, Dict[str, Any]]:
    """Inject LoRA adapters into the DiT decoder of the model.
    
    Args:
        model: The AceStepConditionGenerationModel
        lora_config: LoRA configuration
        
    Returns:
        Tuple of (peft_model, info_dict)
    """
    if not PEFT_AVAILABLE:
        raise ImportError("PEFT library is required for LoRA training. Install with: pip install peft")
    
    # Get the decoder (DiT model)
    decoder = model.decoder
    
    # Create PEFT LoRA config
    peft_lora_config = LoraConfig(
        r=lora_config.r,
        lora_alpha=lora_config.alpha,
        lora_dropout=lora_config.dropout,
        target_modules=lora_config.target_modules,
        bias=lora_config.bias,
        task_type=TaskType.FEATURE_EXTRACTION,  # For diffusion models
    )
    
    # Apply LoRA to the decoder
    peft_decoder = get_peft_model(decoder, peft_lora_config)
    
    # Replace the decoder in the original model
    model.decoder = peft_decoder
    
    # Freeze all non-LoRA parameters
    # Freeze encoder, tokenizer, detokenizer
    for name, param in model.named_parameters():
        # Only keep LoRA parameters trainable
        if 'lora_' not in name:
            param.requires_grad = False
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    info = {
        "total_params": total_params,
        "trainable_params": trainable_params,
        "trainable_ratio": trainable_params / total_params if total_params > 0 else 0,
        "lora_r": lora_config.r,
        "lora_alpha": lora_config.alpha,
        "target_modules": lora_config.target_modules,
    }
    
    logger.info(f"LoRA injected into DiT decoder:")
    logger.info(f"  Total parameters: {total_params:,}")
    logger.info(f"  Trainable parameters: {trainable_params:,} ({info['trainable_ratio']:.2%})")
    logger.info(f"  LoRA rank: {lora_config.r}, alpha: {lora_config.alpha}")
    
    return model, info


def save_lora_weights(
    model,
    output_dir: str,
    save_full_model: bool = False,
) -> str:
    """Save LoRA adapter weights.
    
    Args:
        model: Model with LoRA adapters
        output_dir: Directory to save weights
        save_full_model: Whether to save the full model state dict
        
    Returns:
        Path to saved weights
    """
    os.makedirs(output_dir, exist_ok=True)
    
    if hasattr(model, 'decoder') and hasattr(model.decoder, 'save_pretrained'):
        # Save PEFT adapter
        adapter_path = os.path.join(output_dir, "adapter")
        model.decoder.save_pretrained(adapter_path)
        logger.info(f"LoRA adapter saved to {adapter_path}")
        return adapter_path
    elif save_full_model:
        # Save full model state dict (larger file)
        model_path = os.path.join(output_dir, "model.pt")
        torch.save(model.state_dict(), model_path)
        logger.info(f"Full model state dict saved to {model_path}")
        return model_path
    else:
        # Extract only LoRA parameters
        lora_state_dict = {}
        for name, param in model.named_parameters():
            if 'lora_' in name:
                lora_state_dict[name] = param.data.clone()
        
        if not lora_state_dict:
            logger.warning("No LoRA parameters found to save!")
            return ""
        
        lora_path = os.path.join(output_dir, "lora_weights.pt")
        torch.save(lora_state_dict, lora_path)
        logger.info(f"LoRA weights saved to {lora_path}")
        return lora_path


def load_lora_weights(
    model,
    lora_path: str,
    lora_config: Optional[LoRAConfig] = None,
) -> Any:
    """Load LoRA adapter weights into the model.
    
    Args:
        model: The base model (without LoRA)
        lora_path: Path to saved LoRA weights (adapter or .pt file)
        lora_config: LoRA configuration (required if loading from .pt file)
        
    Returns:
        Model with LoRA weights loaded
    """
    if not os.path.exists(lora_path):
        raise FileNotFoundError(f"LoRA weights not found: {lora_path}")
    
    # Check if it's a PEFT adapter directory
    if os.path.isdir(lora_path):
        if not PEFT_AVAILABLE:
            raise ImportError("PEFT library is required to load adapter. Install with: pip install peft")
        
        # Load PEFT adapter
        peft_config = PeftConfig.from_pretrained(lora_path)
        model.decoder = PeftModel.from_pretrained(model.decoder, lora_path)
        logger.info(f"LoRA adapter loaded from {lora_path}")
    
    elif lora_path.endswith('.pt'):
        # Load from PyTorch state dict
        if lora_config is None:
            raise ValueError("lora_config is required when loading from .pt file")
        
        # First inject LoRA structure
        model, _ = inject_lora_into_dit(model, lora_config)
        
        # Load weights
        lora_state_dict = torch.load(lora_path, map_location='cpu')
        
        # Load into model
        model_state = model.state_dict()
        for name, param in lora_state_dict.items():
            if name in model_state:
                model_state[name].copy_(param)
            else:
                logger.warning(f"Unexpected key in LoRA state dict: {name}")
        
        logger.info(f"LoRA weights loaded from {lora_path}")
    
    else:
        raise ValueError(f"Unsupported LoRA weight format: {lora_path}")
    
    return model


def merge_lora_weights(model) -> Any:
    """Merge LoRA weights into the base model.
    
    This permanently integrates the LoRA adaptations into the model weights.
    After merging, the model can be used without PEFT.
    
    Args:
        model: Model with LoRA adapters
        
    Returns:
        Model with merged weights
    """
    if hasattr(model, 'decoder') and hasattr(model.decoder, 'merge_and_unload'):
        # PEFT model - merge and unload
        model.decoder = model.decoder.merge_and_unload()
        logger.info("LoRA weights merged into base model")
    else:
        logger.warning("Model does not support LoRA merging")
    
    return model


def get_lora_info(model) -> Dict[str, Any]:
    """Get information about LoRA adapters in the model.
    
    Args:
        model: Model to inspect
        
    Returns:
        Dictionary with LoRA information
    """
    info = {
        "has_lora": False,
        "lora_params": 0,
        "total_params": 0,
        "modules_with_lora": [],
    }
    
    total_params = 0
    lora_params = 0
    lora_modules = []
    
    for name, param in model.named_parameters():
        total_params += param.numel()
        if 'lora_' in name:
            lora_params += param.numel()
            # Extract module name
            module_name = name.rsplit('.lora_', 1)[0]
            if module_name not in lora_modules:
                lora_modules.append(module_name)
    
    info["total_params"] = total_params
    info["lora_params"] = lora_params
    info["has_lora"] = lora_params > 0
    info["modules_with_lora"] = lora_modules
    
    if total_params > 0:
        info["lora_ratio"] = lora_params / total_params
    
    return info