File size: 9,506 Bytes
dbe2c62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
import os
import numpy as np
import pandas as pd
import json

from typing import Optional, Union

import evaluate
from datasets import Dataset, DatasetDict, load_from_disk

from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    EarlyStoppingCallback,
    set_seed,
)


class SummarizationTrainer:
    """
    Fine-tune mô hình tóm tắt (Seq2Seq) đa dụng — thống nhất interface:
    run(Checkpoint, ModelPath, DataPath | dataset, tokenizer)
    """

    def __init__(
        self,
        Max_Input_Length: int = 1024,
        Max_Target_Length: int = 256,
        prefix: str = "",
        input_column: str = "article",
        target_column: str = "summary",
        Learning_Rate: float = 3e-5,
        Weight_Decay: float = 0.01,
        Batch_Size: int = 8,
        Num_Train_Epochs: int = 3,
        gradient_accumulation_steps: int = 1,
        warmup_ratio: float = 0.05,
        lr_scheduler_type: str = "linear",
        seed: int = 42,
        num_beams: int = 4,
        generation_max_length: Optional[int] = None,
        fp16: bool = True,
        early_stopping_patience: int = 2,
        logging_steps: int = 200,
        report_to: str = "none",
    ):
        # Hyperparams
        self.Max_Input_Length = Max_Input_Length
        self.Max_Target_Length = Max_Target_Length
        self.prefix = prefix
        self.input_column = input_column
        self.target_column = target_column

        self.Learning_Rate = Learning_Rate
        self.Weight_Decay = Weight_Decay
        self.Batch_Size = Batch_Size
        self.Num_Train_Epochs = Num_Train_Epochs
        self.gradient_accumulation_steps = gradient_accumulation_steps
        self.warmup_ratio = warmup_ratio
        self.lr_scheduler_type = lr_scheduler_type
        self.seed = seed

        self.num_beams = num_beams
        self.generation_max_length = generation_max_length
        self.fp16 = fp16
        self.early_stopping_patience = early_stopping_patience
        self.logging_steps = logging_steps
        self.report_to = report_to

        self._rouge = evaluate.load("rouge")
        self._tokenizer = None
        self._model = None

    # =========================================================
    # 1️⃣  Đọc dữ liệu JSONL hoặc Arrow
    # =========================================================
    def _load_jsonl_to_datasetdict(self, DataPath: str) -> DatasetDict:
        print(f"Đang tải dữ liệu từ {DataPath} ...")
        data_list = []
        with open(DataPath, "r", encoding="utf-8") as f:
            for line in f:
                if not line.strip():
                    continue
                try:
                    data_list.append(json.loads(line))
                except json.JSONDecodeError:
                    continue

        df = pd.DataFrame(data_list)
        if self.input_column not in df or self.target_column not in df:
            raise ValueError(f"File {DataPath} thiếu cột {self.input_column}/{self.target_column}")
        df = df[[self.input_column, self.target_column]].dropna()

        dataset = Dataset.from_pandas(df, preserve_index=False)
        split = dataset.train_test_split(test_size=0.1, seed=self.seed)
        print(f"✔ Dữ liệu chia: {len(split['train'])} train / {len(split['test'])} validation")
        return DatasetDict({"train": split["train"], "validation": split["test"]})

    def _ensure_datasetdict(self, dataset: Optional[Union[Dataset, DatasetDict]], DataPath: Optional[str]) -> DatasetDict:
        if dataset is not None:
            if isinstance(dataset, DatasetDict):
                return dataset
            if isinstance(dataset, Dataset):
                split = dataset.train_test_split(test_size=0.1, seed=self.seed)
                return DatasetDict({"train": split["train"], "validation": split["test"]})
            raise TypeError("dataset phải là datasets.Dataset hoặc datasets.DatasetDict.")
        if DataPath:
            if os.path.isdir(DataPath):
                print(f"Load DatasetDict từ thư mục Arrow: {DataPath}")
                return load_from_disk(DataPath)
            return self._load_jsonl_to_datasetdict(DataPath)
        raise ValueError("Cần truyền dataset hoặc DataPath")

    # =========================================================
    # 2️⃣  Token hóa
    # =========================================================
    def _preprocess_function(self, examples):
        inputs = examples[self.input_column]
        if self.prefix:
            inputs = [self.prefix + x for x in inputs]
        model_inputs = self._tokenizer(inputs, max_length=self.Max_Input_Length, truncation=True)
        with self._tokenizer.as_target_tokenizer():
            labels = self._tokenizer(examples[self.target_column], max_length=self.Max_Target_Length, truncation=True)
        model_inputs["labels"] = labels["input_ids"]
        return model_inputs

    # =========================================================
    # 3️⃣  Tính điểm ROUGE
    # =========================================================
    def _compute_metrics(self, eval_pred):
        preds, labels = eval_pred
        decoded_preds = self._tokenizer.batch_decode(preds, skip_special_tokens=True)
        labels = np.where(labels != -100, labels, self._tokenizer.pad_token_id)
        decoded_labels = self._tokenizer.batch_decode(labels, skip_special_tokens=True)
        result = self._rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
        return {k: round(v * 100, 4) for k, v in result.items()}

    # =========================================================
    # 4️⃣  Chạy huấn luyện
    # =========================================================
    def run(
        self,
        Checkpoint: str,
        ModelPath: str,
        DataPath: Optional[str] = None,
        dataset: Optional[Union[Dataset, DatasetDict]] = None,
        tokenizer: Optional[AutoTokenizer] = None,
    ):
        set_seed(self.seed)
        ds = self._ensure_datasetdict(dataset, DataPath)
        self._tokenizer = tokenizer or AutoTokenizer.from_pretrained(Checkpoint)
        print(f"Tải model checkpoint: {Checkpoint}")
        self._model = AutoModelForSeq2SeqLM.from_pretrained(Checkpoint)

        print("Tokenizing dữ liệu ...")
        tokenized = ds.map(self._preprocess_function, batched=True)
        data_collator = DataCollatorForSeq2Seq(tokenizer=self._tokenizer, model=self._model)
        gen_max_len = self.generation_max_length or self.Max_Target_Length

        training_args = Seq2SeqTrainingArguments(
            output_dir=ModelPath,
            evaluation_strategy="epoch",
            save_strategy="epoch",
            learning_rate=self.Learning_Rate,
            per_device_train_batch_size=self.Batch_Size,
            per_device_eval_batch_size=self.Batch_Size,
            weight_decay=self.Weight_Decay,
            num_train_epochs=self.Num_Train_Epochs,
            predict_with_generate=True,
            generation_max_length=gen_max_len,
            generation_num_beams=self.num_beams,
            fp16=self.fp16,
            gradient_accumulation_steps=self.gradient_accumulation_steps,
            warmup_ratio=self.warmup_ratio,
            lr_scheduler_type=self.lr_scheduler_type,
            logging_steps=self.logging_steps,
            load_best_model_at_end=True,
            metric_for_best_model="rougeL",
            greater_is_better=True,
            save_total_limit=3,
            report_to=self.report_to,
        )

        trainer = Seq2SeqTrainer(
            model=self._model,
            args=training_args,
            train_dataset=tokenized["train"],
            eval_dataset=tokenized["validation"],
            tokenizer=self._tokenizer,
            data_collator=data_collator,
            compute_metrics=self._compute_metrics,
            callbacks=[EarlyStoppingCallback(early_stopping_patience=self.early_stopping_patience)],
        )

        print("\n🚀 BẮT ĐẦU HUẤN LUYỆN ...")
        trainer.train()
        print("✅ HUẤN LUYỆN HOÀN TẤT.")
        trainer.save_model(ModelPath)
        self._tokenizer.save_pretrained(ModelPath)
        print(f"💾 Đã lưu model & tokenizer tại: {ModelPath}")
        return trainer

    # =========================================================
    # 5️⃣  Sinh tóm tắt
    # =========================================================
    def generate(self, text: str, max_new_tokens: Optional[int] = None) -> str:
        if self._model is None or self._tokenizer is None:
            raise RuntimeError("Model/tokenizer chưa khởi tạo, hãy gọi run() trước.")
        prompt = (self.prefix + text) if self.prefix else text
        inputs = self._tokenizer(prompt, return_tensors="pt", truncation=True, max_length=self.Max_Input_Length)
        gen_len = max_new_tokens or self.Max_Target_Length
        outputs = self._model.generate(**inputs, max_new_tokens=gen_len, num_beams=self.num_beams)
        return self._tokenizer.decode(outputs[0], skip_special_tokens=True)

    # =========================================================
    # 6️⃣  Load lại Dataset Arrow
    # =========================================================
    @staticmethod
    def load_local_dataset(DataPath: str) -> DatasetDict:
        return load_from_disk(DataPath)