File size: 1,755 Bytes
12fd5f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
"""
LoRA adapter configuration and management.
Wraps PEFT LoRA utilities for applying parameter-efficient
fine-tuning to the base model.
"""

from peft import LoraConfig, TaskType, get_peft_model
from typing import List, Optional
from loguru import logger


def create_lora_config(
    task_type: TaskType,
    r: int = 16,
    lora_alpha: int = 32,
    target_modules: Optional[List[str]] = None,
    lora_dropout: float = 0.05,
) -> LoraConfig:
    """Create a LoRA configuration for the given task type."""
    if target_modules is None:
        target_modules = ["q", "v"]

    config = LoraConfig(
        task_type=task_type,
        r=r,
        lora_alpha=lora_alpha,
        lora_dropout=lora_dropout,
        target_modules=target_modules,
        bias="none",
        inference_mode=False,
    )
    logger.info(f"Created LoRA config: r={r}, alpha={lora_alpha}, dropout={lora_dropout}")
    return config


def apply_lora(model, lora_config: LoraConfig):
    """Apply LoRA adapters to a model and return the wrapped model."""
    peft_model = get_peft_model(model, lora_config)
    trainable = sum(p.numel() for p in peft_model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in peft_model.parameters())
    logger.info(f"LoRA applied: {trainable:,}/{total:,} trainable params ({100*trainable/total:.2f}%)")
    return peft_model


def merge_lora_weights(model):
    """Merge LoRA weights into the base model for inference.

    After merging, the model behaves like a regular model with
    LoRA modifications baked in, removing the adapter overhead.
    """
    logger.info("Merging LoRA weights into base model...")
    merged = model.merge_and_unload()
    logger.info("LoRA weights merged successfully")
    return merged