Kat-Gen1 / train.py
Trouter-Library's picture
Create train.py
5d950b8 verified
"""
Fine-tuning script for Kat-Gen1 model
"""
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Trainer,
TrainingArguments,
DataCollatorForLanguageModeling
)
from datasets import load_dataset
from typing import Optional
class KatGen1Trainer:
def __init__(
self,
model_name: str = "Katisim/Kat-Gen1",
output_dir: str = "./kat-gen1-finetuned"
):
"""
Initialize the training setup.
Args:
model_name: Base model to fine-tune
output_dir: Directory to save fine-tuned model
"""
self.model_name = model_name
self.output_dir = output_dir
self.model = None
self.tokenizer = None
def load_model(self):
"""Load model and tokenizer."""
self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model.config.pad_token_id = self.tokenizer.pad_token_id
def prepare_dataset(
self,
dataset_name: str,
text_column: str = "text",
max_length: int = 512
):
"""
Prepare dataset for training.
Args:
dataset_name: Name of dataset from HuggingFace Hub
text_column: Column name containing text data
max_length: Maximum sequence length
Returns:
Tokenized dataset
"""
dataset = load_dataset(dataset_name)
def tokenize_function(examples):
return self.tokenizer(
examples[text_column],
truncation=True,
max_length=max_length,
padding="max_length"
)
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
remove_columns=dataset["train"].column_names
)
return tokenized_dataset
def train(
self,
train_dataset,
eval_dataset: Optional = None,
num_train_epochs: int = 3,
per_device_train_batch_size: int = 4,
per_device_eval_batch_size: int = 4,
learning_rate: float = 5e-5,
warmup_steps: int = 500,
weight_decay: float = 0.01,
logging_steps: int = 100,
save_steps: int = 1000,
eval_steps: int = 500
):
"""
Fine-tune the model.
Args:
train_dataset: Training dataset
eval_dataset: Evaluation dataset (optional)
num_train_epochs: Number of training epochs
per_device_train_batch_size: Training batch size per device
per_device_eval_batch_size: Evaluation batch size per device
learning_rate: Learning rate
warmup_steps: Number of warmup steps
weight_decay: Weight decay coefficient
logging_steps: Log every N steps
save_steps: Save checkpoint every N steps
eval_steps: Evaluate every N steps
"""
training_args = TrainingArguments(
output_dir=self.output_dir,
num_train_epochs=num_train_epochs,
per_device_train_batch_size=per_device_train_batch_size,
per_device_eval_batch_size=per_device_eval_batch_size,
learning_rate=learning_rate,
warmup_steps=warmup_steps,
weight_decay=weight_decay,
logging_dir=f"{self.output_dir}/logs",
logging_steps=logging_steps,
save_steps=save_steps,
eval_steps=eval_steps if eval_dataset else None,
evaluation_strategy="steps" if eval_dataset else "no",
save_total_limit=3,
fp16=torch.cuda.is_available(),
gradient_accumulation_steps=4,
load_best_model_at_end=True if eval_dataset else False
)
data_collator = DataCollatorForLanguageModeling(
tokenizer=self.tokenizer,
mlm=False
)
trainer = Trainer(
model=self.model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator
)
trainer.train()
trainer.save_model(self.output_dir)
self.tokenizer.save_pretrained(self.output_dir)
def main():
"""Example training workflow."""
trainer = KatGen1Trainer(output_dir="./kat-gen1-custom")
trainer.load_model()
# Load and prepare your dataset
# dataset = trainer.prepare_dataset("your_dataset_name")
# trainer.train(
# train_dataset=dataset["train"],
# eval_dataset=dataset["validation"]
# )
print("Training setup complete. Uncomment dataset loading to begin training.")
if __name__ == "__main__":
main()