Trouter-Library commited on
Commit
5d950b8
·
verified ·
1 Parent(s): 33fc213

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +163 -0
train.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fine-tuning script for Kat-Gen1 model
3
+ """
4
+
5
+ import torch
6
+ from transformers import (
7
+ AutoModelForCausalLM,
8
+ AutoTokenizer,
9
+ Trainer,
10
+ TrainingArguments,
11
+ DataCollatorForLanguageModeling
12
+ )
13
+ from datasets import load_dataset
14
+ from typing import Optional
15
+
16
+
17
+ class KatGen1Trainer:
18
+ def __init__(
19
+ self,
20
+ model_name: str = "Katisim/Kat-Gen1",
21
+ output_dir: str = "./kat-gen1-finetuned"
22
+ ):
23
+ """
24
+ Initialize the training setup.
25
+
26
+ Args:
27
+ model_name: Base model to fine-tune
28
+ output_dir: Directory to save fine-tuned model
29
+ """
30
+ self.model_name = model_name
31
+ self.output_dir = output_dir
32
+ self.model = None
33
+ self.tokenizer = None
34
+
35
+ def load_model(self):
36
+ """Load model and tokenizer."""
37
+ self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
38
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
39
+
40
+ if self.tokenizer.pad_token is None:
41
+ self.tokenizer.pad_token = self.tokenizer.eos_token
42
+ self.model.config.pad_token_id = self.tokenizer.pad_token_id
43
+
44
+ def prepare_dataset(
45
+ self,
46
+ dataset_name: str,
47
+ text_column: str = "text",
48
+ max_length: int = 512
49
+ ):
50
+ """
51
+ Prepare dataset for training.
52
+
53
+ Args:
54
+ dataset_name: Name of dataset from HuggingFace Hub
55
+ text_column: Column name containing text data
56
+ max_length: Maximum sequence length
57
+
58
+ Returns:
59
+ Tokenized dataset
60
+ """
61
+ dataset = load_dataset(dataset_name)
62
+
63
+ def tokenize_function(examples):
64
+ return self.tokenizer(
65
+ examples[text_column],
66
+ truncation=True,
67
+ max_length=max_length,
68
+ padding="max_length"
69
+ )
70
+
71
+ tokenized_dataset = dataset.map(
72
+ tokenize_function,
73
+ batched=True,
74
+ remove_columns=dataset["train"].column_names
75
+ )
76
+
77
+ return tokenized_dataset
78
+
79
+ def train(
80
+ self,
81
+ train_dataset,
82
+ eval_dataset: Optional = None,
83
+ num_train_epochs: int = 3,
84
+ per_device_train_batch_size: int = 4,
85
+ per_device_eval_batch_size: int = 4,
86
+ learning_rate: float = 5e-5,
87
+ warmup_steps: int = 500,
88
+ weight_decay: float = 0.01,
89
+ logging_steps: int = 100,
90
+ save_steps: int = 1000,
91
+ eval_steps: int = 500
92
+ ):
93
+ """
94
+ Fine-tune the model.
95
+
96
+ Args:
97
+ train_dataset: Training dataset
98
+ eval_dataset: Evaluation dataset (optional)
99
+ num_train_epochs: Number of training epochs
100
+ per_device_train_batch_size: Training batch size per device
101
+ per_device_eval_batch_size: Evaluation batch size per device
102
+ learning_rate: Learning rate
103
+ warmup_steps: Number of warmup steps
104
+ weight_decay: Weight decay coefficient
105
+ logging_steps: Log every N steps
106
+ save_steps: Save checkpoint every N steps
107
+ eval_steps: Evaluate every N steps
108
+ """
109
+ training_args = TrainingArguments(
110
+ output_dir=self.output_dir,
111
+ num_train_epochs=num_train_epochs,
112
+ per_device_train_batch_size=per_device_train_batch_size,
113
+ per_device_eval_batch_size=per_device_eval_batch_size,
114
+ learning_rate=learning_rate,
115
+ warmup_steps=warmup_steps,
116
+ weight_decay=weight_decay,
117
+ logging_dir=f"{self.output_dir}/logs",
118
+ logging_steps=logging_steps,
119
+ save_steps=save_steps,
120
+ eval_steps=eval_steps if eval_dataset else None,
121
+ evaluation_strategy="steps" if eval_dataset else "no",
122
+ save_total_limit=3,
123
+ fp16=torch.cuda.is_available(),
124
+ gradient_accumulation_steps=4,
125
+ load_best_model_at_end=True if eval_dataset else False
126
+ )
127
+
128
+ data_collator = DataCollatorForLanguageModeling(
129
+ tokenizer=self.tokenizer,
130
+ mlm=False
131
+ )
132
+
133
+ trainer = Trainer(
134
+ model=self.model,
135
+ args=training_args,
136
+ train_dataset=train_dataset,
137
+ eval_dataset=eval_dataset,
138
+ data_collator=data_collator
139
+ )
140
+
141
+ trainer.train()
142
+ trainer.save_model(self.output_dir)
143
+ self.tokenizer.save_pretrained(self.output_dir)
144
+
145
+
146
+ def main():
147
+ """Example training workflow."""
148
+ trainer = KatGen1Trainer(output_dir="./kat-gen1-custom")
149
+ trainer.load_model()
150
+
151
+ # Load and prepare your dataset
152
+ # dataset = trainer.prepare_dataset("your_dataset_name")
153
+
154
+ # trainer.train(
155
+ # train_dataset=dataset["train"],
156
+ # eval_dataset=dataset["validation"]
157
+ # )
158
+
159
+ print("Training setup complete. Uncomment dataset loading to begin training.")
160
+
161
+
162
+ if __name__ == "__main__":
163
+ main()