Spaces:
Sleeping
Sleeping
File size: 4,027 Bytes
8f20942 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import os
import torch
from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
)
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B"
DATASET_NAME = "bltlab/lr-sum"
DATASET_CONFIG = "khm"
def load_khm_dataset():
raw = load_dataset(DATASET_NAME, DATASET_CONFIG)
# Try to find train/validation; if not, split test
if "train" in raw:
train = raw["train"]
if "validation" in raw:
eval_ds = raw["validation"]
elif "test" in raw:
eval_ds = raw["test"]
else:
split = train.train_test_split(test_size=0.05, seed=42)
train, eval_ds = split["train"], split["test"]
else:
# Some LR-Sum subsets only have 'test'; we split that.
split = raw["test"].train_test_split(test_size=0.1, seed=42)
train, eval_ds = split["train"], split["test"]
def format_example(example):
article = example["text"]
summary = example["summary"]
# Simple Khmer instruction β Khmer summary
text = (
"ααΌαααααααα’αααααααΆααααααααΆααΆααΆαααααα\n\n"
f"{article}\n\n"
"ααα
ααααΈααααααα "
f"{summary}"
)
return {"text": text}
cols_to_remove = list(train.features)
train = train.map(
format_example,
remove_columns=cols_to_remove,
desc="Formatting train set",
)
eval_ds = eval_ds.map(
format_example,
remove_columns=cols_to_remove,
desc="Formatting eval set",
)
return train, eval_ds
def load_model_and_tokenizer():
# QLoRA 4-bit quantization config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
)
# Enable gradient checkpointing for memory
model.gradient_checkpointing_enable()
return model, tokenizer
def main():
train_ds, eval_ds = load_khm_dataset()
model, tokenizer = load_model_and_tokenizer()
lora_config = LoraConfig(
r=64,
lora_alpha=16,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
sft_config = SFTConfig(
output_dir="seallm-khm-sum-lora",
num_train_epochs=2,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
gradient_accumulation_steps=8,
learning_rate=2e-4,
logging_steps=10,
eval_strategy="steps",
eval_steps=200,
save_steps=200,
save_total_limit=2,
max_seq_length=1024,
packing=True,
lr_scheduler_type="cosine",
warmup_ratio=0.03,
bf16=True,
gradient_checkpointing=True,
report_to="none", # or "wandb" etc.
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_ds,
eval_dataset=eval_ds,
peft_config=lora_config,
args=sft_config,
dataset_text_field="text",
)
trainer.train()
# Save LoRA adapter and tokenizer
trainer.model.save_pretrained("seallm-khm-sum-lora")
tokenizer.save_pretrained("seallm-khm-sum-lora")
# Optionally push directly to the Hub (needs HF_TOKEN env)
repo_id = os.environ.get("OUTPUT_REPO_ID", "")
if repo_id:
trainer.model.push_to_hub(repo_id)
tokenizer.push_to_hub(repo_id)
if __name__ == "__main__":
main()
|