File size: 4,136 Bytes
ec8f374
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""
PEFT Trainer Module

General Parameter-Efficient Fine-Tuning trainer supporting multiple PEFT methods.
"""

from typing import Optional, List, Dict, Any
from pathlib import Path
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer
)
from peft import (
    get_peft_model,
    PeftConfig,
    PeftModel,
    prepare_model_for_kbit_training
)


class PEFTTrainer:
    """
    General PEFT Trainer supporting multiple parameter-efficient fine-tuning methods.

    Supports:
    - LoRA (Low-Rank Adaptation)
    - Prefix Tuning
    - P-Tuning
    - Prompt Tuning
    - IA3 (Infused Adapter by Inhibiting and Amplifying Inner Activations)
    """

    def __init__(
        self,
        model_name: str,
        peft_config: PeftConfig,
        output_dir: str = "./models/peft_output"
    ):
        """
        Initialize PEFT Trainer.

        Args:
            model_name: HuggingFace model path or name
            peft_config: PEFT configuration (LoraConfig, PrefixTuningConfig, etc.)
            output_dir: Directory for saving checkpoints and final model
        """
        self.model_name = model_name
        self.peft_config = peft_config
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)

        self.model = None
        self.tokenizer = None
        self.trainer = None

    def load_model(
        self,
        use_4bit: bool = False,
        use_8bit: bool = False,
        device_map: str = "auto"
    ) -> None:
        """
        Load model with PEFT configuration.

        Args:
            use_4bit: Use 4-bit quantization
            use_8bit: Use 8-bit quantization
            device_map: Device mapping strategy
        """
        print(f"Loading model: {self.model_name}")

        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_name,
            trust_remote_code=True
        )

        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # Quantization config
        quantization_config = None
        if use_4bit or use_8bit:
            from transformers import BitsAndBytesConfig
            if use_4bit:
                quantization_config = BitsAndBytesConfig(
                    load_in_4bit=True,
                    bnb_4bit_compute_dtype=torch.float16,
                    bnb_4bit_use_double_quant=True,
                    bnb_4bit_quant_type="nf4"
                )
            else:
                quantization_config = BitsAndBytesConfig(load_in_8bit=True)

        # Load base model
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            quantization_config=quantization_config,
            device_map=device_map,
            trust_remote_code=True
        )

        # Prepare for k-bit training if quantized
        if use_4bit or use_8bit:
            self.model = prepare_model_for_kbit_training(self.model)

        # Apply PEFT
        self.model = get_peft_model(self.model, self.peft_config)

        # Print trainable parameters
        self.model.print_trainable_parameters()

        print("βœ… Model loaded with PEFT")

    def save_model(self, save_path: Optional[str] = None) -> None:
        """
        Save PEFT adapter weights.

        Args:
            save_path: Path to save adapters
        """
        if save_path is None:
            save_path = str(self.output_dir / "final_model")

        Path(save_path).mkdir(parents=True, exist_ok=True)

        self.model.save_pretrained(save_path)
        self.tokenizer.save_pretrained(save_path)

        print(f"βœ… PEFT model saved to: {save_path}")

    def load_adapter(self, adapter_path: str) -> None:
        """
        Load pre-trained PEFT adapter.

        Args:
            adapter_path: Path to adapter weights
        """
        print(f"Loading PEFT adapter from: {adapter_path}")

        self.model = PeftModel.from_pretrained(
            self.model,
            adapter_path
        )

        print("βœ… Adapter loaded")