vedaco commited on
Commit
2b7f32a
·
verified ·
1 Parent(s): ee25239

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +166 -260
train.py CHANGED
@@ -1,4 +1,4 @@
1
- """Training - MODIFIED with conversation data"""
2
 
3
  import tensorflow as tf
4
  from tensorflow import keras
@@ -10,21 +10,21 @@ from model import VedaProgrammingLLM
10
  from tokenizer import VedaTokenizer
11
  from config import MODEL_DIR, VOCAB_SIZE, MAX_LENGTH, BATCH_SIZE
12
 
13
- # ADDED: Conversation training data
14
  TRAINING_DATA = '''
15
  <USER> Hello!
16
- <ASSISTANT> Hello! Welcome to Veda Programming Assistant! I'm here to help you with coding, explain concepts, or answer any programming questions. What would you like to learn today?
17
 
18
  <USER> What is Python?
19
- <ASSISTANT> Python is a high-level, interpreted programming language created by Guido van Rossum in 1991. It's known for its simple, readable syntax that emphasizes code readability.
20
 
21
  Key features:
22
  - Easy to learn and read
23
  - Dynamic typing
24
  - Large standard library
25
- - Great for beginners and experts alike
26
 
27
- Python is used for web development, data science, AI, automation, and much more!
28
 
29
  <USER> How do I print hello world?
30
  <ASSISTANT> To print "Hello, World!" in Python, use the print() function:
@@ -35,10 +35,8 @@ print("Hello, World!")
35
 
36
  When you run this, it displays: Hello, World!
37
 
38
- The print() function outputs text to the console. You can print strings, numbers, and variables!
39
-
40
  <USER> Write a function to add two numbers
41
- <ASSISTANT> Here's a simple function to add two numbers:
42
 
43
  <CODE>
44
  def add(a, b):
@@ -52,9 +50,9 @@ print(result) # Output: 8
52
  This function takes two parameters and returns their sum.
53
 
54
  <USER> Explain what a function is
55
- <ASSISTANT> A function is a reusable block of code that performs a specific task. Functions help organize code, avoid repetition, and make programs easier to understand.
56
 
57
- Here's the basic structure:
58
 
59
  <CODE>
60
  def function_name(parameters):
@@ -99,23 +97,23 @@ Common operations:
99
  fruits = ["apple", "banana"]
100
 
101
  # Add items
102
- fruits.append("cherry") # Add to end
103
- fruits.insert(0, "apricot") # Insert at position
104
 
105
  # Access items
106
  print(fruits[0]) # First item
107
  print(fruits[-1]) # Last item
108
 
109
  # Remove items
110
- fruits.remove("banana") # Remove by value
111
- fruits.pop() # Remove last item
112
 
113
  # Length
114
  print(len(fruits))
115
  <ENDCODE>
116
 
117
  <USER> Write a fibonacci function
118
- <ASSISTANT> Here's a Fibonacci function with two approaches:
119
 
120
  Recursive approach:
121
 
@@ -148,10 +146,8 @@ def fibonacci_iterative(n):
148
  return b
149
  <ENDCODE>
150
 
151
- The Fibonacci sequence: each number is the sum of the two before it!
152
-
153
  <USER> What is recursion?
154
- <ASSISTANT> Recursion is when a function calls itself to solve a problem. It breaks complex problems into smaller, similar sub-problems.
155
 
156
  Key components:
157
  1. Base case - stops the recursion
@@ -182,10 +178,8 @@ def factorial(n):
182
  print(factorial(5)) # Output: 120
183
  <ENDCODE>
184
 
185
- Warning: Always have a base case to avoid infinite recursion!
186
-
187
  <USER> Write bubble sort
188
- <ASSISTANT> Here's the Bubble Sort algorithm:
189
 
190
  <CODE>
191
  def bubble_sort(arr):
@@ -214,12 +208,8 @@ Output:
214
  Original: [64, 34, 25, 12, 22, 11, 90]
215
  Sorted: [11, 12, 22, 25, 34, 64, 90]
216
 
217
- How it works: repeatedly compares adjacent elements and swaps them if in wrong order. Largest elements "bubble up" to the end.
218
-
219
- Time complexity: O(n²)
220
-
221
  <USER> Write binary search
222
- <ASSISTANT> Here's the Binary Search algorithm:
223
 
224
  <CODE>
225
  def binary_search(arr, target):
@@ -245,14 +235,11 @@ print(binary_search(numbers, 6)) # Output: -1
245
  <ENDCODE>
246
 
247
  Important: The array must be sorted!
248
-
249
- Time complexity: O(log n) - much faster than linear search!
250
 
251
  <USER> What is a class?
252
  <ASSISTANT> A class is a blueprint for creating objects. It bundles data (attributes) and functions (methods) together.
253
 
254
- Basic class structure:
255
-
256
  <CODE>
257
  class Dog:
258
  def __init__(self, name, age):
@@ -278,9 +265,7 @@ Key concepts:
278
  - Methods: functions inside the class
279
 
280
  <USER> How do I read a file?
281
- <ASSISTANT> Here's how to read files in Python:
282
-
283
- Basic file reading:
284
 
285
  <CODE>
286
  # Read entire file
@@ -312,13 +297,9 @@ def read_file_safely(filename):
312
  content = read_file_safely("myfile.txt")
313
  <ENDCODE>
314
 
315
- The 'with' statement automatically closes the file when done - always recommended!
316
-
317
  <USER> How do I handle errors?
318
  <ASSISTANT> Python uses try-except blocks for error handling:
319
 
320
- Basic try-except:
321
-
322
  <CODE>
323
  try:
324
  result = 10 / 0
@@ -333,7 +314,7 @@ try:
333
  number = int(input("Enter number: "))
334
  result = 10 / number
335
  except ValueError:
336
- print("That's not a valid number!")
337
  except ZeroDivisionError:
338
  print("Cannot divide by zero!")
339
  except Exception as e:
@@ -348,258 +329,183 @@ try:
348
  except SomeError:
349
  print("Handle error")
350
  else:
351
- print("Success!") # Runs if no error
352
  finally:
353
- print("Always runs") # Cleanup code
354
  <ENDCODE>
355
 
356
- Common exceptions:
357
- - ValueError: wrong value
358
- - TypeError: wrong type
359
- - IndexError: index out of range
360
- - KeyError: dictionary key not found
361
- - FileNotFoundError: file doesn't exist
362
-
363
  <USER> Thanks!
364
- <ASSISTANT> You're welcome! I'm glad I could help. Remember:
365
-
366
- - Practice coding every day
367
- - Build small projects
368
- - Don't fear errors - they help you learn
369
- - Ask questions whenever stuck
370
-
371
- Feel free to ask me anything anytime! Happy coding! 🐍
372
-
373
- <USER> debug this: def add(a b): return a + b
374
- <ASSISTANT> I found the issue! You're missing a comma between parameters.
375
-
376
- Your code:
377
- ```python
378
- def add(a b): return a + b # ❌ Missing commaFixed code:
379
-
380
- <CODE> def add(a, b): return a + b <ENDCODE>
381
- The fix: Add a comma between a and b in the parameters.
382
 
383
- Testing it:
384
-
385
- <CODE> def add(a, b): return a + b
386
- print(add(5, 3)) # Output: 8
387
- print(add(10, 20)) # Output: 30
388
- <ENDCODE>
389
-
390
- Common syntax errors to watch:
391
-
392
- Missing commas between parameters
393
- Missing colons after def, if, for, while
394
- Incorrect indentation
395
- Mismatched parentheses
396
  <USER> What is the difference between list and tuple?
397
  <ASSISTANT> The main difference is mutability:
 
 
398
 
399
- Lists are mutable (can be changed)
400
- Tuples are immutable (cannot be changed)
401
  Syntax:
402
 
403
- <CODE> # List - square brackets my_list = [1, 2, 3]
404
- Tuple - parentheses
 
 
 
405
  my_tuple = (1, 2, 3)
406
  <ENDCODE>
407
 
408
  Mutability difference:
409
 
410
- <CODE> # Lists can be modified my_list = [1, 2, 3] my_list[0] = 100 # OK my_list.append(4) # OK
411
- Tuples cannot be modified
412
- my_tuple = (1, 2, 3)
 
 
413
 
414
- my_tuple[0] = 100 # ERROR!
415
- my_tuple.append(4) # ERROR!
 
416
  <ENDCODE>
417
- When to use:
418
 
419
- List: data that will change (shopping cart, scores)
420
- Tuple: fixed data (coordinates, RGB colors)
421
- Example:
422
 
423
- <CODE> # Tuple for fixed data coordinates = (10, 20) rgb_color = (255, 128, 0)
424
- List for changing data
425
- shopping_cart = ["apple", "bread"]
426
- shopping_cart.append("milk")
427
- <ENDCODE>
428
 
429
- Other differences:
430
 
431
- Tuples are slightly faster
432
- Tuples use less memory
433
- Tuples can be dictionary keys
434
- '''
435
  class VedaTrainer:
