Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| from datasets import load_dataset | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| BitsAndBytesConfig, | |
| ) | |
| from trl import SFTTrainer, SFTConfig | |
| from peft import LoraConfig | |
| MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B" | |
| DATASET_NAME = "bltlab/lr-sum" | |
| DATASET_CONFIG = "khm" | |
| def load_khm_dataset(): | |
| raw = load_dataset(DATASET_NAME, DATASET_CONFIG) | |
| # Try to find train/validation; if not, split test | |
| if "train" in raw: | |
| train = raw["train"] | |
| if "validation" in raw: | |
| eval_ds = raw["validation"] | |
| elif "test" in raw: | |
| eval_ds = raw["test"] | |
| else: | |
| split = train.train_test_split(test_size=0.05, seed=42) | |
| train, eval_ds = split["train"], split["test"] | |
| else: | |
| # Some LR-Sum subsets only have 'test'; we split that. | |
| split = raw["test"].train_test_split(test_size=0.1, seed=42) | |
| train, eval_ds = split["train"], split["test"] | |
| def format_example(example): | |
| article = example["text"] | |
| summary = example["summary"] | |
| # Simple Khmer instruction β Khmer summary | |
| text = ( | |
| "ααΌαααααααα’αααααααΆααααααααΆααΆααΆαααααα\n\n" | |
| f"{article}\n\n" | |
| "ααα ααααΈααααααα " | |
| f"{summary}" | |
| ) | |
| return {"text": text} | |
| cols_to_remove = list(train.features) | |
| train = train.map( | |
| format_example, | |
| remove_columns=cols_to_remove, | |
| desc="Formatting train set", | |
| ) | |
| eval_ds = eval_ds.map( | |
| format_example, | |
| remove_columns=cols_to_remove, | |
| desc="Formatting eval set", | |
| ) | |
| return train, eval_ds | |
| def load_model_and_tokenizer(): | |
| # QLoRA 4-bit quantization config | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_NAME, | |
| trust_remote_code=True, | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| # Enable gradient checkpointing for memory | |
| model.gradient_checkpointing_enable() | |
| return model, tokenizer | |
| def main(): | |
| train_ds, eval_ds = load_khm_dataset() | |
| model, tokenizer = load_model_and_tokenizer() | |
| lora_config = LoraConfig( | |
| r=64, | |
| lora_alpha=16, | |
| lora_dropout=0.05, | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| ) | |
| sft_config = SFTConfig( | |
| output_dir="seallm-khm-sum-lora", | |
| num_train_epochs=2, | |
| per_device_train_batch_size=2, | |
| per_device_eval_batch_size=2, | |
| gradient_accumulation_steps=8, | |
| learning_rate=2e-4, | |
| logging_steps=10, | |
| eval_strategy="steps", | |
| eval_steps=200, | |
| save_steps=200, | |
| save_total_limit=2, | |
| max_seq_length=1024, | |
| packing=True, | |
| lr_scheduler_type="cosine", | |
| warmup_ratio=0.03, | |
| bf16=True, | |
| gradient_checkpointing=True, | |
| report_to="none", # or "wandb" etc. | |
| ) | |
| trainer = SFTTrainer( | |
| model=model, | |
| tokenizer=tokenizer, | |
| train_dataset=train_ds, | |
| eval_dataset=eval_ds, | |
| peft_config=lora_config, | |
| args=sft_config, | |
| dataset_text_field="text", | |
| ) | |
| trainer.train() | |
| # Save LoRA adapter and tokenizer | |
| trainer.model.save_pretrained("seallm-khm-sum-lora") | |
| tokenizer.save_pretrained("seallm-khm-sum-lora") | |
| # Optionally push directly to the Hub (needs HF_TOKEN env) | |
| repo_id = os.environ.get("OUTPUT_REPO_ID", "") | |
| if repo_id: | |
| trainer.model.push_to_hub(repo_id) | |
| tokenizer.push_to_hub(repo_id) | |
| if __name__ == "__main__": | |
| main() | |