File size: 5,042 Bytes
5d950b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
"""
Fine-tuning script for Kat-Gen1 model
"""
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Trainer,
TrainingArguments,
DataCollatorForLanguageModeling
)
from datasets import load_dataset
from typing import Optional
class KatGen1Trainer:
def __init__(
self,
model_name: str = "Katisim/Kat-Gen1",
output_dir: str = "./kat-gen1-finetuned"
):
"""
Initialize the training setup.
Args:
model_name: Base model to fine-tune
output_dir: Directory to save fine-tuned model
"""
self.model_name = model_name
self.output_dir = output_dir
self.model = None
self.tokenizer = None
def load_model(self):
"""Load model and tokenizer."""
self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model.config.pad_token_id = self.tokenizer.pad_token_id
def prepare_dataset(
self,
dataset_name: str,
text_column: str = "text",
max_length: int = 512
):
"""
Prepare dataset for training.
Args:
dataset_name: Name of dataset from HuggingFace Hub
text_column: Column name containing text data
max_length: Maximum sequence length
Returns:
Tokenized dataset
"""
dataset = load_dataset(dataset_name)
def tokenize_function(examples):
return self.tokenizer(
examples[text_column],
truncation=True,
max_length=max_length,
padding="max_length"
)
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
remove_columns=dataset["train"].column_names
)
return tokenized_dataset
def train(
self,
train_dataset,
eval_dataset: Optional = None,
num_train_epochs: int = 3,
per_device_train_batch_size: int = 4,
per_device_eval_batch_size: int = 4,
learning_rate: float = 5e-5,
warmup_steps: int = 500,
weight_decay: float = 0.01,
logging_steps: int = 100,
save_steps: int = 1000,
eval_steps: int = 500
):
"""
Fine-tune the model.
Args:
train_dataset: Training dataset
eval_dataset: Evaluation dataset (optional)
num_train_epochs: Number of training epochs
per_device_train_batch_size: Training batch size per device
per_device_eval_batch_size: Evaluation batch size per device
learning_rate: Learning rate
warmup_steps: Number of warmup steps
weight_decay: Weight decay coefficient
logging_steps: Log every N steps
save_steps: Save checkpoint every N steps
eval_steps: Evaluate every N steps
"""
training_args = TrainingArguments(
output_dir=self.output_dir,
num_train_epochs=num_train_epochs,
per_device_train_batch_size=per_device_train_batch_size,
per_device_eval_batch_size=per_device_eval_batch_size,
learning_rate=learning_rate,
warmup_steps=warmup_steps,
weight_decay=weight_decay,
logging_dir=f"{self.output_dir}/logs",
logging_steps=logging_steps,
save_steps=save_steps,
eval_steps=eval_steps if eval_dataset else None,
evaluation_strategy="steps" if eval_dataset else "no",
save_total_limit=3,
fp16=torch.cuda.is_available(),
gradient_accumulation_steps=4,
load_best_model_at_end=True if eval_dataset else False
)
data_collator = DataCollatorForLanguageModeling(
tokenizer=self.tokenizer,
mlm=False
)
trainer = Trainer(
model=self.model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator
)
trainer.train()
trainer.save_model(self.output_dir)
self.tokenizer.save_pretrained(self.output_dir)
def main():
"""Example training workflow."""
trainer = KatGen1Trainer(output_dir="./kat-gen1-custom")
trainer.load_model()
# Load and prepare your dataset
# dataset = trainer.prepare_dataset("your_dataset_name")
# trainer.train(
# train_dataset=dataset["train"],
# eval_dataset=dataset["validation"]
# )
print("Training setup complete. Uncomment dataset loading to begin training.")
if __name__ == "__main__":
main() |