Spaces:
Sleeping
Sleeping
File size: 9,506 Bytes
dbe2c62 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 | import os
import numpy as np
import pandas as pd
import json
from typing import Optional, Union
import evaluate
from datasets import Dataset, DatasetDict, load_from_disk
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
DataCollatorForSeq2Seq,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
EarlyStoppingCallback,
set_seed,
)
class SummarizationTrainer:
"""
Fine-tune mô hình tóm tắt (Seq2Seq) đa dụng — thống nhất interface:
run(Checkpoint, ModelPath, DataPath | dataset, tokenizer)
"""
def __init__(
self,
Max_Input_Length: int = 1024,
Max_Target_Length: int = 256,
prefix: str = "",
input_column: str = "article",
target_column: str = "summary",
Learning_Rate: float = 3e-5,
Weight_Decay: float = 0.01,
Batch_Size: int = 8,
Num_Train_Epochs: int = 3,
gradient_accumulation_steps: int = 1,
warmup_ratio: float = 0.05,
lr_scheduler_type: str = "linear",
seed: int = 42,
num_beams: int = 4,
generation_max_length: Optional[int] = None,
fp16: bool = True,
early_stopping_patience: int = 2,
logging_steps: int = 200,
report_to: str = "none",
):
# Hyperparams
self.Max_Input_Length = Max_Input_Length
self.Max_Target_Length = Max_Target_Length
self.prefix = prefix
self.input_column = input_column
self.target_column = target_column
self.Learning_Rate = Learning_Rate
self.Weight_Decay = Weight_Decay
self.Batch_Size = Batch_Size
self.Num_Train_Epochs = Num_Train_Epochs
self.gradient_accumulation_steps = gradient_accumulation_steps
self.warmup_ratio = warmup_ratio
self.lr_scheduler_type = lr_scheduler_type
self.seed = seed
self.num_beams = num_beams
self.generation_max_length = generation_max_length
self.fp16 = fp16
self.early_stopping_patience = early_stopping_patience
self.logging_steps = logging_steps
self.report_to = report_to
self._rouge = evaluate.load("rouge")
self._tokenizer = None
self._model = None
# =========================================================
# 1️⃣ Đọc dữ liệu JSONL hoặc Arrow
# =========================================================
def _load_jsonl_to_datasetdict(self, DataPath: str) -> DatasetDict:
print(f"Đang tải dữ liệu từ {DataPath} ...")
data_list = []
with open(DataPath, "r", encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
try:
data_list.append(json.loads(line))
except json.JSONDecodeError:
continue
df = pd.DataFrame(data_list)
if self.input_column not in df or self.target_column not in df:
raise ValueError(f"File {DataPath} thiếu cột {self.input_column}/{self.target_column}")
df = df[[self.input_column, self.target_column]].dropna()
dataset = Dataset.from_pandas(df, preserve_index=False)
split = dataset.train_test_split(test_size=0.1, seed=self.seed)
print(f"✔ Dữ liệu chia: {len(split['train'])} train / {len(split['test'])} validation")
return DatasetDict({"train": split["train"], "validation": split["test"]})
def _ensure_datasetdict(self, dataset: Optional[Union[Dataset, DatasetDict]], DataPath: Optional[str]) -> DatasetDict:
if dataset is not None:
if isinstance(dataset, DatasetDict):
return dataset
if isinstance(dataset, Dataset):
split = dataset.train_test_split(test_size=0.1, seed=self.seed)
return DatasetDict({"train": split["train"], "validation": split["test"]})
raise TypeError("dataset phải là datasets.Dataset hoặc datasets.DatasetDict.")
if DataPath:
if os.path.isdir(DataPath):
print(f"Load DatasetDict từ thư mục Arrow: {DataPath}")
return load_from_disk(DataPath)
return self._load_jsonl_to_datasetdict(DataPath)
raise ValueError("Cần truyền dataset hoặc DataPath")
# =========================================================
# 2️⃣ Token hóa
# =========================================================
def _preprocess_function(self, examples):
inputs = examples[self.input_column]
if self.prefix:
inputs = [self.prefix + x for x in inputs]
model_inputs = self._tokenizer(inputs, max_length=self.Max_Input_Length, truncation=True)
with self._tokenizer.as_target_tokenizer():
labels = self._tokenizer(examples[self.target_column], max_length=self.Max_Target_Length, truncation=True)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
# =========================================================
# 3️⃣ Tính điểm ROUGE
# =========================================================
def _compute_metrics(self, eval_pred):
preds, labels = eval_pred
decoded_preds = self._tokenizer.batch_decode(preds, skip_special_tokens=True)
labels = np.where(labels != -100, labels, self._tokenizer.pad_token_id)
decoded_labels = self._tokenizer.batch_decode(labels, skip_special_tokens=True)
result = self._rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
return {k: round(v * 100, 4) for k, v in result.items()}
# =========================================================
# 4️⃣ Chạy huấn luyện
# =========================================================
def run(
self,
Checkpoint: str,
ModelPath: str,
DataPath: Optional[str] = None,
dataset: Optional[Union[Dataset, DatasetDict]] = None,
tokenizer: Optional[AutoTokenizer] = None,
):
set_seed(self.seed)
ds = self._ensure_datasetdict(dataset, DataPath)
self._tokenizer = tokenizer or AutoTokenizer.from_pretrained(Checkpoint)
print(f"Tải model checkpoint: {Checkpoint}")
self._model = AutoModelForSeq2SeqLM.from_pretrained(Checkpoint)
print("Tokenizing dữ liệu ...")
tokenized = ds.map(self._preprocess_function, batched=True)
data_collator = DataCollatorForSeq2Seq(tokenizer=self._tokenizer, model=self._model)
gen_max_len = self.generation_max_length or self.Max_Target_Length
training_args = Seq2SeqTrainingArguments(
output_dir=ModelPath,
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=self.Learning_Rate,
per_device_train_batch_size=self.Batch_Size,
per_device_eval_batch_size=self.Batch_Size,
weight_decay=self.Weight_Decay,
num_train_epochs=self.Num_Train_Epochs,
predict_with_generate=True,
generation_max_length=gen_max_len,
generation_num_beams=self.num_beams,
fp16=self.fp16,
gradient_accumulation_steps=self.gradient_accumulation_steps,
warmup_ratio=self.warmup_ratio,
lr_scheduler_type=self.lr_scheduler_type,
logging_steps=self.logging_steps,
load_best_model_at_end=True,
metric_for_best_model="rougeL",
greater_is_better=True,
save_total_limit=3,
report_to=self.report_to,
)
trainer = Seq2SeqTrainer(
model=self._model,
args=training_args,
train_dataset=tokenized["train"],
eval_dataset=tokenized["validation"],
tokenizer=self._tokenizer,
data_collator=data_collator,
compute_metrics=self._compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=self.early_stopping_patience)],
)
print("\n🚀 BẮT ĐẦU HUẤN LUYỆN ...")
trainer.train()
print("✅ HUẤN LUYỆN HOÀN TẤT.")
trainer.save_model(ModelPath)
self._tokenizer.save_pretrained(ModelPath)
print(f"💾 Đã lưu model & tokenizer tại: {ModelPath}")
return trainer
# =========================================================
# 5️⃣ Sinh tóm tắt
# =========================================================
def generate(self, text: str, max_new_tokens: Optional[int] = None) -> str:
if self._model is None or self._tokenizer is None:
raise RuntimeError("Model/tokenizer chưa khởi tạo, hãy gọi run() trước.")
prompt = (self.prefix + text) if self.prefix else text
inputs = self._tokenizer(prompt, return_tensors="pt", truncation=True, max_length=self.Max_Input_Length)
gen_len = max_new_tokens or self.Max_Target_Length
outputs = self._model.generate(**inputs, max_new_tokens=gen_len, num_beams=self.num_beams)
return self._tokenizer.decode(outputs[0], skip_special_tokens=True)
# =========================================================
# 6️⃣ Load lại Dataset Arrow
# =========================================================
@staticmethod
def load_local_dataset(DataPath: str) -> DatasetDict:
return load_from_disk(DataPath)
|