vedaco commited on
Commit
5740a72
·
verified ·
1 Parent(s): f162639

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +607 -74
train.py CHANGED
@@ -1,4 +1,4 @@
1
- """Training for Veda Programming Assistant"""
2
 
3
  import tensorflow as tf
4
  from tensorflow import keras
@@ -8,10 +8,12 @@ import json
8
 
9
  from model import VedaProgrammingLLM
10
  from tokenizer import VedaTokenizer
 
11
  from config import MODEL_DIR, VOCAB_SIZE, MAX_LENGTH, BATCH_SIZE
12
 
13
 
14
- TRAINING_DATA = '''
 
15
  <USER> Hello!
16
  <ASSISTANT> Hello! Welcome to Veda Programming Assistant! I can help you with coding, explain concepts, or answer programming questions. What would you like to learn today?
17
 
@@ -46,6 +48,8 @@ result = add(5, 3)
46
  print(result)
47
  <ENDCODE>
48
 
 
 
49
  This function takes two parameters and returns their sum.
50
 
51
  <USER> Explain what a function is
@@ -75,6 +79,8 @@ def greet(name):
75
  print(greet("Alice"))
76
  <ENDCODE>
77
 
 
 
78
  <USER> What is a list?
79
  <ASSISTANT> A list is an ordered, mutable collection in Python that can hold items of any type.
80
 
@@ -107,19 +113,23 @@ print(len(fruits))
107
  <USER> Write a fibonacci function
108
  <ASSISTANT> Here is a Fibonacci function:
109
 
 
 
110
  <CODE>
111
  def fibonacci(n):
112
  if n <= 0:
113
  return 0
114
  if n == 1:
115
  return 1
116
- return fibonacci(n-1) + fibonacci(n-2)
117
 
118
  for i in range(10):
119
  print(fibonacci(i), end=" ")
120
  <ENDCODE>
121
 
122
- Iterative version:
 
 
123
 
124
  <CODE>
125
  def fibonacci_iterative(n):
@@ -127,7 +137,6 @@ def fibonacci_iterative(n):
127
  return 0
128
  if n == 1:
129
  return 1
130
-
131
  a, b = 0, 1
132
  for _ in range(2, n + 1):
133
  a, b = b, a + b
@@ -141,7 +150,7 @@ Key components:
141
  1. Base case - stops the recursion
142
  2. Recursive case - function calls itself
143
 
144
- Example:
145
 
146
  <CODE>
147
  def countdown(n):
@@ -154,6 +163,8 @@ def countdown(n):
154
  countdown(5)
155
  <ENDCODE>
156
 
 
 
157
  Factorial example:
158
 
159
  <CODE>
@@ -165,32 +176,39 @@ def factorial(n):
165
  print(factorial(5))
166
  <ENDCODE>
167
 
 
 
168
  <USER> Write bubble sort
169
  <ASSISTANT> Here is the Bubble Sort algorithm:
170
 
171
  <CODE>
172
  def bubble_sort(arr):
173
  n = len(arr)
174
-
175
  for i in range(n):
176
  swapped = False
177
-
178
  for j in range(0, n - i - 1):
179
  if arr[j] > arr[j + 1]:
180
  arr[j], arr[j + 1] = arr[j + 1], arr[j]
181
  swapped = True
182
-
183
  if not swapped:
184
  break
185
-
186
  return arr
187
 
188
  numbers = [64, 34, 25, 12, 22, 11, 90]
 
189
  print("Sorted:", bubble_sort(numbers.copy()))
190
  <ENDCODE>
191
 
 
 
 
 
 
 
 
 
192
  <USER> Write binary search
193
- <ASSISTANT> Here is Binary Search:
194
 
195
  <CODE>
196
  def binary_search(arr, target):
@@ -211,10 +229,18 @@ def binary_search(arr, target):
211
 
212
  numbers = [1, 3, 5, 7, 9, 11, 13, 15]
213
  print(binary_search(numbers, 7))
 
214
  <ENDCODE>
215
 
 
 
 
 
 
 
 
216
  <USER> What is a class?
217
- <ASSISTANT> A class is a blueprint for creating objects:
218
 
219
  <CODE>
220
  class Dog:
@@ -224,26 +250,106 @@ class Dog:
224
 
225
  def bark(self):
226
  return f"{self.name} says Woof!"
 
 
 
227
 
228
  my_dog = Dog("Buddy", 3)
229
  print(my_dog.bark())
 
230
  <ENDCODE>
231
 
 
 
 
 
 
 
 
 
 
 
232
  <USER> How do I read a file?
233
- <ASSISTANT> Here is how to read files:
 
 
234
 
235
  <CODE>
236
  with open("filename.txt", "r") as file:
237
  content = file.read()
238
  print(content)
 
 
 
239
 
 
240
  with open("filename.txt", "r") as file:
241
  for line in file:
242
  print(line.strip())
243
  <ENDCODE>
244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  <USER> How do I handle errors?
246
- <ASSISTANT> Use try-except blocks:
 
 
247
 
248
  <CODE>
249
  try:
@@ -266,82 +372,480 @@ except Exception as e:
266
  print(f"Error: {e}")
267
  <ENDCODE>
268
 
269
- <USER> Thanks!
270
- <ASSISTANT> You are welcome! Feel free to ask me anything anytime! Happy coding!
271
 
272
- <USER> Hi
273
- <ASSISTANT> Hello! How can I help you today? I can write code, explain concepts, or answer your programming questions!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
  <USER> What is the difference between list and tuple?
276
  <ASSISTANT> The main difference is mutability:
277
- - Lists are mutable (can change)
278
- - Tuples are immutable (cannot change)
 
 
279
 
280
  <CODE>
281
  my_list = [1, 2, 3]
282
  my_tuple = (1, 2, 3)
 
 
 
283
 
 
 
284
  my_list[0] = 100
285
  my_list.append(4)
 
 
 
 
 
286
  <ENDCODE>
287
 
288
- Use lists for changing data, tuples for fixed data.
289
- '''
 
