File size: 4,155 Bytes
8f20942
 
 
 
 
 
 
15236b2
8f20942
15236b2
8f20942
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15236b2
 
8f20942
 
 
 
 
 
 
 
 
 
 
fdfb5e5
 
 
8f20942
 
fdfb5e5
8f20942
 
 
 
 
 
15236b2
8f20942
15236b2
 
8f20942
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
import torch
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    TrainingArguments,
)
from trl import SFTTrainer
from peft import LoraConfig

MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B"
DATASET_NAME = "bltlab/lr-sum"
DATASET_CONFIG = "khm"

def load_khm_dataset():
    raw = load_dataset(DATASET_NAME, DATASET_CONFIG)

    # Try to find train/validation; if not, split test
    if "train" in raw:
        train = raw["train"]
        if "validation" in raw:
            eval_ds = raw["validation"]
        elif "test" in raw:
            eval_ds = raw["test"]
        else:
            split = train.train_test_split(test_size=0.05, seed=42)
            train, eval_ds = split["train"], split["test"]
    else:
        # Some LR-Sum subsets only have 'test'; we split that.
        split = raw["test"].train_test_split(test_size=0.1, seed=42)
        train, eval_ds = split["train"], split["test"]

    def format_example(example):
        article = example["text"]
        summary = example["summary"]

        # Simple Khmer instruction β†’ Khmer summary
        text = (
            "αžŸαžΌαž˜αžŸαž„αŸ’αžαŸαž”αž’αžαŸ’αžαž”αž‘αžαžΆαž„αž€αŸ’αžšαŸ„αž˜αž‡αžΆαž—αžΆαžŸαžΆαžαŸ’αž˜αŸ‚αžšαŸ–\n\n"
            f"{article}\n\n"
            "αžŸαŸαž…αž€αŸ’αžαžΈαžŸαž„αŸ’αžαŸαž”αŸ– "
            f"{summary}"
        )
        return {"text": text}

    cols_to_remove = list(train.features)

    train = train.map(
        format_example,
        remove_columns=cols_to_remove,
        desc="Formatting train set",
    )
    eval_ds = eval_ds.map(
        format_example,
        remove_columns=cols_to_remove,
        desc="Formatting eval set",
    )

    return train, eval_ds


def load_model_and_tokenizer():
    # QLoRA 4-bit quantization config
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )

    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_NAME,
        trust_remote_code=True,
    )

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True,
    )

    # Enable gradient checkpointing for memory
    model.gradient_checkpointing_enable()

    return model, tokenizer

def main():
    train_ds, eval_ds = load_khm_dataset()
    model, tokenizer = load_model_and_tokenizer()

    lora_config = LoraConfig(
        r=64,
        lora_alpha=16,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
    )

    # Use standard TrainingArguments instead of SFTConfig
    training_args = TrainingArguments(
        output_dir="seallm-khm-sum-lora",
        num_train_epochs=2,
        per_device_train_batch_size=2,
        per_device_eval_batch_size=2,
        gradient_accumulation_steps=8,
        learning_rate=2e-4,
        logging_steps=10,
        save_steps=200,
        save_total_limit=2,
        lr_scheduler_type="cosine",
        warmup_ratio=0.03,
        # old transformers may not support bf16, so let's be safe:
        fp16=True,          # use fp16 instead of bf16
        report_to="none",   # if this errors next, we’ll drop it
    )


    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        peft_config=lora_config,
        args=training_args,
        dataset_text_field="text",
        max_seq_length=1024,    # set here instead of in config
        # packing=False,        # keep off for compatibility
    )

    trainer.train()

    # Save LoRA adapter and tokenizer
    trainer.model.save_pretrained("seallm-khm-sum-lora")
    tokenizer.save_pretrained("seallm-khm-sum-lora")

    repo_id = os.environ.get("OUTPUT_REPO_ID", "")
    if repo_id:
        trainer.model.push_to_hub(repo_id)
        tokenizer.push_to_hub(repo_id)


if __name__ == "__main__":
    main()