File size: 4,027 Bytes
8f20942
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os
import torch
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
)
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig

MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B"
DATASET_NAME = "bltlab/lr-sum"
DATASET_CONFIG = "khm"

def load_khm_dataset():
    raw = load_dataset(DATASET_NAME, DATASET_CONFIG)

    # Try to find train/validation; if not, split test
    if "train" in raw:
        train = raw["train"]
        if "validation" in raw:
            eval_ds = raw["validation"]
        elif "test" in raw:
            eval_ds = raw["test"]
        else:
            split = train.train_test_split(test_size=0.05, seed=42)
            train, eval_ds = split["train"], split["test"]
    else:
        # Some LR-Sum subsets only have 'test'; we split that.
        split = raw["test"].train_test_split(test_size=0.1, seed=42)
        train, eval_ds = split["train"], split["test"]

    def format_example(example):
        article = example["text"]
        summary = example["summary"]

        # Simple Khmer instruction β†’ Khmer summary
        text = (
            "αžŸαžΌαž˜αžŸαž„αŸ’αžαŸαž”αž’αžαŸ’αžαž”αž‘αžαžΆαž„αž€αŸ’αžšαŸ„αž˜αž‡αžΆαž—αžΆαžŸαžΆαžαŸ’αž˜αŸ‚αžšαŸ–\n\n"
            f"{article}\n\n"
            "αžŸαŸαž…αž€αŸ’αžαžΈαžŸαž„αŸ’αžαŸαž”αŸ– "
            f"{summary}"
        )
        return {"text": text}

    cols_to_remove = list(train.features)

    train = train.map(
        format_example,
        remove_columns=cols_to_remove,
        desc="Formatting train set",
    )
    eval_ds = eval_ds.map(
        format_example,
        remove_columns=cols_to_remove,
        desc="Formatting eval set",
    )

    return train, eval_ds


def load_model_and_tokenizer():
    # QLoRA 4-bit quantization config
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )

    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_NAME,
        trust_remote_code=True,
    )

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True,
    )

    # Enable gradient checkpointing for memory
    model.gradient_checkpointing_enable()

    return model, tokenizer


def main():
    train_ds, eval_ds = load_khm_dataset()
    model, tokenizer = load_model_and_tokenizer()

    lora_config = LoraConfig(
        r=64,
        lora_alpha=16,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
    )

    sft_config = SFTConfig(
        output_dir="seallm-khm-sum-lora",
        num_train_epochs=2,
        per_device_train_batch_size=2,
        per_device_eval_batch_size=2,
        gradient_accumulation_steps=8,
        learning_rate=2e-4,
        logging_steps=10,
        eval_strategy="steps",
        eval_steps=200,
        save_steps=200,
        save_total_limit=2,
        max_seq_length=1024,
        packing=True,
        lr_scheduler_type="cosine",
        warmup_ratio=0.03,
        bf16=True,
        gradient_checkpointing=True,
        report_to="none",  # or "wandb" etc.
    )

    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        peft_config=lora_config,
        args=sft_config,
        dataset_text_field="text",
    )

    trainer.train()

    # Save LoRA adapter and tokenizer
    trainer.model.save_pretrained("seallm-khm-sum-lora")
    tokenizer.save_pretrained("seallm-khm-sum-lora")

    # Optionally push directly to the Hub (needs HF_TOKEN env)
    repo_id = os.environ.get("OUTPUT_REPO_ID", "")
    if repo_id:
        trainer.model.push_to_hub(repo_id)
        tokenizer.push_to_hub(repo_id)


if __name__ == "__main__":
    main()