290
 
 
291
 
292
- class VedaTrainer:
293
- """Trainer for Veda Programming Assistant"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  def __init__(self, vocab_size: int = 8000, max_length: int = 512, batch_size: int = 4):
296
  self.vocab_size = vocab_size
297
  self.max_length = max_length
298
  self.batch_size = batch_size
299
  self.tokenizer = VedaTokenizer(vocab_size=vocab_size)
300
  self.model = None
301
-
302
- def prepare_data(self, extra_data: str = ""):
303
- """Prepare training data"""
304
  data = TRAINING_DATA
 
305
  if extra_data:
306
  data += "\n\n" + extra_data
307
-
 
 
 
308
  if os.path.exists("programming.txt"):
309
- with open("programming.txt", 'r', encoding='utf-8') as f:
310
- code_data = f.read()
311
- data += "\n\n" + code_data
312
-
 
 
 
313
  self.tokenizer.fit([data])
314
-
315
  all_tokens = self.tokenizer.encode(data)
316
  print(f"Total tokens: {len(all_tokens)}")
317
-
318
  sequences = []
319
  stride = self.max_length // 2
320
-
321
  for i in range(0, len(all_tokens) - self.max_length - 1, stride):
322
- seq = all_tokens[i:i + self.max_length + 1]
323
  if len(seq) == self.max_length + 1:
324
  sequences.append(seq)
325
-
326
  if len(sequences) < 10:
327
  stride = self.max_length // 4
328
  sequences = []
329
  for i in range(0, len(all_tokens) - self.max_length - 1, stride):
330
- seq = all_tokens[i:i + self.max_length + 1]
331
  if len(seq) == self.max_length + 1:
332
  sequences.append(seq)
333
-
334
  print(f"Created {len(sequences)} training sequences")
335
-
 
 
 
 
 
 
 
336
  sequences = np.array(sequences)
337
  X = sequences[:, :-1]
338
  y = sequences[:, 1:]
339
-
340
  dataset = tf.data.Dataset.from_tensor_slices((X, y))
341
  dataset = dataset.shuffle(1000).batch(self.batch_size).prefetch(1)
342
-
343
  return dataset
344
-
345
  def build_model(self):
346
  """Build the model"""
347
  self.model = VedaProgrammingLLM(
@@ -350,73 +854,102 @@ class VedaTrainer:
350
  d_model=256,
351
  num_heads=8,
352
  num_layers=4,
353
- ff_dim=512
354
  )
355
-
356
  self.model.compile(
357
- optimizer=keras.optimizers.Adam(1e-4),
358
  loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
359
- metrics=['accuracy']
360
  )
361
-
362
  dummy = tf.zeros((1, self.max_length), dtype=tf.int32)
363
  self.model(dummy)
364
-
365
  return self.model
366
-
367
- def train(self, epochs: int = 15, save_path: str = None, extra_data: str = ""):
 
 
 
 
 
 
368
  """Train the model"""
369
  if save_path is None:
370
  save_path = MODEL_DIR
371
-
372
- dataset = self.prepare_data(extra_data)
373
  self.build_model()
374
-
375
  self.model.summary()
376
-
377
  os.makedirs(save_path, exist_ok=True)
378
-
379
  history = self.model.fit(dataset, epochs=epochs, verbose=1)
380
-
 
381
  self.model.save_weights(os.path.join(save_path, "weights.h5"))
382
- self.tokenizer.save(os.path.join(save_path, "tokenizer.json"))
383
 
 
 
 
 
384
  config = self.model.get_config()
385
- with open(os.path.join(save_path, "config.json"), 'w') as f:
386
- json.dump(config, f)
387
-
388
  print(f"Model saved to {save_path}")
389
  return history
390
-
391
- def generate_response(self, user_input: str, max_tokens: int = 200, temperature: float = 0.7) -> str:
 
 
392
  """Generate a response"""
 
 
 
393
  prompt = f"<USER> {user_input}\n<ASSISTANT>"
394
-
395
  tokens = self.tokenizer.encode(prompt)
396
-
397
  generated = self.model.generate(
398
  tokens,
399
  max_new_tokens=max_tokens,
400
  temperature=temperature,
401
- repetition_penalty=1.2
402
  )
403
-
404
  response = self.tokenizer.decode(generated)
405
-
406
  if "<ASSISTANT>" in response:
407
  response = response.split("<ASSISTANT>")[-1].strip()
408
  if "<USER>" in response:
409
  response = response.split("<USER>")[0].strip()
410
-
411
  return response
412
 
413
 
414
  if __name__ == "__main__":
 
 
 
 
415
  trainer = VedaTrainer()
416
  trainer.train(epochs=20)
 
 
 
 
 
 
 
 
 
 
 
417
 
418
- print("\nTesting:")
419
- tests = ["Hello!", "What is a function?"]
420
- for test in tests:
421
- print(f"User: {test}")
422
- print(f"Assistant: {trainer.generate_response(test)}")
 
1
+ """Training for Veda Programming Assistant with Distillation Support"""
2
 
3
  import tensorflow as tf
4
  from tensorflow import keras
 
8
 
9
  from model import VedaProgrammingLLM
10
  from tokenizer import VedaTokenizer
11
+ from database import db
12
  from config import MODEL_DIR, VOCAB_SIZE, MAX_LENGTH, BATCH_SIZE
13
 
14
 
15
+ # Training data with proper format (no markdown code blocks - using <CODE> tags)
16
+ TRAINING_DATA = """
17
  <USER> Hello!