436
- """Trainer for Veda Programming Assistant"""def __init__(self, vocab_size: int = 8000, max_length: int = 512, batch_size: int = 4):
437
- self.vocab_size = vocab_size
438
- self.max_length = max_length
439
- self.batch_size = batch_size
440
- self.tokenizer = VedaTokenizer(vocab_size=vocab_size)
441
- self.model = None
442
-
443
- def prepare_data(self, extra_data: str = ""):
444
- """Prepare training data"""
445
- # Combine training data
446
- data = TRAINING_DATA
447
- if extra_data:
448
- data += "\n\n" + extra_data
449
-
450
- # Load additional code from programming.txt if exists
451
- if os.path.exists("programming.txt"):
452
- with open("programming.txt", 'r', encoding='utf-8') as f:
453
- code_data = f.read()
454
- data += "\n\n" + code_data
455
-
456
- # Fit tokenizer
457
- self.tokenizer.fit([data])
458
-
459
- # Encode
460
- all_tokens = self.tokenizer.encode(data)
461
- print(f"Total tokens: {len(all_tokens)}")
462
-
463
- # Create sequences
464
- sequences = []
465
- stride = self.max_length // 2
466
-
467
- for i in range(0, len(all_tokens) - self.max_length - 1, stride):
468
- seq = all_tokens[i:i + self.max_length + 1]
469
- if len(seq) == self.max_length + 1:
470
- sequences.append(seq)
471
-
472
- if len(sequences) < 10:
473
- stride = self.max_length // 4
474
  sequences = []
 
 
