seallm-khm-sum / train_seallm_khm_sum.py
lightita's picture
Update train_seallm_khm_sum.py
9a5ab5c verified
import os
import torch
from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling,
BitsAndBytesConfig,
)
from peft import LoraConfig, get_peft_model
MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B"
DATASET_NAME = "bltlab/lr-sum"
DATASET_CONFIG = "khm"
def load_khm_dataset():
raw = load_dataset(DATASET_NAME, DATASET_CONFIG)
# Try standard splits first
if "train" in raw:
train = raw["train"]
if "validation" in raw:
eval_ds = raw["validation"]
elif "test" in raw:
eval_ds = raw["test"]
else:
split = train.train_test_split(test_size=0.05, seed=42)
train, eval_ds = split["train"], split["test"]
else:
# Some subsets only have 'test'; split that
split = raw["test"].train_test_split(test_size=0.1, seed=42)
train, eval_ds = split["train"], split["test"]
def format_example(example):
article = example["text"]
summary = example["summary"]
# Simple Khmer instruction-style format
text = (
"αžŸαžΌαž˜αžŸαž„αŸ’αžαŸαž”αž’αžαŸ’αžαž”αž‘αžαžΆαž„αž€αŸ’αžšαŸ„αž˜αž‡αžΆαž—αžΆαžŸαžΆαžαŸ’αž˜αŸ‚αžšαŸ–\n\n"
f"{article}\n\n"
"αžŸαŸαž…αž€αŸ’αžαžΈαžŸαž„αŸ’αžαŸαž”αŸ– "
f"{summary}"
)
return {"text": text}
cols_to_remove = list(train.features)
train = train.map(
format_example,
remove_columns=cols_to_remove,
desc="Formatting train set",
)
eval_ds = eval_ds.map(
format_example,
remove_columns=cols_to_remove,
desc="Formatting eval set",
)
return train, eval_ds
def load_model_and_tokenizer():
# QLoRA 4-bit config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
)
# Disable gradient checkpointing; old transformers breaks autograd here
# model.gradient_checkpointing_enable()
return model, tokenizer
def main():
train_ds, eval_ds = load_khm_dataset()
model, tokenizer = load_model_and_tokenizer()
# Apply LoRA to the model
lora_config = LoraConfig(
r=64,
lora_alpha=16,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
# Tokenize datasets
max_length = 1024
def tokenize_function(batch):
out = tokenizer(
batch["text"],
max_length=max_length,
truncation=True,
padding="max_length",
)
# Causal LM: labels = input_ids
out["labels"] = out["input_ids"].copy()
return out
train_tokenized = train_ds.map(
tokenize_function,
batched=True,
remove_columns=["text"],
desc="Tokenizing train set",
)
eval_tokenized = eval_ds.map(
tokenize_function,
batched=True,
remove_columns=["text"],
desc="Tokenizing eval set",
)
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False,
)
training_args = TrainingArguments(
output_dir="seallm-khm-sum-lora",
num_train_epochs=2,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
gradient_accumulation_steps=8,
learning_rate=2e-4,
logging_steps=10,
save_steps=200,
save_total_limit=2,
lr_scheduler_type="cosine",
warmup_ratio=0.03,
fp16=False, # turn off mixed precision for CPU
report_to="none",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_tokenized,
eval_dataset=eval_tokenized,
data_collator=data_collator,
)
trainer.train()
# Save LoRA adapter + tokenizer
model.save_pretrained("seallm-khm-sum-lora")
tokenizer.save_pretrained("seallm-khm-sum-lora")
repo_id = os.environ.get("OUTPUT_REPO_ID", "")
if repo_id:
model.push_to_hub(repo_id)
tokenizer.push_to_hub(repo_id)
if __name__ == "__main__":
main()