18
  <ASSISTANT> Hello! Welcome to Veda Programming Assistant! I can help you with coding, explain concepts, or answer programming questions. What would you like to learn today?
19
 
 
48
  print(result)
49
  <ENDCODE>
50
 
51
+ Output: 8
52
+
53
  This function takes two parameters and returns their sum.
54
 
55
  <USER> Explain what a function is
 
79
  print(greet("Alice"))
80
  <ENDCODE>
81
 
82
+ Output: Hello, Alice!
83
+
84
  <USER> What is a list?
85
  <ASSISTANT> A list is an ordered, mutable collection in Python that can hold items of any type.
86
 
 
113
  <USER> Write a fibonacci function
114
  <ASSISTANT> Here is a Fibonacci function:
115
 
116
+ Recursive approach:
117
+
118
  <CODE>
119
  def fibonacci(n):
120
  if n <= 0:
121
  return 0
122
  if n == 1:
123
  return 1
124
+ return fibonacci(n - 1) + fibonacci(n - 2)
125
 
126
  for i in range(10):
127
  print(fibonacci(i), end=" ")
128
  <ENDCODE>
129
 
130
+ Output: 0 1 1 2 3 5 8 13 21 34
131
+
132
+ Iterative approach (more efficient):
133
 
134
  <CODE>
135
  def fibonacci_iterative(n):
 
137
  return 0
138
  if n == 1:
139
  return 1
 
140
  a, b = 0, 1
141
  for _ in range(2, n + 1):
142
  a, b = b, a + b
 
150
  1. Base case - stops the recursion
151
  2. Recursive case - function calls itself
152
 
153
+ Example - countdown:
154
 
155
  <CODE>
156
  def countdown(n):
 
163
  countdown(5)
164
  <ENDCODE>
165
 
166
+ Output: 5, 4, 3, 2, 1, Done!
167
+
168
  Factorial example:
169
 
170
  <CODE>
 
176
  print(factorial(5))
177
  <ENDCODE>
178
 
179
+ Output: 120
180
+
181
  <USER> Write bubble sort
182
  <ASSISTANT> Here is the Bubble Sort algorithm:
183
 
184
  <CODE>
185
  def bubble_sort(arr):
186
  n = len(arr)
 
187
  for i in range(n):
188
  swapped = False
 
189
  for j in range(0, n - i - 1):
190
  if arr[j] > arr[j + 1]:
191
  arr[j], arr[j + 1] = arr[j + 1], arr[j]
192
  swapped = True
 
193
  if not swapped:
194
  break
 
195
  return arr