475
  for i in range(0, len(all_tokens) - self.max_length - 1, stride):
476
  seq = all_tokens[i:i + self.max_length + 1]
477
  if len(seq) == self.max_length + 1:
478
  sequences.append(seq)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
479
 
480
- print(f"Created {len(sequences)} training sequences")
481
-
482
- sequences = np.array(sequences)
483
- X = sequences[:, :-1]
484
- y = sequences[:, 1:]
485
-
486
- dataset = tf.data.Dataset.from_tensor_slices((X, y))
487
- dataset = dataset.shuffle(1000).batch(self.batch_size).prefetch(1)
488
-
489
- return dataset
490
-
491
- def build_model(self):
492
- """Build the model"""
493
- self.model = VedaProgrammingLLM(
494
- vocab_size=self.tokenizer.vocabulary_size,
495
- max_length=self.max_length,
496
- d_model=256,
497
- num_heads=8,
498
- num_layers=4,
499
- ff_dim=512
500
- )
501
-
502
- self.model.compile(
503
- optimizer=keras.optimizers.Adam(1e-4),
504
- loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
505
- metrics=['accuracy']
506
- )
507
-
508
- dummy = tf.zeros((1, self.max_length), dtype=tf.int32)
509
- self.model(dummy)
510
 
511
- return self.model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512
 
513
- def train(self, epochs: int = 15, save_path: str = None, extra_data: str = ""):
514
- """Train the model"""
515
- if save_path is None:
516
- save_path = MODEL_DIR
517
-
518
- dataset = self.prepare_data(extra_data)
519
- self.build_model()
520
-
521
- self.model.summary()
522
-
523
- os.makedirs(save_path, exist_ok=True)
524
-
525
- history = self.model.fit(dataset, epochs=epochs, verbose=1)
526
-
527
- # Save
528
- self.model.save_weights(os.path.join(save_path, "weights.h5"))
529
- self.tokenizer.save(os.path.join(save_path, "tokenizer.json"))
530
-
531
- config = self.model.get_config()
532
- with open(os.path.join(save_path, "config.json"), 'w') as f:
533
- json.dump(config, f)
534
-
535
- print(f"Model saved to {save_path}")
536
- return history
537
 
538
- def generate_response(self, user_input: str, max_tokens: int = 200,
539
- temperature: float = 0.7) -> str:
540
- """Generate a response"""
541
- prompt = f"<USER> {user_input}\n<ASSISTANT>"
542
-
543
- tokens = self.tokenizer.encode(prompt)
544
-
545
- generated = self.model.generate(
546
- tokens,
547
- max_new_tokens=max_tokens,
548
- temperature=temperature,
549
- repetition_penalty=1.2
550
- )
551
 
