Spaces:
Sleeping
Sleeping
File size: 4,678 Bytes
8f20942 15236b2 a53b482 8f20942 a53b482 8f20942 a53b482 8f20942 a53b482 8f20942 a53b482 8f20942 a53b482 8f20942 a53b482 8f20942 9a5ab5c 8f20942 a53b482 9a5ab5c 8f20942 a53b482 8f20942 a53b482 8f20942 15236b2 8f20942 9a5ab5c 8f20942 9a5ab5c a53b482 8f20942 15236b2 a53b482 8f20942 a53b482 8f20942 a53b482 8f20942 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
import os
import torch
from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling,
BitsAndBytesConfig,
)
from peft import LoraConfig, get_peft_model
MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B"
DATASET_NAME = "bltlab/lr-sum"
DATASET_CONFIG = "khm"
def load_khm_dataset():
raw = load_dataset(DATASET_NAME, DATASET_CONFIG)
# Try standard splits first
if "train" in raw:
train = raw["train"]
if "validation" in raw:
eval_ds = raw["validation"]
elif "test" in raw:
eval_ds = raw["test"]
else:
split = train.train_test_split(test_size=0.05, seed=42)
train, eval_ds = split["train"], split["test"]
else:
# Some subsets only have 'test'; split that
split = raw["test"].train_test_split(test_size=0.1, seed=42)
train, eval_ds = split["train"], split["test"]
def format_example(example):
article = example["text"]
summary = example["summary"]
# Simple Khmer instruction-style format
text = (
"ααΌαααααααα’αααααααΆααααααααΆααΆααΆαααααα\n\n"
f"{article}\n\n"
"ααα
ααααΈααααααα "
f"{summary}"
)
return {"text": text}
cols_to_remove = list(train.features)
train = train.map(
format_example,
remove_columns=cols_to_remove,
desc="Formatting train set",
)
eval_ds = eval_ds.map(
format_example,
remove_columns=cols_to_remove,
desc="Formatting eval set",
)
return train, eval_ds
def load_model_and_tokenizer():
# QLoRA 4-bit config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
)
# Disable gradient checkpointing; old transformers breaks autograd here
# model.gradient_checkpointing_enable()
return model, tokenizer
def main():
train_ds, eval_ds = load_khm_dataset()
model, tokenizer = load_model_and_tokenizer()
# Apply LoRA to the model
lora_config = LoraConfig(
r=64,
lora_alpha=16,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
# Tokenize datasets
max_length = 1024
def tokenize_function(batch):
out = tokenizer(
batch["text"],
max_length=max_length,
truncation=True,
padding="max_length",
)
# Causal LM: labels = input_ids
out["labels"] = out["input_ids"].copy()
return out
train_tokenized = train_ds.map(
tokenize_function,
batched=True,
remove_columns=["text"],
desc="Tokenizing train set",
)
eval_tokenized = eval_ds.map(
tokenize_function,
batched=True,
remove_columns=["text"],
desc="Tokenizing eval set",
)
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False,
)
training_args = TrainingArguments(
output_dir="seallm-khm-sum-lora",
num_train_epochs=2,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
gradient_accumulation_steps=8,
learning_rate=2e-4,
logging_steps=10,
save_steps=200,
save_total_limit=2,
lr_scheduler_type="cosine",
warmup_ratio=0.03,
fp16=False, # turn off mixed precision for CPU
report_to="none",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_tokenized,
eval_dataset=eval_tokenized,
data_collator=data_collator,
)
trainer.train()
# Save LoRA adapter + tokenizer
model.save_pretrained("seallm-khm-sum-lora")
tokenizer.save_pretrained("seallm-khm-sum-lora")
repo_id = os.environ.get("OUTPUT_REPO_ID", "")
if repo_id:
model.push_to_hub(repo_id)
tokenizer.push_to_hub(repo_id)
if __name__ == "__main__":
main()
|