import os import torch from datasets import load_dataset from transformers import ( AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling, BitsAndBytesConfig, ) from peft import LoraConfig, get_peft_model MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B" DATASET_NAME = "bltlab/lr-sum" DATASET_CONFIG = "khm" def load_khm_dataset(): raw = load_dataset(DATASET_NAME, DATASET_CONFIG) # Try standard splits first if "train" in raw: train = raw["train"] if "validation" in raw: eval_ds = raw["validation"] elif "test" in raw: eval_ds = raw["test"] else: split = train.train_test_split(test_size=0.05, seed=42) train, eval_ds = split["train"], split["test"] else: # Some subsets only have 'test'; split that split = raw["test"].train_test_split(test_size=0.1, seed=42) train, eval_ds = split["train"], split["test"] def format_example(example): article = example["text"] summary = example["summary"] # Simple Khmer instruction-style format text = ( "សូមសង្ខេបអត្ថបទខាងក្រោមជាភាសាខ្មែរ៖\n\n" f"{article}\n\n" "សេចក្តីសង្ខេប៖ " f"{summary}" ) return {"text": text} cols_to_remove = list(train.features) train = train.map( format_example, remove_columns=cols_to_remove, desc="Formatting train set", ) eval_ds = eval_ds.map( format_example, remove_columns=cols_to_remove, desc="Formatting eval set", ) return train, eval_ds def load_model_and_tokenizer(): # QLoRA 4-bit config bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, ) tokenizer = AutoTokenizer.from_pretrained( MODEL_NAME, trust_remote_code=True, ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, quantization_config=bnb_config, device_map="auto", trust_remote_code=True, ) # Disable gradient checkpointing; old transformers breaks autograd here # model.gradient_checkpointing_enable() return model, tokenizer def main(): train_ds, eval_ds = load_khm_dataset() model, tokenizer = load_model_and_tokenizer() # Apply LoRA to the model lora_config = LoraConfig( r=64, lora_alpha=16, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", ) model = get_peft_model(model, lora_config) # Tokenize datasets max_length = 1024 def tokenize_function(batch): out = tokenizer( batch["text"], max_length=max_length, truncation=True, padding="max_length", ) # Causal LM: labels = input_ids out["labels"] = out["input_ids"].copy() return out train_tokenized = train_ds.map( tokenize_function, batched=True, remove_columns=["text"], desc="Tokenizing train set", ) eval_tokenized = eval_ds.map( tokenize_function, batched=True, remove_columns=["text"], desc="Tokenizing eval set", ) data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=False, ) training_args = TrainingArguments( output_dir="seallm-khm-sum-lora", num_train_epochs=2, per_device_train_batch_size=2, per_device_eval_batch_size=2, gradient_accumulation_steps=8, learning_rate=2e-4, logging_steps=10, save_steps=200, save_total_limit=2, lr_scheduler_type="cosine", warmup_ratio=0.03, fp16=False, # turn off mixed precision for CPU report_to="none", ) trainer = Trainer( model=model, args=training_args, train_dataset=train_tokenized, eval_dataset=eval_tokenized, data_collator=data_collator, ) trainer.train() # Save LoRA adapter + tokenizer model.save_pretrained("seallm-khm-sum-lora") tokenizer.save_pretrained("seallm-khm-sum-lora") repo_id = os.environ.get("OUTPUT_REPO_ID", "") if repo_id: model.push_to_hub(repo_id) tokenizer.push_to_hub(repo_id) if __name__ == "__main__": main()