File size: 1,755 Bytes
12fd5f2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 | """
LoRA adapter configuration and management.
Wraps PEFT LoRA utilities for applying parameter-efficient
fine-tuning to the base model.
"""
from peft import LoraConfig, TaskType, get_peft_model
from typing import List, Optional
from loguru import logger
def create_lora_config(
task_type: TaskType,
r: int = 16,
lora_alpha: int = 32,
target_modules: Optional[List[str]] = None,
lora_dropout: float = 0.05,
) -> LoraConfig:
"""Create a LoRA configuration for the given task type."""
if target_modules is None:
target_modules = ["q", "v"]
config = LoraConfig(
task_type=task_type,
r=r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
target_modules=target_modules,
bias="none",
inference_mode=False,
)
logger.info(f"Created LoRA config: r={r}, alpha={lora_alpha}, dropout={lora_dropout}")
return config
def apply_lora(model, lora_config: LoraConfig):
"""Apply LoRA adapters to a model and return the wrapped model."""
peft_model = get_peft_model(model, lora_config)
trainable = sum(p.numel() for p in peft_model.parameters() if p.requires_grad)
total = sum(p.numel() for p in peft_model.parameters())
logger.info(f"LoRA applied: {trainable:,}/{total:,} trainable params ({100*trainable/total:.2f}%)")
return peft_model
def merge_lora_weights(model):
"""Merge LoRA weights into the base model for inference.
After merging, the model behaves like a regular model with
LoRA modifications baked in, removing the adapter overhead.
"""
logger.info("Merging LoRA weights into base model...")
merged = model.merge_and_unload()
logger.info("LoRA weights merged successfully")
return merged
|