vedaco commited on
Commit
dbb535a
·
verified ·
1 Parent(s): 8f5ff41

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +339 -0
train.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow import keras
3
+ import numpy as np
4
+ import os
5
+ from typing import List, Tuple, Optional
6
+ from model import VedaProgrammingLLM, create_veda_model
7
+ from tokenizer import VedaTokenizer
8
+
9
+ class VedaTrainer:
10
+ """Trainer class for Veda Programming LLM"""
11
+
12
+ def __init__(
13
+ self,
14
+ data_path: str = "programming.txt",
15
+ vocab_size: int = 10000,
16
+ max_length: int = 256,
17
+ batch_size: int = 32,
18
+ model_size: str = "small"
19
+ ):
20
+ self.data_path = data_path
21
+ self.vocab_size = vocab_size
22
+ self.max_length = max_length
23
+ self.batch_size = batch_size
24
+ self.model_size = model_size
25
+
26
+ self.tokenizer = VedaTokenizer(vocab_size=vocab_size)
27
+ self.model: Optional[VedaProgrammingLLM] = None
28
+
29
+ def load_data(self) -> List[str]:
30
+ """Load programming data from file"""
31
+ if not os.path.exists(self.data_path):
32
+ print(f"Creating sample {self.data_path}...")
33
+ self._create_sample_data()
34
+
35
+ with open(self.data_path, 'r', encoding='utf-8') as f:
36
+ content = f.read()
37
+
38
+ # Split into code samples (by double newlines or function definitions)
39
+ samples = []
40
+ current_sample = []
41
+
42
+ for line in content.split('\n'):
43
+ if line.strip() == '' and current_sample:
44
+ samples.append('\n'.join(current_sample))
45
+ current_sample = []
46
+ else:
47
+ current_sample.append(line)
48
+
49
+ if current_sample:
50
+ samples.append('\n'.join(current_sample))
51
+
52
+ # Filter empty samples
53
+ samples = [s.strip() for s in samples if s.strip()]
54
+ print(f"Loaded {len(samples)} code samples")
55
+ return samples
56
+
57
+ def _create_sample_data(self):
58
+ """Create sample programming data"""
59
+ sample_code = '''
60
+ def hello_world():
61
+ print("Hello, World!")
62
+ return True
63
+
64
+ def fibonacci(n):
65
+ if n <= 1:
66
+ return n
67
+ return fibonacci(n-1) + fibonacci(n-2)
68
+
69
+ def factorial(n):
70
+ if n == 0:
71
+ return 1
72
+ return n * factorial(n-1)
73
+
74
+ class Calculator:
75
+ def __init__(self):
76
+ self.result = 0
77
+
78
+ def add(self, a, b):
79
+ self.result = a + b
80
+ return self.result
81
+
82
+ def subtract(self, a, b):
83
+ self.result = a - b
84
+ return self.result
85
+
86
+ def multiply(self, a, b):
87
+ self.result = a * b
88
+ return self.result
89
+
90
+ def divide(self, a, b):
91
+ if b != 0:
92
+ self.result = a / b
93
+ return self.result
94
+
95
+ def bubble_sort(arr):
96
+ n = len(arr)
97
+ for i in range(n):
98
+ for j in range(0, n-i-1):
99
+ if arr[j] > arr[j+1]:
100
+ arr[j], arr[j+1] = arr[j+1], arr[j]
101
+ return arr
102
+
103
+ def binary_search(arr, target):
104
+ left, right = 0, len(arr) - 1
105
+ while left <= right:
106
+ mid = (left + right) // 2
107
+ if arr[mid] == target:
108
+ return mid
109
+ elif arr[mid] < target:
110
+ left = mid + 1
111
+ else:
112
+ right = mid - 1
113
+ return -1
114
+
115
+ def quicksort(arr):
116
+ if len(arr) <= 1:
117
+ return arr
118
+ pivot = arr[len(arr) // 2]
119
+ left = [x for x in arr if x < pivot]
120
+ middle = [x for x in arr if x == pivot]
121
+ right = [x for x in arr if x > pivot]
122
+ return quicksort(left) + middle + quicksort(right)
123
+
124
+ class LinkedList:
125
+ def __init__(self):
126
+ self.head = None
127
+
128
+ def append(self, data):
129
+ new_node = Node(data)
130
+ if not self.head:
131
+ self.head = new_node
132
+ return
133
+ current = self.head
134
+ while current.next:
135
+ current = current.next
136
+ current.next = new_node
137
+
138
+ def merge_sort(arr):
139
+ if len(arr) <= 1:
140
+ return arr
141
+ mid = len(arr) // 2
142
+ left = merge_sort(arr[:mid])
143
+ right = merge_sort(arr[mid:])
144
+ return merge(left, right)
145
+
146
+ def is_palindrome(s):
147
+ s = s.lower().replace(" ", "")
148
+ return s == s[::-1]
149
+
150
+ def count_words(text):
151
+ words = text.split()
152
+ return len(words)
153
+
154
+ async def fetch_data(url):
155
+ async with aiohttp.ClientSession() as session:
156
+ async with session.get(url) as response:
157
+ return await response.json()
158
+
159
+ def read_file(filename):
160
+ with open(filename, 'r') as f:
161
+ return f.read()
162
+
163
+ def write_file(filename, content):
164
+ with open(filename, 'w') as f:
165
+ f.write(content)
166
+ '''
167
+ with open(self.data_path, 'w', encoding='utf-8') as f:
168
+ f.write(sample_code)
169
+ print(f"Created sample {self.data_path}")
170
+
171
+ def prepare_dataset(self, samples: List[str]) -> tf.data.Dataset:
172
+ """Prepare TensorFlow dataset for training"""
173
+ # Fit tokenizer
174
+ self.tokenizer.fit(samples)
175
+
176
+ # Encode all samples
177
+ all_tokens = []
178
+ for sample in samples:
179
+ tokens = self.tokenizer.encode(sample)
180
+ all_tokens.extend(tokens)
181
+
182
+ # Create sequences
183
+ sequences = []
184
+ for i in range(0, len(all_tokens) - self.max_length, self.max_length // 2):
185
+ seq = all_tokens[i:i + self.max_length + 1]
186
+ if len(seq) == self.max_length + 1:
187
+ sequences.append(seq)
188
+
189
+ if not sequences:
190
+ # Create padded sequences if not enough data
191
+ for sample in samples:
192
+ tokens = self.tokenizer.encode(sample, max_length=self.max_length + 1)
193
+ sequences.append(tokens)
194
+
195
+ print(f"Created {len(sequences)} training sequences")
196
+
197
+ # Convert to numpy arrays
198
+ sequences = np.array(sequences)
199
+
200
+ # Split into input and target
201
+ X = sequences[:, :-1]
202
+ y = sequences[:, 1:]
203
+
204
+ # Create dataset
205
+ dataset = tf.data.Dataset.from_tensor_slices((X, y))
206
+ dataset = dataset.shuffle(buffer_size=len(sequences))
207
+ dataset = dataset.batch(self.batch_size)
208
+ dataset = dataset.prefetch(tf.data.AUTOTUNE)
209
+
210
+ return dataset
211
+
212
+ def build_model(self):
213
+ """Build the Veda Programming model"""
214
+ self.model = create_veda_model(
215
+ vocab_size=self.tokenizer.vocabulary_size,
216
+ max_length=self.max_length,
217
+ model_size=self.model_size
218
+ )
219
+
220
+ # Compile model
221
+ optimizer = keras.optimizers.Adam(learning_rate=1e-4)
222
+ loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
223
+
224
+ self.model.compile(
225
+ optimizer=optimizer,
226
+ loss=loss_fn,
227
+ metrics=['accuracy']
228
+ )
229
+
230
+ # Build model with dummy input
231
+ dummy_input = tf.zeros((1, self.max_length), dtype=tf.int32)
232
+ self.model(dummy_input)
233
+
234
+ self.model.summary()
235
+ return self.model
236
+
237
+ def train(
238
+ self,
239
+ epochs: int = 10,
240
+ save_path: str = "veda_model"
241
+ ):
242
+ """Train the model"""
243
+ # Load and prepare data
244
+ samples = self.load_data()
245
+ dataset = self.prepare_dataset(samples)
246
+
247
+ # Build model
248
+ self.build_model()
249
+
250
+ # Callbacks
251
+ callbacks = [
252
+ keras.callbacks.ModelCheckpoint(
253
+ filepath=os.path.join(save_path, "model_checkpoint.keras"),
254
+ save_best_only=True,
255
+ monitor='loss'
256
+ ),
257
+ keras.callbacks.EarlyStopping(
258
+ monitor='loss',
259
+ patience=5,
260
+ restore_best_weights=True
261
+ ),
262
+ keras.callbacks.ReduceLROnPlateau(
263
+ monitor='loss',
264
+ factor=0.5,
265
+ patience=2
266
+ )
267
+ ]
268
+
269
+ # Create save directory
270
+ os.makedirs(save_path, exist_ok=True)
271
+
272
+ # Train
273
+ history = self.model.fit(
274
+ dataset,
275
+ epochs=epochs,
276
+ callbacks=callbacks
277
+ )
278
+
279
+ # Save final model and tokenizer
280
+ self.model.save_weights(os.path.join(save_path, "model_weights.h5"))
281
+ self.tokenizer.save(os.path.join(save_path, "tokenizer.json"))
282
+
283
+ # Save model config
284
+ config = self.model.get_config()
285
+ config['tokenizer_vocab_size'] = self.tokenizer.vocabulary_size
286
+
287
+ import json
288
+ with open(os.path.join(save_path, "config.json"), 'w') as f:
289
+ json.dump(config, f)
290
+
291
+ print(f"Model saved to {save_path}")
292
+ return history
293
+
294
+ def generate(
295
+ self,
296
+ prompt: str,
297
+ max_new_tokens: int = 100,
298
+ temperature: float = 0.7
299
+ ) -> str:
300
+ """Generate code from prompt"""
301
+ if self.model is None:
302
+ raise ValueError("Model not loaded. Train or load a model first.")
303
+
304
+ # Encode prompt
305
+ prompt_tokens = self.tokenizer.encode(prompt)
306
+
307
+ # Generate
308
+ generated_tokens = self.model.generate(
309
+ prompt_tokens,
310
+ max_new_tokens=max_new_tokens,
311
+ temperature=temperature
312
+ )
313
+
314
+ # Decode
315
+ generated_text = self.tokenizer.decode(generated_tokens)
316
+ return generated_text
317
+
318
+
319
+ def main():
320
+ """Main training function"""
321
+ trainer = VedaTrainer(
322
+ data_path="programming.txt",
323
+ vocab_size=10000,
324
+ max_length=256,
325
+ batch_size=16,
326
+ model_size="small"
327
+ )
328
+
329
+ # Train model
330
+ history = trainer.train(epochs=20, save_path="veda_model")
331
+
332
+ # Test generation
333
+ test_prompt = "def calculate"
334
+ generated = trainer.generate(test_prompt, max_new_tokens=50)
335
+ print(f"\nGenerated code:\n{generated}")
336
+
337
+
338
+ if __name__ == "__main__":
339
+ main()