|
|
""" |
|
|
Fine-tuning script for Kat-Gen1 model |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from transformers import ( |
|
|
AutoModelForCausalLM, |
|
|
AutoTokenizer, |
|
|
Trainer, |
|
|
TrainingArguments, |
|
|
DataCollatorForLanguageModeling |
|
|
) |
|
|
from datasets import load_dataset |
|
|
from typing import Optional |
|
|
|
|
|
|
|
|
class KatGen1Trainer: |
|
|
def __init__( |
|
|
self, |
|
|
model_name: str = "Katisim/Kat-Gen1", |
|
|
output_dir: str = "./kat-gen1-finetuned" |
|
|
): |
|
|
""" |
|
|
Initialize the training setup. |
|
|
|
|
|
Args: |
|
|
model_name: Base model to fine-tune |
|
|
output_dir: Directory to save fine-tuned model |
|
|
""" |
|
|
self.model_name = model_name |
|
|
self.output_dir = output_dir |
|
|
self.model = None |
|
|
self.tokenizer = None |
|
|
|
|
|
def load_model(self): |
|
|
"""Load model and tokenizer.""" |
|
|
self.model = AutoModelForCausalLM.from_pretrained(self.model_name) |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) |
|
|
|
|
|
if self.tokenizer.pad_token is None: |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
self.model.config.pad_token_id = self.tokenizer.pad_token_id |
|
|
|
|
|
def prepare_dataset( |
|
|
self, |
|
|
dataset_name: str, |
|
|
text_column: str = "text", |
|
|
max_length: int = 512 |
|
|
): |
|
|
""" |
|
|
Prepare dataset for training. |
|
|
|
|
|
Args: |
|
|
dataset_name: Name of dataset from HuggingFace Hub |
|
|
text_column: Column name containing text data |
|
|
max_length: Maximum sequence length |
|
|
|
|
|
Returns: |
|
|
Tokenized dataset |
|
|
""" |
|
|
dataset = load_dataset(dataset_name) |
|
|
|
|
|
def tokenize_function(examples): |
|
|
return self.tokenizer( |
|
|
examples[text_column], |
|
|
truncation=True, |
|
|
max_length=max_length, |
|
|
padding="max_length" |
|
|
) |
|
|
|
|
|
tokenized_dataset = dataset.map( |
|
|
tokenize_function, |
|
|
batched=True, |
|
|
remove_columns=dataset["train"].column_names |
|
|
) |
|
|
|
|
|
return tokenized_dataset |
|
|
|
|
|
def train( |
|
|
self, |
|
|
train_dataset, |
|
|
eval_dataset: Optional = None, |
|
|
num_train_epochs: int = 3, |
|
|
per_device_train_batch_size: int = 4, |
|
|
per_device_eval_batch_size: int = 4, |
|
|
learning_rate: float = 5e-5, |
|
|
warmup_steps: int = 500, |
|
|
weight_decay: float = 0.01, |
|
|
logging_steps: int = 100, |
|
|
save_steps: int = 1000, |
|
|
eval_steps: int = 500 |
|
|
): |
|
|
""" |
|
|
Fine-tune the model. |
|
|
|
|
|
Args: |
|
|
train_dataset: Training dataset |
|
|
eval_dataset: Evaluation dataset (optional) |
|
|
num_train_epochs: Number of training epochs |
|
|
per_device_train_batch_size: Training batch size per device |
|
|
per_device_eval_batch_size: Evaluation batch size per device |
|
|
learning_rate: Learning rate |
|
|
warmup_steps: Number of warmup steps |
|
|
weight_decay: Weight decay coefficient |
|
|
logging_steps: Log every N steps |
|
|
save_steps: Save checkpoint every N steps |
|
|
eval_steps: Evaluate every N steps |
|
|
""" |
|
|
training_args = TrainingArguments( |
|
|
output_dir=self.output_dir, |
|
|
num_train_epochs=num_train_epochs, |
|
|
per_device_train_batch_size=per_device_train_batch_size, |
|
|
per_device_eval_batch_size=per_device_eval_batch_size, |
|
|
learning_rate=learning_rate, |
|
|
warmup_steps=warmup_steps, |
|
|
weight_decay=weight_decay, |
|
|
logging_dir=f"{self.output_dir}/logs", |
|
|
logging_steps=logging_steps, |
|
|
save_steps=save_steps, |
|
|
eval_steps=eval_steps if eval_dataset else None, |
|
|
evaluation_strategy="steps" if eval_dataset else "no", |
|
|
save_total_limit=3, |
|
|
fp16=torch.cuda.is_available(), |
|
|
gradient_accumulation_steps=4, |
|
|
load_best_model_at_end=True if eval_dataset else False |
|
|
) |
|
|
|
|
|
data_collator = DataCollatorForLanguageModeling( |
|
|
tokenizer=self.tokenizer, |
|
|
mlm=False |
|
|
) |
|
|
|
|
|
trainer = Trainer( |
|
|
model=self.model, |
|
|
args=training_args, |
|
|
train_dataset=train_dataset, |
|
|
eval_dataset=eval_dataset, |
|
|
data_collator=data_collator |
|
|
) |
|
|
|
|
|
trainer.train() |
|
|
trainer.save_model(self.output_dir) |
|
|
self.tokenizer.save_pretrained(self.output_dir) |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Example training workflow.""" |
|
|
trainer = KatGen1Trainer(output_dir="./kat-gen1-custom") |
|
|
trainer.load_model() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Training setup complete. Uncomment dataset loading to begin training.") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |