vedaco commited on
Commit
3096f4a
·
verified ·
1 Parent(s): 3fc35d3

Create continuous_trainer.py

Browse files
Files changed (1) hide show
  1. continuous_trainer.py +368 -0
continuous_trainer.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Continuous training system for Veda Programming LLM"""
2
+
3
+ import os
4
+ import json
5
+ import shutil
6
+ from datetime import datetime
7
+ from typing import Optional
8
+ import threading
9
+ import time
10
+
11
+ import tensorflow as tf
12
+ from tensorflow import keras
13
+
14
+ from model import VedaProgrammingLLM
15
+ from tokenizer import VedaTokenizer
16
+ from database import db
17
+ from data_collector import collector
18
+ from config import (
19
+ MODEL_DIR, VERSIONS_DIR, VOCAB_SIZE, MAX_LENGTH,
20
+ D_MODEL, NUM_HEADS, NUM_LAYERS, FF_DIM, BATCH_SIZE,
21
+ MIN_SAMPLES_FOR_TRAINING, EPOCHS_PER_RETRAIN,
22
+ AUTO_TRAIN_INTERVAL_HOURS
23
+ )
24
+
25
+ class ContinuousTrainer:
26
+ """Handles continuous learning and model updates"""
27
+
28
+ def __init__(self):
29
+ self.model: Optional[VedaProgrammingLLM] = None
30
+ self.tokenizer: Optional[VedaTokenizer] = None
31
+ self.is_training = False
32
+ self.training_progress = 0
33
+ self.last_training_time = None
34
+ self.model_version = self._get_current_version()
35
+
36
+ # Background training thread
37
+ self._training_thread = None
38
+ self._stop_background = False
39
+
40
+ def _get_current_version(self) -> str:
41
+ """Get current model version"""
42
+ config_path = os.path.join(MODEL_DIR, "config.json")
43
+ if os.path.exists(config_path):
44
+ with open(config_path, 'r') as f:
45
+ config = json.load(f)
46
+ return config.get('version', 'v1.0')
47
+ return 'v1.0'
48
+
49
+ def _generate_version(self) -> str:
50
+ """Generate new version string"""
51
+ return f"v{datetime.now().strftime('%Y%m%d_%H%M%S')}"
52
+
53
+ def load_model(self) -> bool:
54
+ """Load the current model"""
55
+ config_path = os.path.join(MODEL_DIR, "config.json")
56
+
57
+ if not os.path.exists(config_path):
58
+ print("No existing model found.")
59
+ return False
60
+
61
+ try:
62
+ # Load config
63
+ with open(config_path, 'r') as f:
64
+ config = json.load(f)
65
+
66
+ # Load tokenizer
67
+ self.tokenizer = VedaTokenizer()
68
+ self.tokenizer.load(os.path.join(MODEL_DIR, "tokenizer.json"))
69
+
70
+ # Create model
71
+ self.model = VedaProgrammingLLM(
72
+ vocab_size=config['vocab_size'],
73
+ max_length=config['max_length'],
74
+ d_model=config['d_model'],
75
+ num_heads=config['num_heads'],
76
+ num_layers=config['num_layers'],
77
+ ff_dim=config['ff_dim']
78
+ )
79
+
80
+ # Build and load weights
81
+ dummy = tf.zeros((1, config['max_length']), dtype=tf.int32)
82
+ self.model(dummy)
83
+ self.model.load_weights(os.path.join(MODEL_DIR, "weights.h5"))
84
+
85
+ self.model_version = config.get('version', 'v1.0')
86
+ print(f"Model loaded: {self.model_version}")
87
+ return True
88
+
89
+ except Exception as e:
90
+ print(f"Error loading model: {e}")
91
+ return False
92
+
93
+ def save_model(self, version: str = None):
94
+ """Save the current model"""
95
+ if self.model is None or self.tokenizer is None:
96
+ return
97
+
98
+ version = version or self._generate_version()
99
+
100
+ # Save to main directory
101
+ os.makedirs(MODEL_DIR, exist_ok=True)
102
+
103
+ self.model.save_weights(os.path.join(MODEL_DIR, "weights.h5"))
104
+ self.tokenizer.save(os.path.join(MODEL_DIR, "tokenizer.json"))
105
+
106
+ config = self.model.get_config()
107
+ config['version'] = version
108
+ config['last_trained'] = datetime.now().isoformat()
109
+
110
+ with open(os.path.join(MODEL_DIR, "config.json"), 'w') as f:
111
+ json.dump(config, f, indent=2)
112
+
113
+ # Save version backup
114
+ version_dir = os.path.join(VERSIONS_DIR, version)
115
+ os.makedirs(version_dir, exist_ok=True)
116
+
117
+ shutil.copy(
118
+ os.path.join(MODEL_DIR, "weights.h5"),
119
+ os.path.join(version_dir, "weights.h5")
120
+ )
121
+ shutil.copy(
122
+ os.path.join(MODEL_DIR, "tokenizer.json"),
123
+ os.path.join(version_dir, "tokenizer.json")
124
+ )
125
+ shutil.copy(
126
+ os.path.join(MODEL_DIR, "config.json"),
127
+ os.path.join(version_dir, "config.json")
128
+ )
129
+
130
+ self.model_version = version
131
+ print(f"Model saved: {version}")
132
+
133
+ def should_retrain(self) -> bool:
134
+ """Check if retraining is needed"""
135
+ pending = collector.get_pending_count()
136
+ return pending >= MIN_SAMPLES_FOR_TRAINING
137
+
138
+ def prepare_training_data(self) -> tf.data.Dataset:
139
+ """Prepare dataset for training"""
140
+ # Get all training samples
141
+ samples = collector.get_training_data(include_base=True)
142
+
143
+ if not samples:
144
+ return None
145
+
146
+ # Combine all samples
147
+ all_text = '\n\n'.join(samples)
148
+
149
+ # Fit or update tokenizer
150
+ if self.tokenizer is None:
151
+ self.tokenizer = VedaTokenizer(vocab_size=VOCAB_SIZE)
152
+
153
+ self.tokenizer.fit([all_text])
154
+
155
+ # Encode
156
+ all_tokens = self.tokenizer.encode(all_text)
157
+
158
+ # Create sequences
159
+ sequences = []
160
+ stride = MAX_LENGTH // 2
161
+
162
+ for i in range(0, len(all_tokens) - MAX_LENGTH - 1, stride):
163
+ seq = all_tokens[i:i + MAX_LENGTH + 1]
164
+ if len(seq) == MAX_LENGTH + 1:
165
+ sequences.append(seq)
166
+
167
+ if len(sequences) < 5:
168
+ stride = max(1, MAX_LENGTH // 8)
169
+ sequences = []
170
+ for i in range(0, len(all_tokens) - MAX_LENGTH - 1, stride):
171
+ seq = all_tokens[i:i + MAX_LENGTH + 1]
172
+ if len(seq) == MAX_LENGTH + 1:
173
+ sequences.append(seq)
174
+
175
+ import numpy as np
176
+ sequences = np.array(sequences)
177
+ X = sequences[:, :-1]
178
+ y = sequences[:, 1:]
179
+
180
+ dataset = tf.data.Dataset.from_tensor_slices((X, y))
181
+ dataset = dataset.shuffle(buffer_size=min(1000, len(sequences)))
182
+ dataset = dataset.batch(BATCH_SIZE)
183
+ dataset = dataset.prefetch(tf.data.AUTOTUNE)
184
+
185
+ print(f"Prepared {len(sequences)} training sequences")
186
+ return dataset
187
+
188
+ def train(
189
+ self,
190
+ epochs: int = EPOCHS_PER_RETRAIN,
191
+ callback=None
192
+ ) -> dict:
193
+ """Train/retrain the model"""
194
+ if self.is_training:
195
+ return {'status': 'error', 'message': 'Training already in progress'}
196
+
197
+ self.is_training = True
198
+ self.training_progress = 0
199
+
200
+ try:
201
+ # Prepare data
202
+ dataset = self.prepare_training_data()
203
+ if dataset is None:
204
+ self.is_training = False
205
+ return {'status': 'error', 'message': 'No training data available'}
206
+
207
+ # Create/update model
208
+ if self.model is None:
209
+ self.model = VedaProgrammingLLM(
210
+ vocab_size=self.tokenizer.vocabulary_size,
211
+ max_length=MAX_LENGTH,
212
+ d_model=D_MODEL,
213
+ num_heads=NUM_HEADS,
214
+ num_layers=NUM_LAYERS,
215
+ ff_dim=FF_DIM
216
+ )
217
+
218
+ # Compile
219
+ self.model.compile(
220
+ optimizer=keras.optimizers.Adam(learning_rate=1e-4),
221
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
222
+ metrics=['accuracy']
223
+ )
224
+
225
+ # Build
226
+ dummy = tf.zeros((1, MAX_LENGTH), dtype=tf.int32)
227
+ self.model(dummy)
228
+
229
+ # Custom callback for progress
230
+ class ProgressCallback(keras.callbacks.Callback):
231
+ def __init__(self, trainer, total_epochs):
232
+ self.trainer = trainer
233
+ self.total_epochs = total_epochs
234
+
235
+ def on_epoch_end(self, epoch, logs=None):
236
+ self.trainer.training_progress = (epoch + 1) / self.total_epochs * 100
237
+
238
+ callbacks = [ProgressCallback(self, epochs)]
239
+ if callback:
240
+ callbacks.append(callback)
241
+
242
+ # Train
243
+ history = self.model.fit(
244
+ dataset,
245
+ epochs=epochs,
246
+ callbacks=callbacks,
247
+ verbose=1
248
+ )
249
+
250
+ # Save model
251
+ new_version = self._generate_version()
252
+ self.save_model(new_version)
253
+
254
+ # Mark samples as used
255
+ new_samples = collector.get_new_training_data()
256
+ if new_samples:
257
+ sample_ids = [s['id'] for s in new_samples]
258
+ db.mark_as_used_for_training(sample_ids)
259
+
260
+ # Record training run
261
+ final_loss = history.history['loss'][-1]
262
+ final_acc = history.history.get('accuracy', [0])[-1]
263
+
264
+ samples_count = len(new_samples) if new_samples else 0
265
+ db.save_training_run(
266
+ samples_used=samples_count,
267
+ epochs=epochs,
268
+ final_loss=final_loss,
269
+ final_accuracy=final_acc,
270
+ model_version=new_version
271
+ )
272
+
273
+ self.last_training_time = datetime.now()
274
+ self.is_training = False
275
+ self.training_progress = 100
276
+
277
+ return {
278
+ 'status': 'success',
279
+ 'version': new_version,
280
+ 'loss': final_loss,
281
+ 'accuracy': final_acc,
282
+ 'samples_used': samples_count
283
+ }
284
+
285
+ except Exception as e:
286
+ self.is_training = False
287
+ import traceback
288
+ traceback.print_exc()
289
+ return {'status': 'error', 'message': str(e)}
290
+
291
+ def train_async(self, epochs: int = EPOCHS_PER_RETRAIN):
292
+ """Start training in background thread"""
293
+ if self.is_training:
294
+ return False
295
+
296
+ def train_thread():
297
+ result = self.train(epochs=epochs)
298
+ print(f"Background training completed: {result}")
299
+
300
+ self._training_thread = threading.Thread(target=train_thread)
301
+ self._training_thread.start()
302
+ return True
303
+
304
+ def start_auto_training(self):
305
+ """Start automatic retraining scheduler"""
306
+ def auto_train_loop():
307
+ while not self._stop_background:
308
+ # Check every hour
309
+ time.sleep(3600)
310
+
311
+ if self._stop_background:
312
+ break
313
+
314
+ # Check if retraining needed
315
+ if self.should_retrain():
316
+ print("Auto-training triggered...")
317
+ self.train()
318
+
319
+ self._stop_background = False
320
+ thread = threading.Thread(target=auto_train_loop, daemon=True)
321
+ thread.start()
322
+ print("Auto-training scheduler started")
323
+
324
+ def stop_auto_training(self):
325
+ """Stop automatic retraining"""
326
+ self._stop_background = True
327
+
328
+ def get_status(self) -> dict:
329
+ """Get trainer status"""
330
+ return {
331
+ 'model_loaded': self.model is not None,
332
+ 'model_version': self.model_version,
333
+ 'is_training': self.is_training,
334
+ 'training_progress': self.training_progress,
335
+ 'last_training': self.last_training_time.isoformat() if self.last_training_time else None,
336
+ 'pending_samples': collector.get_pending_count(),
337
+ 'min_samples_for_training': MIN_SAMPLES_FOR_TRAINING
338
+ }
339
+
340
+ def generate(
341
+ self,
342
+ prompt: str,
343
+ max_tokens: int = 100,
344
+ temperature: float = 0.7,
345
+ repetition_penalty: float = 1.2,
346
+ top_k: int = 50
347
+ ) -> str:
348
+ """Generate code using the model"""
349
+ if self.model is None or self.tokenizer is None:
350
+ raise ValueError("Model not loaded")
351
+
352
+ tokens = self.tokenizer.encode(prompt)
353
+ if len(tokens) == 0:
354
+ tokens = [ord(' ')]
355
+
356
+ generated = self.model.generate(
357
+ tokens,
358
+ max_new_tokens=max_tokens,
359
+ temperature=temperature,
360
+ top_k=top_k,
361
+ repetition_penalty=repetition_penalty
362
+ )
363
+
364
+ return self.tokenizer.decode(generated)
365
+
366
+
367
+ # Global trainer instance
368
+ trainer = ContinuousTrainer()