196
 
197
  numbers = [64, 34, 25, 12, 22, 11, 90]
198
+ print("Original:", numbers)
199
  print("Sorted:", bubble_sort(numbers.copy()))
200
  <ENDCODE>
201
 
202
+ Output:
203
+ Original: [64, 34, 25, 12, 22, 11, 90]
204
+ Sorted: [11, 12, 22, 25, 34, 64, 90]
205
+
206
+ How it works: repeatedly compares adjacent elements and swaps them if in wrong order. Largest elements bubble up to the end.
207
+
208
+ Time complexity: O(n^2)
209
+
210
  <USER> Write binary search
211
+ <ASSISTANT> Here is the Binary Search algorithm:
212
 
213
  <CODE>
214
  def binary_search(arr, target):
 
229
 
230
  numbers = [1, 3, 5, 7, 9, 11, 13, 15]
231
  print(binary_search(numbers, 7))
232
+ print(binary_search(numbers, 6))
233
  <ENDCODE>
234
 
235
+ Output:
236
+ 3
237
+ -1
238
+
239
+ Important: The array must be sorted!
240
+ Time complexity: O(log n)
241
+
242
  <USER> What is a class?
243
+ <ASSISTANT> A class is a blueprint for creating objects. It bundles data (attributes) and functions (methods) together.
244
 
245
  <CODE>
246
  class Dog:
 
250
 
251
  def bark(self):
252
  return f"{self.name} says Woof!"
253
+
254
+ def get_info(self):
255
+ return f"{self.name} is {self.age} years old"
256
 
257
  my_dog = Dog("Buddy", 3)
258
  print(my_dog.bark())
259
+ print(my_dog.get_info())
260
  <ENDCODE>
261
 
262
+ Output:
263
+ Buddy says Woof!
264
+ Buddy is 3 years old
265
+
266
+ Key concepts:
267
+ - __init__: constructor, runs when object is created
268
+ - self: refers to the current object
269
+ - Attributes: variables inside the class
270
+ - Methods: functions inside the class
271
+
272
  <USER> How do I read a file?
273
+ <ASSISTANT> Here is how to read files in Python:
274
+
275
+ Read entire file:
276
 
277
  <CODE>
278
  with open("filename.txt", "r") as file:
279
  content = file.read()
280
  print(content)
281
+ <ENDCODE>
282
+
283
+ Read line by line:
284
 
285
+ <CODE>
286
  with open("filename.txt", "r") as file:
287
  for line in file:
288
  print(line.strip())
289
  <ENDCODE>
290
 
291
+ Read all lines into list:
292
+
293
+ <CODE>
294
+ with open("filename.txt", "r") as file:
295
+ lines = file.readlines()
296
+ <ENDCODE>
297
+
298
+ With error handling:
299
+
300
+ <CODE>
301
+ def read_file_safely(filename):
302
+ try:
303
+ with open(filename, "r") as file:
304
+ return file.read()
305
+ except FileNotFoundError:
306
+ print(f"File not found: {filename}")
307
+ return None
308
+
309
+ content = read_file_safely("myfile.txt")
310
+ <ENDCODE>
311
+
312
+ The with statement automatically closes the file when done.
313
+
314
+ <USER> How do I write to a file?
315
+ <ASSISTANT> Here is how to write to files in Python:
316
+
317
+ Write (overwrites existing content):
318
+
319
+ <CODE>
320
+ with open("output.txt", "w") as file:
321
+ file.write("Hello, World!")
322
+ file.write("This is a new line")
323
+ <ENDCODE>
324
+
325
+ Append (adds to existing content):
326
+
327
+ <CODE>
328
+ with open("output.txt", "a") as file:
329
+ file.write("This is appended")
330
+ <ENDCODE>
331
+
332
+ Write multiple lines:
333
+
334
+ <CODE>
335
+ lines = ["Line 1", "Line 2", "Line 3"]
336
+ with open("output.txt", "w") as file:
337
+ for line in lines:
338
+ file.write(line + "\\n")
339
+ <ENDCODE>
340
+
341
+ Or use writelines:
342
+
343
+ <CODE>
344
+ lines = ["Line 1\\n", "Line 2\\n", "Line 3\\n"]
345
+ with open("output.txt", "w") as file:
346
+ file.writelines(lines)
347
+ <ENDCODE>
348
+
349
  <USER> How do I handle errors?
350
+ <ASSISTANT> Python uses try-except blocks for error handling:
351
+
352
+ Basic try-except:
353
 
354
  <CODE>
355
  try:
 
372
  print(f"Error: {e}")
373
  <ENDCODE>
374
 
375
+ Complete structure:
 
376
 
