""" Fine-tuning script for Kat-Gen1 model """ import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling ) from datasets import load_dataset from typing import Optional class KatGen1Trainer: def __init__( self, model_name: str = "Katisim/Kat-Gen1", output_dir: str = "./kat-gen1-finetuned" ): """ Initialize the training setup. Args: model_name: Base model to fine-tune output_dir: Directory to save fine-tuned model """ self.model_name = model_name self.output_dir = output_dir self.model = None self.tokenizer = None def load_model(self): """Load model and tokenizer.""" self.model = AutoModelForCausalLM.from_pretrained(self.model_name) self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.model.config.pad_token_id = self.tokenizer.pad_token_id def prepare_dataset( self, dataset_name: str, text_column: str = "text", max_length: int = 512 ): """ Prepare dataset for training. Args: dataset_name: Name of dataset from HuggingFace Hub text_column: Column name containing text data max_length: Maximum sequence length Returns: Tokenized dataset """ dataset = load_dataset(dataset_name) def tokenize_function(examples): return self.tokenizer( examples[text_column], truncation=True, max_length=max_length, padding="max_length" ) tokenized_dataset = dataset.map( tokenize_function, batched=True, remove_columns=dataset["train"].column_names ) return tokenized_dataset def train( self, train_dataset, eval_dataset: Optional = None, num_train_epochs: int = 3, per_device_train_batch_size: int = 4, per_device_eval_batch_size: int = 4, learning_rate: float = 5e-5, warmup_steps: int = 500, weight_decay: float = 0.01, logging_steps: int = 100, save_steps: int = 1000, eval_steps: int = 500 ): """ Fine-tune the model. Args: train_dataset: Training dataset eval_dataset: Evaluation dataset (optional) num_train_epochs: Number of training epochs per_device_train_batch_size: Training batch size per device per_device_eval_batch_size: Evaluation batch size per device learning_rate: Learning rate warmup_steps: Number of warmup steps weight_decay: Weight decay coefficient logging_steps: Log every N steps save_steps: Save checkpoint every N steps eval_steps: Evaluate every N steps """ training_args = TrainingArguments( output_dir=self.output_dir, num_train_epochs=num_train_epochs, per_device_train_batch_size=per_device_train_batch_size, per_device_eval_batch_size=per_device_eval_batch_size, learning_rate=learning_rate, warmup_steps=warmup_steps, weight_decay=weight_decay, logging_dir=f"{self.output_dir}/logs", logging_steps=logging_steps, save_steps=save_steps, eval_steps=eval_steps if eval_dataset else None, evaluation_strategy="steps" if eval_dataset else "no", save_total_limit=3, fp16=torch.cuda.is_available(), gradient_accumulation_steps=4, load_best_model_at_end=True if eval_dataset else False ) data_collator = DataCollatorForLanguageModeling( tokenizer=self.tokenizer, mlm=False ) trainer = Trainer( model=self.model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator ) trainer.train() trainer.save_model(self.output_dir) self.tokenizer.save_pretrained(self.output_dir) def main(): """Example training workflow.""" trainer = KatGen1Trainer(output_dir="./kat-gen1-custom") trainer.load_model() # Load and prepare your dataset # dataset = trainer.prepare_dataset("your_dataset_name") # trainer.train( # train_dataset=dataset["train"], # eval_dataset=dataset["validation"] # ) print("Training setup complete. Uncomment dataset loading to begin training.") if __name__ == "__main__": main()