Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| from datasets import load_dataset | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| TrainingArguments, | |
| Trainer, | |
| DataCollatorForLanguageModeling, | |
| BitsAndBytesConfig, | |
| ) | |
| from peft import LoraConfig, get_peft_model | |
| MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B" | |
| DATASET_NAME = "bltlab/lr-sum" | |
| DATASET_CONFIG = "khm" | |
| def load_khm_dataset(): | |
| raw = load_dataset(DATASET_NAME, DATASET_CONFIG) | |
| # Try standard splits first | |
| if "train" in raw: | |
| train = raw["train"] | |
| if "validation" in raw: | |
| eval_ds = raw["validation"] | |
| elif "test" in raw: | |
| eval_ds = raw["test"] | |
| else: | |
| split = train.train_test_split(test_size=0.05, seed=42) | |
| train, eval_ds = split["train"], split["test"] | |
| else: | |
| # Some subsets only have 'test'; split that | |
| split = raw["test"].train_test_split(test_size=0.1, seed=42) | |
| train, eval_ds = split["train"], split["test"] | |
| def format_example(example): | |
| article = example["text"] | |
| summary = example["summary"] | |
| # Simple Khmer instruction-style format | |
| text = ( | |
| "ααΌαααααααα’αααααααΆααααααααΆααΆααΆαααααα\n\n" | |
| f"{article}\n\n" | |
| "ααα ααααΈααααααα " | |
| f"{summary}" | |
| ) | |
| return {"text": text} | |
| cols_to_remove = list(train.features) | |
| train = train.map( | |
| format_example, | |
| remove_columns=cols_to_remove, | |
| desc="Formatting train set", | |
| ) | |
| eval_ds = eval_ds.map( | |
| format_example, | |
| remove_columns=cols_to_remove, | |
| desc="Formatting eval set", | |
| ) | |
| return train, eval_ds | |
| def load_model_and_tokenizer(): | |
| # QLoRA 4-bit config | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_NAME, | |
| trust_remote_code=True, | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| # Disable gradient checkpointing; old transformers breaks autograd here | |
| # model.gradient_checkpointing_enable() | |
| return model, tokenizer | |
| def main(): | |
| train_ds, eval_ds = load_khm_dataset() | |
| model, tokenizer = load_model_and_tokenizer() | |
| # Apply LoRA to the model | |
| lora_config = LoraConfig( | |
| r=64, | |
| lora_alpha=16, | |
| lora_dropout=0.05, | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| ) | |
| model = get_peft_model(model, lora_config) | |
| # Tokenize datasets | |
| max_length = 1024 | |
| def tokenize_function(batch): | |
| out = tokenizer( | |
| batch["text"], | |
| max_length=max_length, | |
| truncation=True, | |
| padding="max_length", | |
| ) | |
| # Causal LM: labels = input_ids | |
| out["labels"] = out["input_ids"].copy() | |
| return out | |
| train_tokenized = train_ds.map( | |
| tokenize_function, | |
| batched=True, | |
| remove_columns=["text"], | |
| desc="Tokenizing train set", | |
| ) | |
| eval_tokenized = eval_ds.map( | |
| tokenize_function, | |
| batched=True, | |
| remove_columns=["text"], | |
| desc="Tokenizing eval set", | |
| ) | |
| data_collator = DataCollatorForLanguageModeling( | |
| tokenizer=tokenizer, | |
| mlm=False, | |
| ) | |
| training_args = TrainingArguments( | |
| output_dir="seallm-khm-sum-lora", | |
| num_train_epochs=2, | |
| per_device_train_batch_size=2, | |
| per_device_eval_batch_size=2, | |
| gradient_accumulation_steps=8, | |
| learning_rate=2e-4, | |
| logging_steps=10, | |
| save_steps=200, | |
| save_total_limit=2, | |
| lr_scheduler_type="cosine", | |
| warmup_ratio=0.03, | |
| fp16=False, # turn off mixed precision for CPU | |
| report_to="none", | |
| ) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_tokenized, | |
| eval_dataset=eval_tokenized, | |
| data_collator=data_collator, | |
| ) | |
| trainer.train() | |
| # Save LoRA adapter + tokenizer | |
| model.save_pretrained("seallm-khm-sum-lora") | |
| tokenizer.save_pretrained("seallm-khm-sum-lora") | |
| repo_id = os.environ.get("OUTPUT_REPO_ID", "") | |
| if repo_id: | |
| model.push_to_hub(repo_id) | |
| tokenizer.push_to_hub(repo_id) | |
| if __name__ == "__main__": | |
| main() | |