377
+ <CODE>
378
+ try:
379
+ result = risky_operation()
380
+ except SomeError:
381
+ print("Handle error")
382
+ else:
383
+ print("Success! No error occurred")
384
+ finally:
385
+ print("This always runs")
386
+ <ENDCODE>
387
+
388
+ Common exceptions:
389
+ - ValueError: wrong value
390
+ - TypeError: wrong type
391
+ - IndexError: index out of range
392
+ - KeyError: dictionary key not found
393
+ - FileNotFoundError: file does not exist
394
+ - ZeroDivisionError: division by zero
395
+
396
+ <USER> What is a dictionary?
397
+ <ASSISTANT> A dictionary is a collection of key-value pairs in Python.
398
+
399
+ Creating dictionaries:
400
+
401
+ <CODE>
402
+ empty_dict = {}
403
+ person = {"name": "Alice", "age": 25, "city": "NYC"}
404
+ another = dict(name="Bob", age=30)
405
+ <ENDCODE>
406
+
407
+ Accessing values:
408
+
409
+ <CODE>
410
+ person = {"name": "Alice", "age": 25}
411
+
412
+ print(person["name"])
413
+ print(person.get("name"))
414
+ print(person.get("job", "Unknown"))
415
+ <ENDCODE>
416
+
417
+ Modifying:
418
+
419
+ <CODE>
420
+ person = {"name": "Alice", "age": 25}
421
+
422
+ person["city"] = "Boston"
423
+ person["age"] = 26
424
+ del person["city"]
425
+ age = person.pop("age")
426
+ <ENDCODE>
427
+
428
+ Iterating:
429
+
430
+ <CODE>
431
+ person = {"name": "Alice", "age": 25}
432
+
433
+ for key in person:
434
+ print(key)
435
+
436
+ for value in person.values():
437
+ print(value)
438
+
439
+ for key, value in person.items():
440
+ print(f"{key}: {value}")
441
+ <ENDCODE>
442
 
443
  <USER> What is the difference between list and tuple?
444
  <ASSISTANT> The main difference is mutability:
445
+ - Lists are mutable (can be changed)
446
+ - Tuples are immutable (cannot be changed)
447
+
448
+ Syntax:
449
 
450
  <CODE>
451
  my_list = [1, 2, 3]
452
  my_tuple = (1, 2, 3)
453
+ <ENDCODE>
454
+
455
+ Mutability difference:
456
 
457
+ <CODE>
458
+ my_list = [1, 2, 3]
459
  my_list[0] = 100
460
  my_list.append(4)
461
+ print(my_list)
462
+
463
+ my_tuple = (1, 2, 3)
464
+ # my_tuple[0] = 100 # ERROR! Cannot modify
465
+ # my_tuple.append(4) # ERROR! No append method
466
  <ENDCODE>
467
 
468
+ When to use:
469
+ - List: data that will change (shopping cart, scores)
470
+ - Tuple: fixed data (coordinates, RGB colors)
471
 
472
+ Examples:
473
 
