vedaco commited on
Commit
ed6fa70
·
verified ·
1 Parent(s): 89d56eb

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +341 -154
train.py CHANGED
@@ -1,191 +1,378 @@
 
 
1
  import tensorflow as tf
2
  from tensorflow import keras
3
  import numpy as np
4
  import os
5
  import json
 
6
  from model import VedaProgrammingLLM
7
  from tokenizer import VedaTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- SAMPLE_CODE = '''
10
- def hello_world():
11
- print("Hello, World!")
12
- return True
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def fibonacci(n):
15
- if n <= 1:
16
- return n
 
 
17
  return fibonacci(n-1) + fibonacci(n-2)
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def factorial(n):
20
- if n == 0:
21
  return 1
22
- return n * factorial(n-1)
23
 
24
- class Calculator:
25
- def __init__(self):
26
- self.result = 0
27
-
28
- def add(self, a, b):
29
- return a + b
30
-
31
- def subtract(self, a, b):
32
- return a - b
33
-
34
- def multiply(self, a, b):
35
- return a * b
36
 
 
 
 
 
37
  def bubble_sort(arr):
38
  n = len(arr)
 
39
  for i in range(n):
40
- for j in range(0, n-i-1):
41
- if arr[j] > arr[j+1]:
42
- arr[j], arr[j+1] = arr[j+1], arr[j]
 
 
 
 
 
 
 
43
  return arr
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def binary_search(arr, target):
46
- left, right = 0, len(arr) - 1
 
 
47
  while left <= right:
48
  mid = (left + right) // 2
 
49
  if arr[mid] == target:
50
  return mid
51
  elif arr[mid] < target:
52
  left = mid + 1
53
  else:
54
  right = mid - 1
55
- return -1
 
56
 
57
- def is_palindrome(s):
58
- s = s.lower()
59
- return s == s[::-1]
 
 
60
 
61
- def sum_list(numbers):
62
- total = 0
63
- for num in numbers:
64
- total += num
65
- return total
66
 
67
- def find_max(arr):
68
- if not arr:
69
- return None
70
- max_val = arr[0]
71
- for val in arr:
72
- if val > max_val:
73
- max_val = val
74
- return max_val
75
- '''
76
-
77
- class VedaTrainer:
78
- """Trainer for Veda Programming LLM"""
79
-
80
- def __init__(self, data_path: str = "programming.txt",
81
- vocab_size: int = 3000, max_length: int = 128, batch_size: int = 8):
82
- self.data_path = data_path
83
- self.vocab_size = vocab_size
84
- self.max_length = max_length
85
- self.batch_size = batch_size
86
- self.tokenizer = VedaTokenizer(vocab_size=vocab_size)
87
- self.model = None
88
-
89
- def load_data(self):
90
- """Load training data"""
91
- if not os.path.exists(self.data_path):
92
- with open(self.data_path, 'w') as f:
93
- f.write(SAMPLE_CODE)
94
-
95
- with open(self.data_path, 'r', encoding='utf-8') as f:
96
- content = f.read()
97
-
98
- samples = [s.strip() for s in content.split('\n\n') if s.strip()]
99
- print(f"Loaded {len(samples)} samples")
100
- return samples
101
-
102
- def prepare_dataset(self, samples):
103
- """Prepare TensorFlow dataset"""
104
- self.tokenizer.fit(samples)
105
-
106
- all_tokens = []
107
- for sample in samples:
108
- tokens = self.tokenizer.encode(sample)
109
- all_tokens.extend(tokens)
110
-
111
- sequences = []
112
- step = max(1, self.max_length // 4)
113
- for i in range(0, len(all_tokens) - self.max_length, step):
114
- seq = all_tokens[i:i + self.max_length + 1]
115
- if len(seq) == self.max_length + 1:
116
- sequences.append(seq)
117
-
118
- if len(sequences) < 2:
119
- for sample in samples:
120
- tokens = self.tokenizer.encode(sample, max_length=self.max_length + 1)
121
- sequences.append(tokens)
122
-
123
- sequences = np.array(sequences[:100]) # Limit for memory
124
- X = sequences[:, :-1]
125
- y = sequences[:, 1:]
126
-
127
- dataset = tf.data.Dataset.from_tensor_slices((X, y))
128
- dataset = dataset.shuffle(100).batch(self.batch_size).prefetch(1)
129
-
130
- print(f"Created {len(sequences)} sequences")
131
- return dataset
132
-
133
- def build_model(self):
134
- """Build the model"""
135
- self.model = VedaProgrammingLLM(
136
- vocab_size=self.tokenizer.vocabulary_size,
137
- max_length=self.max_length,
138
- d_model=128,
139
- num_heads=4,
140
- num_layers=2,
141
- ff_dim=256
142
- )
143
-
144
- self.model.compile(
145
- optimizer=keras.optimizers.Adam(1e-4),
146
- loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
147
- metrics=['accuracy']
148
- )
149
-
150
- # Build
151
- dummy = tf.zeros((1, self.max_length), dtype=tf.int32)
152
- self.model(dummy)
153
- self.model.summary()
154
- return self.model
155
 
156
- def train(self, epochs: int = 5, save_path: str = "veda_model"):
157
- """Train the model"""
158
- samples = self.load_data()
159
- dataset = self.prepare_dataset(samples)
160
- self.build_model()
161
-
162
- os.makedirs(save_path, exist_ok=True)
163
-
164
- history = self.model.fit(dataset, epochs=epochs, verbose=1)
165
-
166
- # Save weights only (more reliable)
167
- self.model.save_weights(os.path.join(save_path, "weights.h5"))
168
- self.tokenizer.save(os.path.join(save_path, "tokenizer.json"))
169
-
170
- config = self.model.get_config()
171
- with open(os.path.join(save_path, "config.json"), 'w') as f:
172
- json.dump(config, f)
173
-
174
- print(f"Model saved to {save_path}")
175
- return history
176
 
177
- def generate(self, prompt: str, max_tokens: int = 50, temperature: float = 0.8):
178
- """Generate code"""
179
- if self.model is None:
180
- raise ValueError("Model not loaded")
181
-
182
- tokens = self.tokenizer.encode(prompt)
183
- generated = self.model.generate(tokens, max_tokens, temperature)
184
- return self.tokenizer.decode(generated)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
 
 
186
 
187
- if __name__ == "__main__":
188
- trainer = VedaTrainer()
189
- trainer.train(epochs=10)
190
- print("\nTest generation:")
191
- print(trainer.generate("def calculate"))
 
1
+ """Training - MODIFIED with conversation data"""
2
+
3
  import tensorflow as tf
4
  from tensorflow import keras
5
  import numpy as np
6
  import os
7
  import json
8
+
9
  from model import VedaProgrammingLLM
10
  from tokenizer import VedaTokenizer
11
+ from config import MODEL_DIR, VOCAB_SIZE, MAX_LENGTH, BATCH_SIZE
12
+
13
+ # ADDED: Conversation training data
14
+ TRAINING_DATA = '''
15
+ <USER> Hello!
16
+ <ASSISTANT> Hello! Welcome to Veda Programming Assistant! I'm here to help you with coding, explain concepts, or answer any programming questions. What would you like to learn today?
17
+
18
+ <USER> What is Python?
19
+ <ASSISTANT> Python is a high-level, interpreted programming language created by Guido van Rossum in 1991. It's known for its simple, readable syntax that emphasizes code readability.
20
+
21
+ Key features:
22
+ - Easy to learn and read
23
+ - Dynamic typing
24
+ - Large standard library
25
+ - Great for beginners and experts alike
26
+
27
+ Python is used for web development, data science, AI, automation, and much more!
28
+
29
+ <USER> How do I print hello world?
30
+ <ASSISTANT> To print "Hello, World!" in Python, use the print() function:
31
+
32
+ <CODE>
33
+ print("Hello, World!")
34
+ <ENDCODE>
35
+
36
+ When you run this, it displays: Hello, World!
37
+
38
+ The print() function outputs text to the console. You can print strings, numbers, and variables!
39
+
40
+ <USER> Write a function to add two numbers
41
+ <ASSISTANT> Here's a simple function to add two numbers:
42
+
43
+ <CODE>
44
+ def add(a, b):
45
+ return a + b
46
+
47
+ # Example usage
48
+ result = add(5, 3)
49
+ print(result) # Output: 8
50
+ <ENDCODE>
51
+
52
+ This function takes two parameters and returns their sum.
53
+
54
+ <USER> Explain what a function is
55
+ <ASSISTANT> A function is a reusable block of code that performs a specific task. Functions help organize code, avoid repetition, and make programs easier to understand.
56
+
57
+ Here's the basic structure:
58
+
59
+ <CODE>
60
+ def function_name(parameters):
61
+ # Code block
62
+ return result
63
+ <ENDCODE>
64
+
65
+ Parts of a function:
66
+ - def: keyword to define a function
67
+ - function_name: what you call your function
68
+ - parameters: input values (optional)
69
+ - return: sends back a result (optional)
70
+
71
+ Example:
72
+
73
+ <CODE>
74
+ def greet(name):
75
+ message = f"Hello, {name}!"
76
+ return message
77
+
78
+ print(greet("Alice")) # Output: Hello, Alice!
79
+ <ENDCODE>
80
+
81
+ <USER> What is a list?
82
+ <ASSISTANT> A list is an ordered, mutable collection in Python that can hold items of any type.
83
+
84
+ Creating lists:
85
+
86
+ <CODE>
87
+ # Empty list
88
+ my_list = []
89
+
90
+ # List with items
91
+ numbers = [1, 2, 3, 4, 5]
92
+ fruits = ["apple", "banana", "cherry"]
93
+ mixed = [1, "hello", 3.14, True]
94
+ <ENDCODE>
95
+
96
+ Common operations:
97
+
98
+ <CODE>
99
+ fruits = ["apple", "banana"]
100
+
101
+ # Add items
102
+ fruits.append("cherry") # Add to end
103
+ fruits.insert(0, "apricot") # Insert at position
104
 
105
+ # Access items
106
+ print(fruits[0]) # First item
107
+ print(fruits[-1]) # Last item
 
108
 
109
+ # Remove items
110
+ fruits.remove("banana") # Remove by value
111
+ fruits.pop() # Remove last item
112
+
113
+ # Length
114
+ print(len(fruits))
115
+ <ENDCODE>
116
+
117
+ <USER> Write a fibonacci function
118
+ <ASSISTANT> Here's a Fibonacci function with two approaches:
119
+
120
+ Recursive approach:
121
+
122
+ <CODE>
123
  def fibonacci(n):
124
+ if n <= 0:
125
+ return 0
126
+ if n == 1:
127
+ return 1
128
  return fibonacci(n-1) + fibonacci(n-2)
129
 
130
+ # Print first 10 numbers
131
+ for i in range(10):
132
+ print(fibonacci(i), end=" ")
133
+ # Output: 0 1 1 2 3 5 8 13 21 34
134
+ <ENDCODE>
135
+
136
+ Iterative approach (more efficient):
137
+
138
+ <CODE>
139
+ def fibonacci_iterative(n):
140
+ if n <= 0:
141
+ return 0
142
+ if n == 1:
143
+ return 1
144
+
145
+ a, b = 0, 1
146
+ for _ in range(2, n + 1):
147
+ a, b = b, a + b
148
+ return b
149
+ <ENDCODE>
150
+
151
+ The Fibonacci sequence: each number is the sum of the two before it!
152
+
153
+ <USER> What is recursion?
154
+ <ASSISTANT> Recursion is when a function calls itself to solve a problem. It breaks complex problems into smaller, similar sub-problems.
155
+
156
+ Key components:
157
+ 1. Base case - stops the recursion
158
+ 2. Recursive case - function calls itself
159
+
160
+ Example - countdown:
161
+
162
+ <CODE>
163
+ def countdown(n):
164
+ if n <= 0: # Base case
165
+ print("Done!")
166
+ return
167
+ print(n)
168
+ countdown(n - 1) # Recursive call
169
+
170
+ countdown(5)
171
+ # Output: 5, 4, 3, 2, 1, Done!
172
+ <ENDCODE>
173
+
174
+ Another example - factorial:
175
+
176
+ <CODE>
177
  def factorial(n):
178
+ if n <= 1: # Base case
179
  return 1
180
+ return n * factorial(n - 1) # Recursive call
181
 
182
+ print(factorial(5)) # Output: 120
183
+ <ENDCODE>
184
+
185
+ Warning: Always have a base case to avoid infinite recursion!
 
 
 
 
 
 
 
 
186
 
187
+ <USER> Write bubble sort
188
+ <ASSISTANT> Here's the Bubble Sort algorithm:
189
+
190
+ <CODE>
191
  def bubble_sort(arr):
192
  n = len(arr)
193
+
194
  for i in range(n):
195
+ swapped = False
196
+
197
+ for j in range(0, n - i - 1):
198
+ if arr[j] > arr[j + 1]:
199
+ arr[j], arr[j + 1] = arr[j + 1], arr[j]
200
+ swapped = True
201
+
202
+ if not swapped:
203
+ break
204
+
205
  return arr
206
 
207
+ # Example
208
+ numbers = [64, 34, 25, 12, 22, 11, 90]
209
+ print("Original:", numbers)
210
+ print("Sorted:", bubble_sort(numbers.copy()))
211
+ <ENDCODE>
212
+
213
+ Output:
214
+ Original: [64, 34, 25, 12, 22, 11, 90]
215
+ Sorted: [11, 12, 22, 25, 34, 64, 90]
216
+
217
+ How it works: repeatedly compares adjacent elements and swaps them if in wrong order. Largest elements "bubble up" to the end.
218
+
219
+ Time complexity: O(n²)
220
+
221
+ <USER> Write binary search
222
+ <ASSISTANT> Here's the Binary Search algorithm:
223
+
224
+ <CODE>
225
  def binary_search(arr, target):
226
+ left = 0
227
+ right = len(arr) - 1
228
+
229
  while left <= right:
230
  mid = (left + right) // 2
231
+
232
  if arr[mid] == target:
233
  return mid
234
  elif arr[mid] < target:
235
  left = mid + 1
236
  else:
237
  right = mid - 1
238
+
239
+ return -1 # Not found
240
 
241
+ # Example
242
+ numbers = [1, 3, 5, 7, 9, 11, 13, 15]
243
+ print(binary_search(numbers, 7)) # Output: 3
244
+ print(binary_search(numbers, 6)) # Output: -1
245
+ <ENDCODE>
246
 
247
+ Important: The array must be sorted!
 
 
 
 
248
 
249
+ Time complexity: O(log n) - much faster than linear search!
250
+
251
+ <USER> What is a class?
252
+ <ASSISTANT> A class is a blueprint for creating objects. It bundles data (attributes) and functions (methods) together.
253
+
254
+ Basic class structure:
255
+
256
+ <CODE>
257
+ class Dog:
258
+ def __init__(self, name, age):
259
+ self.name = name
260
+ self.age = age
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
+ def bark(self):
263
+ return f"{self.name} says Woof!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
+ def get_info(self):
266
+ return f"{self.name} is {self.age} years old"
267
+
268
+ # Create objects
269
+ my_dog = Dog("Buddy", 3)
270
+ print(my_dog.bark()) # Buddy says Woof!
271
+ print(my_dog.get_info()) # Buddy is 3 years old
272
+ <ENDCODE>
273
+
274
+ Key concepts:
275
+ - __init__: constructor, runs when object is created
276
+ - self: refers to the current object
277
+ - Attributes: variables inside the class
278
+ - Methods: functions inside the class
279
+
280
+ <USER> How do I read a file?
281
+ <ASSISTANT> Here's how to read files in Python:
282
+
283
+ Basic file reading:
284
+
285
+ <CODE>
286
+ # Read entire file
287
+ with open("filename.txt", "r") as file:
288
+ content = file.read()
289
+ print(content)
290
+
291
+ # Read line by line
292
+ with open("filename.txt", "r") as file:
293
+ for line in file:
294
+ print(line.strip())
295
+
296
+ # Read all lines into list
297
+ with open("filename.txt", "r") as file:
298
+ lines = file.readlines()
299
+ <ENDCODE>
300
+
301
+ With error handling:
302
+
303
+ <CODE>
304
+ def read_file_safely(filename):
305
+ try:
306
+ with open(filename, "r") as file:
307
+ return file.read()
308
+ except FileNotFoundError:
309
+ print(f"File '{filename}' not found")
310
+ return None
311
+
312
+ content = read_file_safely("myfile.txt")
313
+ <ENDCODE>
314
+
315
+ The 'with' statement automatically closes the file when done - always recommended!
316
+
317
+ <USER> How do I handle errors?
318
+ <ASSISTANT> Python uses try-except blocks for error handling:
319
+
320
+ Basic try-except:
321
+
322
+ <CODE>
323
+ try:
324
+ result = 10 / 0
325
+ except ZeroDivisionError:
326
+ print("Cannot divide by zero!")
327
+ <ENDCODE>
328
+
329
+ Multiple exceptions:
330
+
331
+ <CODE>
332
+ try:
333
+ number = int(input("Enter number: "))
334
+ result = 10 / number
335
+ except ValueError:
336
+ print("That's not a valid number!")
337
+ except ZeroDivisionError:
338
+ print("Cannot divide by zero!")
339
+ except Exception as e:
340
+ print(f"Error: {e}")
341
+ <ENDCODE>
342
+
343
+ Complete structure:
344
+
345
+ <CODE>
346
+ try:
347
+ result = risky_operation()
348
+ except SomeError:
349
+ print("Handle error")
350
+ else:
351
+ print("Success!") # Runs if no error
352
+ finally:
353
+ print("Always runs") # Cleanup code
354
+ <ENDCODE>
355
+
356
+ Common exceptions:
357
+ - ValueError: wrong value
358
+ - TypeError: wrong type
359
+ - IndexError: index out of range
360
+ - KeyError: dictionary key not found
361
+ - FileNotFoundError: file doesn't exist
362
+
363
+ <USER> Thanks!
364
+ <ASSISTANT> You're welcome! I'm glad I could help. Remember:
365
+
366
+ - Practice coding every day
367
+ - Build small projects
368
+ - Don't fear errors - they help you learn
369
+ - Ask questions whenever stuck
370
+
371
+ Feel free to ask me anything anytime! Happy coding! 🐍
372
 
373
+ <USER> debug this: def add(a b): return a + b
374
+ <ASSISTANT> I found the issue! You're missing a comma between parameters.
375
 
376
+ Your code:
377
+ ```python
378
+ def add(a b): return a + b # ❌ Missing comma