File size: 14,677 Bytes
76a4048 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 | """
PEFT Utilities for Parameter-Efficient Fine-Tuning
Supports LoRA, AdaLoRA, IA3, Prefix Tuning, and Prompt Tuning
"""
import os
import json
import logging
from typing import Dict, List, Optional, Union, Any
from dataclasses import dataclass, field
import torch
from transformers import PreTrainedModel, PreTrainedTokenizer
logger = logging.getLogger(__name__)
# PEFT configuration classes
@dataclass
class LoRAConfig:
"""LoRA configuration"""
r: int = 8
lora_alpha: int = 16
lora_dropout: float = 0.05
target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj"])
bias: str = "none"
modules_to_save: List[str] = field(default_factory=list)
@dataclass
class AdaLoRAConfig:
"""AdaLoRA configuration"""
target_r: int = 8
init_r: int = 12
tinit: int = 200
tfinal: int = 1000
deltaT: int = 10
beta1: float = 0.85
beta2: float = 0.85
lora_alpha: int = 16
lora_dropout: float = 0.05
target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj"])
modules_to_save: List[str] = field(default_factory=list)
@dataclass
class IA3Config:
"""IA3 configuration"""
target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj", "k_proj"])
feedforward_modules: List[str] = field(default_factory=lambda: ["up_proj", "down_proj"])
modules_to_save: List[str] = field(default_factory=list)
@dataclass
class PrefixTuningConfig:
"""Prefix Tuning configuration"""
num_virtual_tokens: int = 20
encoder_hidden_size: Optional[int] = None
prefix_projection: bool = False
projection_dim: int = 128
dropout: float = 0.0
@dataclass
class PromptTuningConfig:
"""Prompt Tuning configuration"""
num_virtual_tokens: int = 20
tokenizer_name_or_path: Optional[str] = None
num_layers: Optional[int] = None
token_dim: Optional[int] = None
PEFT_CONFIG_MAP = {
"lora": LoRAConfig,
"adalora": AdaLoRAConfig,
"ia3": IA3Config,
"prefix_tuning": PrefixTuningConfig,
"prompt_tuning": PromptTuningConfig,
}
def get_peft_config(peft_type: str, **kwargs) -> Any:
"""
Get PEFT configuration for the specified type.
Args:
peft_type: Type of PEFT method ('lora', 'adalora', 'ia3', 'prefix_tuning', 'prompt_tuning')
**kwargs: Configuration parameters
Returns:
PEFT configuration object
"""
peft_type = peft_type.lower()
if peft_type not in PEFT_CONFIG_MAP:
raise ValueError(f"Unknown PEFT type: {peft_type}. Available: {list(PEFT_CONFIG_MAP.keys())}")
config_class = PEFT_CONFIG_MAP[peft_type]
return config_class(**kwargs)
def apply_peft_to_model(
model: PreTrainedModel,
peft_type: str,
config: Optional[Union[Dict, Any]] = None,
**kwargs
) -> PreTrainedModel:
"""
Apply PEFT to a model.
Args:
model: The base model to apply PEFT to
peft_type: Type of PEFT method
config: PEFT configuration (dict or dataclass)
**kwargs: Additional configuration parameters
Returns:
Model with PEFT applied
"""
try:
from peft import (
LoraConfig, AdaLoraConfig, IA3Config,
PrefixTuningConfig, PromptTuningConfig,
get_peft_model, TaskType, prepare_model_for_kbit_training
)
except ImportError:
logger.warning("PEFT library not installed. Returning original model.")
return model
peft_type = peft_type.lower()
# Build PEFT config
if config is None:
config = {}
if isinstance(config, dict):
config_data = {**config, **kwargs}
else:
config_data = {k: v for k, v in vars(config).items() if not k.startswith('_')}
config_data.update(kwargs)
# Map to PEFT library config classes
peft_config_map = {
"lora": LoraConfig,
"adalora": AdaLoraConfig,
"ia3": IA3Config,
"prefix_tuning": PrefixTuningConfig,
"prompt_tuning": PromptTuningConfig,
}
if peft_type not in peft_config_map:
raise ValueError(f"Unknown PEFT type: {peft_type}")
peft_config_class = peft_config_map[peft_type]
# Determine task type
task_type = config_data.pop('task_type', None)
if task_type:
task_type_map = {
'causal-lm': TaskType.CAUSAL_LM,
'seq2seq': TaskType.SEQ_2_SEQ_LM,
'token-classification': TaskType.TOKEN_CLS,
'text-classification': TaskType.SEQ_CLS,
'question-answering': TaskType.QUESTION_ANS,
}
task_type = task_type_map.get(task_type)
if task_type:
config_data['task_type'] = task_type
# Create PEFT config
peft_config = peft_config_class(**config_data)
# Prepare model for k-bit training if needed
if hasattr(model, 'is_loaded_in_8bit') and model.is_loaded_in_8bit:
model = prepare_model_for_kbit_training(model)
elif hasattr(model, 'is_loaded_in_4bit') and model.is_loaded_in_4bit:
model = prepare_model_for_kbit_training(model)
# Apply PEFT
model = get_peft_model(model, peft_config)
# Log trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
all_params = sum(p.numel() for p in model.parameters())
logger.info(f"Trainable params: {trainable_params:,} / {all_params:,} ({100 * trainable_params / all_params:.2f}%)")
return model
def get_target_modules_for_architecture(model_name: str) -> List[str]:
"""
Get recommended target modules based on model architecture.
Args:
model_name: Name of the model
Returns:
List of target module names
"""
model_name_lower = model_name.lower()
# LLaMA, Alpaca, Vicuna
if any(name in model_name_lower for name in ['llama', 'alpaca', 'vicuna']):
return ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj']
# Mistral
if 'mistral' in model_name_lower:
return ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj']
# BERT, RoBERTa, DeBERTa
if any(name in model_name_lower for name in ['bert', 'roberta', 'deberta']):
return ['query', 'key', 'value', 'dense']
# T5, Flan-T5
if 't5' in model_name_lower:
return ['q', 'k', 'v', 'o', 'wi', 'wo']
# GPT-2, GPT-Neo, GPT-J
if any(name in model_name_lower for name in ['gpt2', 'gpt-neo', 'gptj', 'gpt-j']):
return ['c_attn', 'c_proj', 'mlp.c_fc', 'mlp.c_proj']
# Bloom
if 'bloom' in model_name_lower:
return ['query_key_value', 'dense', 'dense_h_to_4h', 'dense_4h_to_h']
# OPT
if 'opt' in model_name_lower:
return ['q_proj', 'k_proj', 'v_proj', 'out_proj', 'fc1', 'fc2']
# Falcon
if 'falcon' in model_name_lower:
return ['query_key_value', 'dense', 'dense_h_to_4h', 'dense_4h_to_h']
# Default for transformer models
return ['q_proj', 'v_proj']
def estimate_lora_parameters(
base_model_params: int,
r: int,
target_modules: List[str],
lora_alpha: int = 16
) -> Dict[str, int]:
"""
Estimate the number of trainable parameters for LoRA.
Args:
base_model_params: Number of parameters in the base model
r: LoRA rank
target_modules: List of target module names
lora_alpha: LoRA alpha parameter
Returns:
Dictionary with parameter estimates
"""
# Rough estimate: each target module gets 2 LoRA matrices (A and B)
# Size depends on layer dimensions and rank
# For a typical transformer layer:
# - attention projections: hidden_size x hidden_size
# - LoRA adds: hidden_size x r + r x hidden_size = 2 * hidden_size * r per module
# Estimate hidden size from total params (rough approximation)
hidden_size = int((base_model_params ** 0.5) * 0.5)
# Estimate params per target module
params_per_module = 2 * hidden_size * r
# Total trainable params (rough estimate)
total_lora_params = params_per_module * len(target_modules)
return {
'estimated_trainable_params': total_lora_params,
'params_per_module': params_per_module,
'compression_ratio': base_model_params / total_lora_params if total_lora_params > 0 else 0,
'memory_reduction_percent': 100 * (1 - total_lora_params / base_model_params) if base_model_params > 0 else 0
}
def save_peft_model(
model,
output_dir: str,
tokenizer: Optional[PreTrainedTokenizer] = None,
save_merged: bool = False
) -> Dict[str, str]:
"""
Save PEFT model and associated files.
Args:
model: PEFT model to save
output_dir: Directory to save to
tokenizer: Optional tokenizer to save
save_merged: Whether to save merged model
Returns:
Dictionary with saved file paths
"""
os.makedirs(output_dir, exist_ok=True)
saved_files = []
try:
# Save PEFT adapters
model.save_pretrained(output_dir)
saved_files.append(f"{output_dir}/adapter_config.json")
saved_files.append(f"{output_dir}/adapter_model.safetensors")
# Save tokenizer if provided
if tokenizer:
tokenizer.save_pretrained(output_dir)
saved_files.append(f"{output_dir}/tokenizer.json")
# Optionally save merged model
if save_merged:
try:
merged_model = model.merge_and_unload()
merged_dir = os.path.join(output_dir, "merged")
merged_model.save_pretrained(merged_dir)
if tokenizer:
tokenizer.save_pretrained(merged_dir)
saved_files.append(f"{merged_dir}/pytorch_model.bin")
except Exception as e:
logger.warning(f"Could not merge model: {e}")
# Save training config
config = {
'peft_type': model.active_peft_config.peft_type.value if hasattr(model, 'active_peft_config') else 'unknown',
'trainable_params': sum(p.numel() for p in model.parameters() if p.requires_grad),
'total_params': sum(p.numel() for p in model.parameters()),
}
config_path = os.path.join(output_dir, "training_config.json")
with open(config_path, 'w') as f:
json.dump(config, f, indent=2)
saved_files.append(config_path)
logger.info(f"Saved PEFT model to {output_dir}")
except Exception as e:
logger.error(f"Error saving PEFT model: {e}")
raise
return {'saved_files': saved_files, 'output_dir': output_dir}
def load_peft_model(
base_model_name: str,
peft_model_path: str,
device: str = 'auto'
):
"""
Load a PEFT model.
Args:
base_model_name: Name or path of the base model
peft_model_path: Path to the saved PEFT adapters
device: Device to load to
Returns:
Loaded PEFT model
"""
try:
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.float16 if device != 'cpu' else torch.float32,
device_map=device
)
# Load PEFT model
model = PeftModel.from_pretrained(base_model, peft_model_path)
return model
except Exception as e:
logger.error(f"Error loading PEFT model: {e}")
raise
def get_peft_memory_requirements(
model_params: int,
peft_type: str = 'lora',
r: int = 8,
batch_size: int = 1,
seq_length: int = 512,
gradient_checkpointing: bool = True
) -> Dict[str, float]:
"""
Estimate memory requirements for PEFT training.
Args:
model_params: Number of model parameters
peft_type: Type of PEFT method
r: LoRA rank (if applicable)
batch_size: Training batch size
seq_length: Sequence length
gradient_checkpointing: Whether gradient checkpointing is enabled
Returns:
Dictionary with memory estimates in GB
"""
# Base model memory (FP16)
base_memory = model_params * 2 / 1e9
# Optimizer states (AdamW: 2 states per param)
# Only for trainable params with PEFT
trainable_ratio = r / 512 # Approximate ratio for LoRA
trainable_params = model_params * trainable_ratio
optimizer_memory = trainable_params * 2 * 4 / 1e9 # 2 states, FP32
# Gradients (only for trainable params)
gradient_memory = trainable_params * 2 / 1e9
# Activations (depends on batch size, seq length, and gradient checkpointing)
# Rough estimate: ~batch_size * seq_length * hidden_size * num_layers
activation_memory = batch_size * seq_length * (model_params ** 0.5) * 0.1 / 1e9
if gradient_checkpointing:
activation_memory *= 0.2 # Significant reduction
# Total
total_memory = base_memory + optimizer_memory + gradient_memory + activation_memory
return {
'base_model_gb': round(base_memory, 2),
'optimizer_states_gb': round(optimizer_memory, 2),
'gradients_gb': round(gradient_memory, 2),
'activations_gb': round(activation_memory, 2),
'total_gb': round(total_memory, 2),
'peak_gb': round(total_memory * 1.1, 2), # 10% buffer
'recommended_gpu_vram': round(total_memory * 1.2, 2) # 20% buffer
}
# Convenience function for quick LoRA setup
def quick_lora_setup(
model: PreTrainedModel,
r: int = 8,
lora_alpha: int = 16,
lora_dropout: float = 0.05,
target_modules: Optional[List[str]] = None
) -> PreTrainedModel:
"""
Quick setup for LoRA fine-tuning.
Args:
model: Base model
r: LoRA rank
lora_alpha: LoRA alpha
lora_dropout: Dropout rate
target_modules: Target modules (auto-detected if None)
Returns:
Model with LoRA applied
"""
if target_modules is None:
# Try to auto-detect from model config
model_name = getattr(model.config, '_name_or_path', '')
target_modules = get_target_modules_for_architecture(model_name)
return apply_peft_to_model(
model,
'lora',
r=r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
target_modules=target_modules
) |