552
- response = self.tokenizer.decode(generated)
 
 
553
 
554
- # Extract assistant response
555
- if "<ASSISTANT>" in response:
556
- response = response.split("<ASSISTANT>")[-1].strip()
557
- if "<USER>" in response:
558
- response = response.split("<USER>")[0].strip()
559
 
560
- return responseif name == "main":
561
- trainer = VedaTrainer()
562
- trainer.train(epochs=20)# Test
563
- print("\n" + "="*50)
564
- print("Testing:")
565
- print("="*50)
566
-
567
- tests = [
568
- "Hello!",
569
- "What is a function?",
570
- "Write a function to reverse a string",
571
- ]
572
-
573
- for test in tests:
574
- print(f"\nUser: {test}")
575
- print(f"Assistant: {trainer.generate_response(test)}")
576
- ---
577
-
578
- ### 3. config.py (MODIFY - Increase max_length)
579
-
580
- ```python
581
- """Configuration - MODIFIED for conversation"""
582
-
583
- import os
584
-
585
- BASE_DIR = os.path.dirname(os.path.abspath(__file__))
586
- DATA_DIR = os.path.join(BASE_DIR, "data")
587
- MODEL_DIR = os.path.join(BASE_DIR, "veda_model")
588
-
589
- os.makedirs(DATA_DIR, exist_ok=True)
590
- os.makedirs(MODEL_DIR, exist_ok=True)
591
-
592
- DATABASE_PATH = os.path.join(DATA_DIR, "conversations.db")
593
-
594
- # Model settings - MODIFIED
595
- VOCAB_SIZE = 8000 # Increased for more words
596
- MAX_LENGTH = 512 # Increased for longer conversations
597
- D_MODEL = 256
598
- NUM_HEADS = 8
599
- NUM_LAYERS = 4
600
- FF_DIM = 512
601
- BATCH_SIZE = 4 # Smaller for longer sequences
602
-
603
- # Generation defaults
604
- DEFAULT_TEMPERATURE = 0.7
605
- DEFAULT_MAX_TOKENS = 200
 
1
+ """Training for Veda Programming Assistant"""
2
 
3
  import tensorflow as tf
4
  from tensorflow import keras
 
10
  from tokenizer import VedaTokenizer
11
  from config import MODEL_DIR, VOCAB_SIZE, MAX_LENGTH, BATCH_SIZE
12
 
13
+
14
  TRAINING_DATA = '''
15
  <USER> Hello!
16
+ <ASSISTANT> Hello! Welcome to Veda Programming Assistant! I can help you with coding, explain concepts, or answer programming questions. What would you like to learn today?
17
 
18
  <USER> What is Python?
19
+ <ASSISTANT> Python is a high-level, interpreted programming language created by Guido van Rossum in 1991. It is known for its simple, readable syntax.
20
 
21
  Key features:
22
  - Easy to learn and read
23
  - Dynamic typing
24
  - Large standard library
25
+ - Great for beginners and experts
26
 
27
+ Python is used for web development, data science, AI, automation, and more!
28
 
29
  <USER> How do I print hello world?
30
  <ASSISTANT> To print "Hello, World!" in Python, use the print() function:
 
35
 
36
  When you run this, it displays: Hello, World!
37
 
 
 
38
  <USER> Write a function to add two numbers
39
+ <ASSISTANT> Here is a simple function to add two numbers:
40
 
41
  <CODE>
42
  def add(a, b):
 
50
  This function takes two parameters and returns their sum.
51
 
52
  <USER> Explain what a function is
53
+ <ASSISTANT> A function is a reusable block of code that performs a specific task. Functions help organize code and avoid repetition.
54
 
55
+ Here is the basic structure:
56
 
57
  <CODE>
58
  def function_name(parameters):
 
97
  fruits = ["apple", "banana"]
98
 
99
  # Add items
100
+ fruits.append("cherry")
101
+ fruits.insert(0, "apricot")
102
 
103
  # Access items
104
  print(fruits[0]) # First item
105
  print(fruits[-1]) # Last item
106
 
107
  # Remove items
108
+ fruits.remove("banana")
109
+ fruits.pop()
110
 
111
  # Length
112
  print(len(fruits))
113
  <ENDCODE>
114
 
115
  <USER> Write a fibonacci function
116
+ <ASSISTANT> Here is a Fibonacci function:
117
 
118
  Recursive approach:
119
 
 
146
  return b
147
  <ENDCODE>
148
 
 
 
149
  <USER> What is recursion?
150
+ <ASSISTANT> Recursion is when a function calls itself to solve a problem.
151
 
152
  Key components:
153
  1. Base case - stops the recursion
 
178
  print(factorial(5)) # Output: 120
179
  <ENDCODE>
180
 
 
 
181
  <USER> Write bubble sort
182
+ <ASSISTANT> Here is the Bubble Sort algorithm:
183
 
184
  <CODE>
185
  def bubble_sort(arr):
 
208
  Original: [64, 34, 25, 12, 22, 11, 90]
209
  Sorted: [11, 12, 22, 25, 34, 64, 90]
210
 
 
 
 
 
211
  <USER> Write binary search
212
+ <ASSISTANT> Here is the Binary Search algorithm:
213
 
214
  <CODE>
215
  def binary_search(arr, target):
 
235
  <ENDCODE>
236
 
237
  Important: The array must be sorted!
238
+ Time complexity: O(log n)
 
239
 
240
  <USER> What is a class?
241
  <ASSISTANT> A class is a blueprint for creating objects. It bundles data (attributes) and functions (methods) together.
242
 
 
 
243
  <CODE>
244
  class Dog:
245
  def __init__(self, name, age):
 
265
  - Methods: functions inside the class
266
 
267
  <USER> How do I read a file?
268
+ <ASSISTANT> Here is how to read files in Python:
 
 
269
 
270
  <CODE>
271
  # Read entire file
 
297
  content = read_file_safely("myfile.txt")
298
  <ENDCODE>
299
 
 
 
300
  <USER> How do I handle errors?
301
  <ASSISTANT> Python uses try-except blocks for error handling:
302
 
 
 
303
  <CODE>
304
  try:
305
  result = 10 / 0
 
314
  number = int(input("Enter number: "))
315
  result = 10 / number
316
  except ValueError:
317
+ print("Not a valid number!")
318
  except ZeroDivisionError:
319
  print("Cannot divide by zero!")
320
  except Exception as e:
 
329
  except SomeError:
330
  print("Handle error")
331
  else:
332
+ print("Success!")
333
  finally:
334
+ print("Always runs")
335
  <ENDCODE>
336
 
 
 
 
 
 
 
 
337
  <USER> Thanks!
338
+ <ASSISTANT> You are welcome! I am glad I could help. Feel free to ask me anything anytime! Happy coding!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  <USER> What is the difference between list and tuple?
341
  <ASSISTANT> The main difference is mutability:
342
+ - Lists are mutable (can be changed)
343
+ - Tuples are immutable (cannot be changed)
344
 
 
 
345
  Syntax:
346
 
347
+ <CODE>
348
+ # List - square brackets
349
+ my_list = [1, 2, 3]
350
+
351
+ # Tuple - parentheses
352
  my_tuple = (1, 2, 3)
353
  <ENDCODE>
354
 
355
  Mutability difference:
356
 
357
+ <CODE>
358
+ # Lists can be modified
359
+ my_list = [1, 2, 3]
360
+ my_list[0] = 100 # OK
361
+ my_list.append(4) # OK
362
 
363
+ # Tuples cannot be modified
364
+ my_tuple = (1, 2, 3)
365
+ # my_tuple[0] = 100 # ERROR!
366
  <ENDCODE>
 
367
 
368
+ When to use:
369
+ - List: data that will change
370
+ - Tuple: fixed data like coordinates
371
 
372
+ <USER> Hi
373
+ <ASSISTANT> Hello! How can I help you today? I can write code, explain concepts, or answer your programming questions!
374
+ '''
 
 
375
 
 
376
 
 
 
 
 
377
  class VedaTrainer:
378
+ """Trainer for Veda Programming Assistant"""
379
+
380
+ def __init__(self, vocab_size: int = 8000, max_length: int = 512, batch_size: int = 4):
381
+ self.vocab_size = vocab_size
382
+ self.max_length = max_length
383
+ self.batch_size = batch_size
384
+ self.tokenizer = VedaTokenizer(vocab_size=vocab_size)
385
+ self.model = None
386
+
387
+ def prepare_data(self, extra_data: str = ""):
388
+ """Prepare training data"""
389
+ data = TRAINING_DATA
390
+ if extra_data:
391
+ data += "\n\n" + extra_data
392
+
393
+ if os.path.exists("programming.txt"):
394
+ with open("programming.txt", 'r', encoding='utf-8') as f:
395
+ code_data = f.read()
396
+ data += "\n\n" + code_data
397
+
398
+ self.tokenizer.fit([data])
399
+
400
+ all_tokens = self.tokenizer.encode(data)
401
+ print(f"Total tokens: {len(all_tokens)}")
402
+
 
 
 
 
 
 
 
 
 
 
 
 
 
403
  sequences = []
404
+ stride = self.max_length // 2
405
+
406
  for i in range(0, len(all_tokens) - self.max_length - 1, stride):
407
  seq = all_tokens[i:i + self.max_length + 1]
408
  if len(seq) == self.max_length + 1:
409
  sequences.append(seq)
410
+
411
+ if len(sequences) < 10:
412
+ stride = self.max_length // 4
413
+ sequences = []
414
+ for i in range(0, len(all_tokens) - self.max_length - 1, stride):
415
+ seq = all_tokens[i:i + self.max_length + 1]
416
+ if len(seq) == self.max_length + 1:
417
+ sequences.append(seq)
418
+
419
+ print(f"Created {len(sequences)} training sequences")
420
+
421
+ sequences = np.array(sequences)
422
+ X = sequences[:, :-1]
423
+ y = sequences[:, 1:]
424
+
425
+ dataset = tf.data.Dataset.from_tensor_slices((X, y))
426
+ dataset = dataset.shuffle(1000).batch(self.batch_size).prefetch(1)
427
+
428
+ return dataset
429
+
430
+ def build_model(self):
431
+ """Build the model"""
432
+ self.model = VedaProgrammingLLM(
433
+ vocab_size=self.tokenizer.vocabulary_size,
434
+ max_length=self.max_length,
435
+ d_model=256,
436
+ num_heads=8,
437
+ num_layers=4,
438
+ ff_dim=512
439
+ )
440
+
441
+ self.model.compile(
442
+ optimizer=keras.optimizers.Adam(1e-4),
443
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
444
+ metrics=['accuracy']
445
+ )
446
+
447
+ dummy = tf.zeros((1, self.max_length), dtype=tf.int32)
448
+ self.model(dummy)
449
+
450
+ return self.model
451
 
452
+ def train(self, epochs: int = 15, save_path: str = None, extra_data: str = ""):
453
+ """Train the model"""
454
+ if save_path is None:
455
+ save_path = MODEL_DIR
456
+
457
+ dataset = self.prepare_data(extra_data)
458
+ self.build_model()
459
+
460
+ self.model.summary()
461
+
462
+ os.makedirs(save_path, exist_ok=True)
463
+
464
+ history = self.model.fit(dataset, epochs=epochs, verbose=1)
465
+
466
+ self.model.save_weights(os.path.join(save_path, "weights.h5"))
467
+ self.tokenizer.save(os.path.join(save_path, "tokenizer.json"))
468
+
469
+ config = self.model.get_config()
470
+ with open(os.path.join(save_path, "config.json"), 'w') as f:
471
+ json.dump(config, f)
472
+
473
+ print(f"Model saved to {save_path}")
474
+ return history
 
 
 
 
 
 
 
475
 
476
+ def generate_response(self, user_input: str, max_tokens: int = 200, temperature: float = 0.7) -> str:
477
+ """Generate a response"""
478
+ prompt = f"<USER> {user_input}\n<ASSISTANT>"
479
+
480
+ tokens = self.tokenizer.encode(prompt)
481
+
482
+ generated = self.model.generate(
483
+ tokens,
484
+ max_new_tokens=max_tokens,
485
+ temperature=temperature,
486
+ repetition_penalty=1.2
487
+ )
488
+
489
+ response = self.tokenizer.decode(generated)
490
+
491
+ if "<ASSISTANT>" in response:
492
+ response = response.split("<ASSISTANT>")[-1].strip()
493
+ if "<USER>" in response:
494
+ response = response.split("<USER>")[0].strip()
495
+
496
+ return response
497
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
498
 
499
+ if __name__ == "__main__":
500
+ trainer = VedaTrainer()
501
+ trainer.train(epochs=20)
 
 
 
 
 
 
 
 
 
 
502
 
503
+ print("\n" + "="*50)
504
+ print("Testing:")
505
+ print("="*50)
506
 
507
+ tests = ["Hello!", "What is a function?", "Write a function to reverse a string"]
 
 
 
 
508
 
509
+ for test in tests:
510
+ print(f"\nUser: {test}")
511
+ print(f"Assistant: {trainer.generate_response(test)}")