474
+ <CODE>
475
+ coordinates = (10, 20)
476
+ rgb_color = (255, 128, 0)
477
+
478
+ shopping_cart = ["apple", "bread"]
479
+ shopping_cart.append("milk")
480
+ <ENDCODE>
481
+
482
+ Other differences:
483
+ - Tuples are slightly faster
484
+ - Tuples use less memory
485
+ - Tuples can be dictionary keys
486
+
487
+ <USER> How do I use list comprehension?
488
+ <ASSISTANT> List comprehension is a concise way to create lists.
489
+
490
+ Basic syntax:
491
+
492
+ <CODE>
493
+ new_list = [expression for item in iterable]
494
+ <ENDCODE>
495
+
496
+ Examples:
497
+
498
+ <CODE>
499
+ squares = [x ** 2 for x in range(5)]
500
+ print(squares)
501
+ <ENDCODE>
502
+
503
+ Output: [0, 1, 4, 9, 16]
504
+
505
+ With condition:
506
+
507
+ <CODE>
508
+ evens = [x for x in range(10) if x % 2 == 0]
509
+ print(evens)
510
+ <ENDCODE>
511
+
512
+ Output: [0, 2, 4, 6, 8]
513
+
514
+ Transform strings:
515
+
516
+ <CODE>
517
+ words = ["hello", "world"]
518
+ upper_words = [word.upper() for word in words]
519
+ print(upper_words)
520
+ <ENDCODE>
521
+
522
+ Output: ["HELLO", "WORLD"]
523
+
524
+ With if-else:
525
+
526
+ <CODE>
527
+ numbers = [1, 2, 3, 4, 5]
528
+ labels = ["even" if x % 2 == 0 else "odd" for x in numbers]
529
+ print(labels)
530
+ <ENDCODE>
531
+
532
+ Output: ["odd", "even", "odd", "even", "odd"]
533
+
534
+ <USER> Write quicksort
535
+ <ASSISTANT> Here is the QuickSort algorithm:
536
+
537
+ <CODE>
538
+ def quicksort(arr):
539
+ if len(arr) <= 1:
540
+ return arr
541
+
542
+ pivot = arr[len(arr) // 2]
543
+ left = [x for x in arr if x < pivot]
544
+ middle = [x for x in arr if x == pivot]
545
+ right = [x for x in arr if x > pivot]
546
+
547
+ return quicksort(left) + middle + quicksort(right)
548
+
549
+ numbers = [64, 34, 25, 12, 22, 11, 90]
550
+ print("Original:", numbers)
551
+ print("Sorted:", quicksort(numbers))
552
+ <ENDCODE>
553
+
554
+ Output:
555
+ Original: [64, 34, 25, 12, 22, 11, 90]
556
+ Sorted: [11, 12, 22, 25, 34, 64, 90]
557
+
558
+ In-place version:
559
+
560
+ <CODE>
561
+ def quicksort_inplace(arr, low, high):
562
+ if low < high:
563
+ pivot_idx = partition(arr, low, high)
564
+ quicksort_inplace(arr, low, pivot_idx - 1)
565
+ quicksort_inplace(arr, pivot_idx + 1, high)
566
+
567
+ def partition(arr, low, high):
568
+ pivot = arr[high]
569
+ i = low - 1
570
+ for j in range(low, high):
571
+ if arr[j] <= pivot:
572
+ i += 1
573
+ arr[i], arr[j] = arr[j], arr[i]
574
+ arr[i + 1], arr[high] = arr[high], arr[i + 1]
575
+ return i + 1
576
+
577
+ numbers = [64, 34, 25, 12, 22, 11, 90]
578
+ quicksort_inplace(numbers, 0, len(numbers) - 1)
579
+ print(numbers)
580
+ <ENDCODE>
581
+
582
+ Time complexity:
583
+ - Average: O(n log n)
584
+ - Worst: O(n^2)
585
+
586
+ <USER> Write merge sort
587
+ <ASSISTANT> Here is the Merge Sort algorithm:
588
+
589
+ <CODE>
590
+ def merge_sort(arr):
591
+ if len(arr) <= 1:
592
+ return arr
593
+
594
+ mid = len(arr) // 2
595
+ left = merge_sort(arr[:mid])
596
+ right = merge_sort(arr[mid:])
597
 
598
+ return merge(left, right)
599
+
600
+ def merge(left, right):
601
+ result = []
602
+ i = j = 0
603
+
604
+ while i < len(left) and j < len(right):
605
+ if left[i] <= right[j]:
606
+ result.append(left[i])
607
+ i += 1
608
+ else:
609
+ result.append(right[j])
610
+ j += 1
611
+
612
+ result.extend(left[i:])
613
+ result.extend(right[j:])
614
+ return result
615
+
616
+ numbers = [64, 34, 25, 12, 22, 11, 90]
617
+ print("Original:", numbers)
618
+ print("Sorted:", merge_sort(numbers))
619
+ <ENDCODE>
620
+
621
+ Output:
622
+ Original: [64, 34, 25, 12, 22, 11, 90]
623
+ Sorted: [11, 12, 22, 25, 34, 64, 90]
624
+
625
+ How it works:
626
+ 1. Divide array into two halves
627
+ 2. Recursively sort each half
628
+ 3. Merge the sorted halves
629
+
630
+ Time complexity: O(n log n) always
631
+ Space complexity: O(n)
632
+
633
+ <USER> What is a lambda function?
634
+ <ASSISTANT> A lambda function is a small anonymous function defined in one line.
635
+
636
+ Syntax:
637
+
638
+ <CODE>
639
+ lambda arguments: expression
640
+ <ENDCODE>
641
+
642
+ Examples:
643
+
644
+ <CODE>
645
+ add = lambda x, y: x + y
646
+ print(add(5, 3))
647
+
648
+ square = lambda x: x ** 2
649
+ print(square(4))
650
+
651
+ is_even = lambda x: x % 2 == 0
652
+ print(is_even(4))
653
+ <ENDCODE>
654
+
655
+ Output:
656
+ 8
657
+ 16
658
+ True
659
+
660
+ Common uses:
661
+
662
+ With sorted():
663
+
664
+ <CODE>
665
+ students = [("Alice", 85), ("Bob", 92), ("Charlie", 78)]
666
+ sorted_by_score = sorted(students, key=lambda x: x[1])
667
+ print(sorted_by_score)
668
+ <ENDCODE>
669
+
670
+ With map():
671
+
672
+ <CODE>
673
+ numbers = [1, 2, 3, 4, 5]
674
+ squares = list(map(lambda x: x ** 2, numbers))
675
+ print(squares)
676
+ <ENDCODE>
677
+
678
+ With filter():
679
+
680
+ <CODE>
681
+ numbers = [1, 2, 3, 4, 5, 6]
682
+ evens = list(filter(lambda x: x % 2 == 0, numbers))
683
+ print(evens)
684
+ <ENDCODE>
685
+
686
+ <USER> What are decorators?
687
+ <ASSISTANT> Decorators are functions that modify the behavior of other functions.
688
+
689
+ Basic decorator:
690
+
691
+ <CODE>
692
+ def my_decorator(func):
693
+ def wrapper():
694
+ print("Before function")
695
+ func()
696
+ print("After function")
697
+ return wrapper
698
+
699
+ @my_decorator
700
+ def say_hello():
701
+ print("Hello!")
702
+
703
+ say_hello()
704
+ <ENDCODE>
705
+
706
+ Output:
707
+ Before function
708
+ Hello!
709
+ After function
710
+
711
+ With arguments:
712
+
713
+ <CODE>
714
+ def my_decorator(func):
715
+ def wrapper(*args, **kwargs):
716
+ print("Before function")
717
+ result = func(*args, **kwargs)
718
+ print("After function")
719
+ return result
720
+ return wrapper
721
+
722
+ @my_decorator
723
+ def add(a, b):
724
+ return a + b
725
+
726
+ print(add(5, 3))
727
+ <ENDCODE>
728
+
729
+ Practical example - timer:
730
+
731
+ <CODE>
732
+ import time
733
+
734
+ def timer(func):
735
+ def wrapper(*args, **kwargs):
736
+ start = time.time()
737
+ result = func(*args, **kwargs)
738
+ end = time.time()
739
+ print(f"{func.__name__} took {end - start:.4f} seconds")
740
+ return result
741
+ return wrapper
742
+
743
+ @timer
744
+ def slow_function():
745
+ time.sleep(1)
746
+ return "Done"
747
+
748
+ slow_function()
749
+ <ENDCODE>
750
+
751
+ <USER> Thanks!
752
+ <ASSISTANT> You are welcome! I am glad I could help. Feel free to ask me anything anytime. Happy coding!
753
+
754
+ <USER> Hi
755
+ <ASSISTANT> Hello! How can I help you today? I can write code, explain concepts, or answer your programming questions!
756
+
757
+ <USER> What can you do?
758
+ <ASSISTANT> I am Veda Programming Assistant! I can help you with:
759
+
760
+ 1. Writing Python code
761
+ 2. Explaining programming concepts
762
+ 3. Debugging code
763
+ 4. Answering questions about algorithms
764
+ 5. Teaching data structures
765
+ 6. Helping with file handling
766
+ 7. Explaining object-oriented programming
767
+ 8. And much more!
768
+
769
+ Just ask me anything about programming and I will do my best to help you!
770
+
771
+ <USER> 2+2=?
772
+ <ASSISTANT> 4
773
+
774
+ <USER> What is 10 times 5?
775
+ <ASSISTANT> 10 times 5 equals 50.
776
+
777
+ <USER> Calculate 100 divided by 4
778
+ <ASSISTANT> 100 divided by 4 equals 25.
779
+ """
780
+
781
+
782
+ class VedaTrainer:
783
+ """Trainer for Veda Programming Assistant with Distillation Support"""
784
+
785
  def __init__(self, vocab_size: int = 8000, max_length: int = 512, batch_size: int = 4):
786
  self.vocab_size = vocab_size
787
  self.max_length = max_length
788
  self.batch_size = batch_size
789
  self.tokenizer = VedaTokenizer(vocab_size=vocab_size)
790
  self.model = None
791
+
792
+ def prepare_data(self, extra_data: str = "", distillation_data: str = ""):
793
+ """Prepare training data including distillation data"""
794
  data = TRAINING_DATA
795
+
796
  if extra_data:
797
  data += "\n\n" + extra_data
798
+
799
+ if distillation_data:
800
+ data += "\n\n" + distillation_data
801
+
802
  if os.path.exists("programming.txt"):
803
+ try:
804
+ with open("programming.txt", "r", encoding="utf-8") as f:
805
+ code_data = f.read()
806
+ data += "\n\n" + code_data
807
+ except Exception as e:
808
+ print(f"Warning: Could not read programming.txt: {e}")
809
+
810
  self.tokenizer.fit([data])
811
+
812
  all_tokens = self.tokenizer.encode(data)
813
  print(f"Total tokens: {len(all_tokens)}")
814
+
815
  sequences = []
816
  stride = self.max_length // 2
817
+
818
  for i in range(0, len(all_tokens) - self.max_length - 1, stride):
819
+ seq = all_tokens[i : i + self.max_length + 1]
820
  if len(seq) == self.max_length + 1:
821
  sequences.append(seq)
822
+
823
  if len(sequences) < 10:
824
  stride = self.max_length // 4
825
  sequences = []
826
  for i in range(0, len(all_tokens) - self.max_length - 1, stride):
827
+ seq = all_tokens[i : i + self.max_length + 1]
828
  if len(seq) == self.max_length + 1:
829
  sequences.append(seq)
830
+
831
  print(f"Created {len(sequences)} training sequences")
832
+
833
+ if len(sequences) == 0:
834
+ print("Warning: No sequences created. Using minimal sequence.")
835
+ min_seq = all_tokens[:self.max_length + 1]
836
+ while len(min_seq) < self.max_length + 1:
837
+ min_seq.append(0)
838
+ sequences = [min_seq]
839
+
840
  sequences = np.array(sequences)
841
  X = sequences[:, :-1]
842
  y = sequences[:, 1:]
843
+
844
  dataset = tf.data.Dataset.from_tensor_slices((X, y))
845
  dataset = dataset.shuffle(1000).batch(self.batch_size).prefetch(1)
846
+
847
  return dataset
848
+
849
  def build_model(self):
850
  """Build the model"""
851
  self.model = VedaProgrammingLLM(
 
854
  d_model=256,
855
  num_heads=8,
856
  num_layers=4,
857
+ ff_dim=512,
858
  )
859
+
860
  self.model.compile(
861
+ optimizer=keras.optimizers.Adam(learning_rate=1e-4),
862
  loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
863
+ metrics=["accuracy"],
864
  )
865
+
866
  dummy = tf.zeros((1, self.max_length), dtype=tf.int32)
867
  self.model(dummy)
868
+
869
  return self.model
870
+
871
+ def train(
872
+ self,
873
+ epochs: int = 15,
874
+ save_path: str = None,
875
+ extra_data: str = "",
876
+ distillation_data: str = "",
877
+ ):
878
  """Train the model"""
879
  if save_path is None:
880
  save_path = MODEL_DIR
881
+
882
+ dataset = self.prepare_data(extra_data, distillation_data)
883
  self.build_model()
884
+
885
  self.model.summary()
886
+
887
  os.makedirs(save_path, exist_ok=True)
888
+
889
  history = self.model.fit(dataset, epochs=epochs, verbose=1)
890
+
891
+ # Save weights
892
  self.model.save_weights(os.path.join(save_path, "weights.h5"))
 
893
 
894
+ # Save tokenizer
895
+ self.tokenizer.save(os.path.join(save_path, "tokenizer.json"))
896
+
897
+ # Save config
898
  config = self.model.get_config()
899
+ with open(os.path.join(save_path, "config.json"), "w") as f:
900
+ json.dump(config, f, indent=2)
901
+
902
  print(f"Model saved to {save_path}")
903
  return history
904
+
905
+ def generate_response(
906
+ self, user_input: str, max_tokens: int = 200, temperature: float = 0.7
907
+ ) -> str:
908
  """Generate a response"""
909
+ if self.model is None:
910
+ return "Model not loaded."
911
+
912
  prompt = f"<USER> {user_input}\n<ASSISTANT>"
913
+
914
  tokens = self.tokenizer.encode(prompt)
915
+
916
  generated = self.model.generate(
917
  tokens,
918
  max_new_tokens=max_tokens,
919
  temperature=temperature,
920
+ repetition_penalty=1.2,
921
  )
922
+
923
  response = self.tokenizer.decode(generated)
924
+
925
  if "<ASSISTANT>" in response:
926
  response = response.split("<ASSISTANT>")[-1].strip()
927
  if "<USER>" in response:
928
  response = response.split("<USER>")[0].strip()
929
+
930
  return response
931
 
932
 
933
  if __name__ == "__main__":
934
+ print("=" * 50)
935
+ print("Training Veda Programming Assistant")
936
+ print("=" * 50)
937
+
938
  trainer = VedaTrainer()
939
  trainer.train(epochs=20)
940
+
941
+ print("\n" + "=" * 50)
942
+ print("Testing the model:")
943
+ print("=" * 50)
944
+
945
+ test_prompts = [
946
+ "Hello!",
947
+ "What is a function?",
948
+ "Write a function to reverse a string",
949
+ "2+2=?",
950
+ ]
951
 
952
+ for prompt in test_prompts:
953
+ print(f"\nUser: {prompt}")
954
+ response = trainer.generate_response(prompt)
955
+ print(f"Assistant: {response}")