rewrite / src /model /lora_adapter.py
morpheuslord's picture
Add files using upload-large-folder tool
12fd5f2 verified
"""
LoRA adapter configuration and management.
Wraps PEFT LoRA utilities for applying parameter-efficient
fine-tuning to the base model.
"""
from peft import LoraConfig, TaskType, get_peft_model
from typing import List, Optional
from loguru import logger
def create_lora_config(
task_type: TaskType,
r: int = 16,
lora_alpha: int = 32,
target_modules: Optional[List[str]] = None,
lora_dropout: float = 0.05,
) -> LoraConfig:
"""Create a LoRA configuration for the given task type."""
if target_modules is None:
target_modules = ["q", "v"]
config = LoraConfig(
task_type=task_type,
r=r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
target_modules=target_modules,
bias="none",
inference_mode=False,
)
logger.info(f"Created LoRA config: r={r}, alpha={lora_alpha}, dropout={lora_dropout}")
return config
def apply_lora(model, lora_config: LoraConfig):
"""Apply LoRA adapters to a model and return the wrapped model."""
peft_model = get_peft_model(model, lora_config)
trainable = sum(p.numel() for p in peft_model.parameters() if p.requires_grad)
total = sum(p.numel() for p in peft_model.parameters())
logger.info(f"LoRA applied: {trainable:,}/{total:,} trainable params ({100*trainable/total:.2f}%)")
return peft_model
def merge_lora_weights(model):
"""Merge LoRA weights into the base model for inference.
After merging, the model behaves like a regular model with
LoRA modifications baked in, removing the adapter overhead.
"""
logger.info("Merging LoRA weights into base model...")
merged = model.merge_and_unload()
logger.info("LoRA weights merged successfully")
return merged