vedaco commited on
Commit
4ce71b0
Β·
verified Β·
1 Parent(s): 64022af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +259 -142
app.py CHANGED
@@ -1,10 +1,9 @@
1
- """Gradio App for Veda Programming Assistant - Gradio 6.x compatible (with math solving)"""
2
 
3
  import gradio as gr
4
  import tensorflow as tf
5
  import os
6
  import json
7
-
8
  import re
9
  import ast
10
  import operator as op
@@ -13,37 +12,29 @@ from model import VedaProgrammingLLM
13
  from tokenizer import VedaTokenizer
14
  from database import db
15
  from train import VedaTrainer
16
- from config import MODEL_DIR
 
17
 
18
 
19
  # --------- Globals ----------
20
  model = None
21
  tokenizer = None
22
- conversation_history = [] # used for building prompt context for the model
23
  current_conv_id = -1
24
 
25
 
26
- # --------- Helpers (Gradio message parsing) ----------
27
  def extract_text(message):
28
- """
29
- Convert Gradio multimodal / messages objects -> plain string.
30
- Handles:
31
- - str
32
- - dict: {"text": "..."} or {"content": "..."}
33
- - list of parts: [{"type":"text","text":"..."}]
34
- """
35
  if message is None:
36
  return ""
37
  if isinstance(message, str):
38
  return message
39
-
40
  if isinstance(message, dict):
41
  if "text" in message:
42
  return str(message.get("text", ""))
43
  if "content" in message:
44
  return extract_text(message["content"])
45
  return ""
46
-
47
  if isinstance(message, list):
48
  parts = []
49
  for part in message:
@@ -52,33 +43,17 @@ def extract_text(message):
52
  elif isinstance(part, str):
53
  parts.append(part)
54
  return "".join(parts).strip()
55
-
56
  return str(message)
57
 
58
 
59
  def ensure_messages_history(history):
60
- """
61
- Ensure Chatbot history is ALWAYS messages format:
62
- [{"role":"user","content":"..."}, {"role":"assistant","content":"..."}]
63
-
64
- Also converts old tuple format [(user, bot), ...] -> messages.
65
- """
66
  if history is None:
67
  return []
68
-
69
- # Already messages format
70
- if (
71
- len(history) > 0
72
- and isinstance(history[0], dict)
73
- and "role" in history[0]
74
- and "content" in history[0]
75
- ):
76
  fixed = []
77
  for m in history:
78
  fixed.append({"role": m["role"], "content": extract_text(m["content"])})
79
  return fixed
80
-
81
- # Tuple/pair format -> messages format
82
  fixed = []
83
  for pair in history:
84
  if isinstance(pair, (list, tuple)) and len(pair) == 2:
@@ -87,7 +62,7 @@ def ensure_messages_history(history):
87
  return fixed
88
 
89
 
90
- # --------- Safe Math Solver ----------
91
  _ALLOWED_OPS = {
92
  ast.Add: op.add,
93
  ast.Sub: op.sub,
@@ -101,10 +76,6 @@ _ALLOWED_OPS = {
101
 
102
 
103
  def safe_eval_math(expr: str):
104
- """
105
- Safely evaluate arithmetic expression (no variables, no function calls).
106
- Supports: + - * / % ** and parentheses, integers/floats.
107
- """
108
  node = ast.parse(expr, mode="eval").body
109
 
110
  def _eval(n):
@@ -114,45 +85,79 @@ def safe_eval_math(expr: str):
114
  return _ALLOWED_OPS[type(n.op)](_eval(n.left), _eval(n.right))
115
  if isinstance(n, ast.UnaryOp) and type(n.op) in _ALLOWED_OPS:
116
  return _ALLOWED_OPS[type(n.op)](_eval(n.operand))
117
- raise ValueError("Unsupported expression")
118
 
119
  return _eval(node)
120
 
121
 
122
  def try_math_answer(user_text: str):
123
- """
124
- If user text looks like a pure math expression, return computed answer as string.
125
- Otherwise return None.
126
- Examples:
127
- "2+2=?" -> "4"
128
- "2^5" -> "32"
129
- "(10+5)/3" -> "5"
130
- """
131
  if not user_text:
132
  return None
133
-
134
- # Normalize common decorations
135
- s = user_text.strip()
136
- s = s.replace("=", "").replace("?", "").strip()
137
- s = s.replace("^", "**") # allow ^ as power
138
-
139
- # Only allow digits/operators/parentheses/dots/spaces
140
  if not re.fullmatch(r"[0-9\.\s\+\-\*\/\(\)%]+", s):
141
  return None
142
-
143
  try:
144
  val = safe_eval_math(s)
145
- # pretty formatting: 4.0 -> 4
146
  if isinstance(val, float) and val.is_integer():
147
  val = int(val)
148
  return str(val)
149
- except Exception:
150
  return None
151
 
152
 
153
- # --------- Model init ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  def initialize():
155
- """Initialize the assistant (load if exists, else train once)."""
156
  global model, tokenizer
157
 
158
  print("Initializing Veda Programming Assistant...")
@@ -183,7 +188,7 @@ def initialize():
183
 
184
  print("Model loaded!")
185
  else:
186
- print("No saved model found. Training a new model...")
187
  trainer = VedaTrainer()
188
  trainer.train(epochs=15)
189
  model = trainer.model
@@ -192,17 +197,19 @@ def initialize():
192
 
193
 
194
  def clean_response(text: str) -> str:
195
- """Clean the response text for display."""
 
 
196
  text = text.replace("<CODE>", "\n```python\n")
197
  text = text.replace("<ENDCODE>", "\n```\n")
198
-
199
  for token in ["<PAD>", "<UNK>", "<START>", "<END>", "<USER>", "<ASSISTANT>"]:
200
  text = text.replace(token, "")
201
-
202
  lines = text.split("\n")
203
  cleaned = []
204
  empty_count = 0
205
-
206
  for line in lines:
207
  if line.strip() == "":
208
  empty_count += 1
@@ -211,38 +218,21 @@ def clean_response(text: str) -> str:
211
  else:
212
  empty_count = 0
213
  cleaned.append(line)
214
-
215
  return "\n".join(cleaned).strip()
216
 
217
 
218
- def generate_response(user_input: str, temperature: float = 0.7, max_tokens: int = 200) -> str:
219
- """Generate a response from the model OR solve math deterministically."""
220
- global current_conv_id, conversation_history
221
-
222
- # Convert Gradio multimodal -> text
223
- user_input = extract_text(user_input).strip()
224
- if not user_input:
225
- return "Please type a message!"
226
-
227
- # 1) Try math solver first
228
- math_ans = try_math_answer(user_input)
229
- if math_ans is not None:
230
- # Save conversation too (optional)
231
- conversation_history.append({"user": user_input, "assistant": math_ans})
232
- current_conv_id = db.save_conversation(user_input, math_ans)
233
- return math_ans
234
-
235
- # 2) Otherwise use model
236
- if model is None:
237
- return "Model is loading, please wait..."
238
-
239
  try:
240
  context = ""
241
  for msg in conversation_history[-3:]:
242
  context += f"<USER> {msg['user']}\n<ASSISTANT> {msg['assistant']}\n"
243
 
244
  prompt = context + f"<USER> {user_input}\n<ASSISTANT>"
245
-
246
  tokens = tokenizer.encode(prompt)
247
 
248
  if len(tokens) > model.max_length - max_tokens:
@@ -264,27 +254,100 @@ def generate_response(user_input: str, temperature: float = 0.7, max_tokens: int
264
  if "<USER>" in response:
265
  response = response.split("<USER>")[0].strip()
266
 
267
- response = clean_response(response)
 
 
 
 
 
268
 
269
- if not response:
270
- response = "I'm not sure how to respond to that. Could you try rephrasing?"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
- conversation_history.append({"user": user_input, "assistant": response})
273
- current_conv_id = db.save_conversation(user_input, response)
274
 
275
- return response
 
 
276
 
277
- except Exception as e:
278
- import traceback
279
- traceback.print_exc()
280
- return f"Error: {str(e)}"
281
 
 
 
 
 
 
 
282
 
283
- # --------- Gradio handlers ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  def respond(message, history, temperature, max_tokens):
285
- """Always return messages-format history."""
286
  history = ensure_messages_history(history)
287
-
288
  user_text = extract_text(message).strip()
289
  if not user_text:
290
  return "", history
@@ -298,19 +361,17 @@ def respond(message, history, temperature, max_tokens):
298
 
299
 
300
  def feedback_good():
301
- global current_conv_id
302
  if current_conv_id > 0:
303
  db.update_feedback(current_conv_id, 1)
304
- return "πŸ‘ Thanks for the positive feedback!"
305
- return "No conversation to rate yet."
306
 
307
 
308
  def feedback_bad():
309
- global current_conv_id
310
  if current_conv_id > 0:
311
  db.update_feedback(current_conv_id, -1)
312
- return "πŸ‘Ž Thanks! I'll try to improve."
313
- return "No conversation to rate yet."
314
 
315
 
316
  def clear_chat():
@@ -319,56 +380,111 @@ def clear_chat():
319
  return [], "Chat cleared."
320
 
321
 
322
- def retrain(epochs):
323
- """Retrain with good conversations."""
324
  global model, tokenizer
325
 
 
326
  good_convs = db.get_good_conversations()
327
- if not good_convs:
328
- return "No approved conversations yet. Rate some responses as 'Good' first!"
329
-
330
  extra_data = ""
331
  for conv in good_convs:
332
  extra_data += f"<USER> {conv['user_input']}\n"
333
  extra_data += f"<ASSISTANT> {conv['assistant_response']}\n\n"
334
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  trainer = VedaTrainer()
336
- history = trainer.train(epochs=int(epochs), extra_data=extra_data)
 
 
 
 
337
 
338
  model = trainer.model
339
  tokenizer = trainer.tokenizer
340
 
 
 
 
 
 
341
  loss = history.history["loss"][-1]
342
- return f"βœ… Training complete! Loss: {loss:.4f}, Used {len(good_convs)} conversations"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
 
344
 
345
  def get_stats():
346
  stats = db.get_stats()
 
 
347
  return f"""## πŸ“Š Statistics
348
 
 
349
  | Metric | Count |
350
  |--------|-------|
351
- | πŸ’¬ Total Conversations | {stats['total']} |
352
- | πŸ‘ Positive Feedback | {stats['positive']} |
353
- | πŸ‘Ž Negative Feedback | {stats['negative']} |
 
 
 
 
 
 
 
354
  """
355
 
356
 
357
  # --------- Startup ----------
358
- print("Starting initialization...")
 
 
359
  initialize()
360
- print("Initialization complete!")
 
 
 
 
 
 
 
361
 
362
 
363
  # --------- UI ----------
364
  with gr.Blocks(title="Veda Programming Assistant") as demo:
365
- gr.Markdown(
366
- """
367
  # πŸ•‰οΈ Veda Programming Assistant
368
 
369
- Now supports **math** (e.g., `2+2=?`, `(10+5)/3`, `2^5`) plus coding/chatting.
370
- """
371
- )
 
372
 
373
  with gr.Tabs():
374
  with gr.TabItem("πŸ’¬ Chat"):
@@ -377,7 +493,7 @@ Now supports **math** (e.g., `2+2=?`, `(10+5)/3`, `2^5`) plus coding/chatting.
377
  with gr.Row():
378
  msg = gr.Textbox(
379
  label="Your message",
380
- placeholder="Ask me anything about programming... or type math like 2+2=?",
381
  lines=2,
382
  scale=4,
383
  )
@@ -394,9 +510,8 @@ Now supports **math** (e.g., `2+2=?`, `(10+5)/3`, `2^5`) plus coding/chatting.
394
 
395
  feedback_msg = gr.Textbox(label="Status", lines=1, interactive=False)
396
 
397
- send_btn.click(respond, inputs=[msg, chatbot, temperature, max_tokens], outputs=[msg, chatbot])
398
- msg.submit(respond, inputs=[msg, chatbot, temperature, max_tokens], outputs=[msg, chatbot])
399
-
400
  good_btn.click(feedback_good, outputs=feedback_msg)
401
  bad_btn.click(feedback_bad, outputs=feedback_msg)
402
  clear_btn.click(clear_chat, outputs=[chatbot, feedback_msg])
@@ -404,38 +519,40 @@ Now supports **math** (e.g., `2+2=?`, `(10+5)/3`, `2^5`) plus coding/chatting.
404
  gr.Markdown("### πŸ’‘ Examples")
405
  gr.Examples(
406
  examples=[
407
- ["2+2=?"],
408
- ["(10+5)/3"],
409
- ["2^8"],
410
  ["What is Python?"],
411
- ["Write a function to calculate factorial"],
412
  ["Explain recursion"],
 
 
 
413
  ],
414
  inputs=msg,
415
  )
416
 
417
  with gr.TabItem("πŸŽ“ Training"):
418
- gr.Markdown(
419
- """
420
- ### Improve the Assistant
421
 
422
- 1. Chat with the assistant
423
- 2. Rate good responses with πŸ‘
424
- 3. Click "Retrain Model" to learn from good conversations
425
- """
426
- )
427
 
428
- train_epochs = gr.Slider(5, 20, 10, step=1, label="Training Epochs")
429
- train_btn = gr.Button("πŸ”„ Retrain Model", variant="primary")
 
 
 
430
  train_output = gr.Markdown()
431
- train_btn.click(retrain, inputs=[train_epochs], outputs=train_output)
 
432
 
433
  with gr.TabItem("πŸ“Š Statistics"):
434
  stats_out = gr.Markdown()
435
- refresh_btn = gr.Button("πŸ”„ Refresh Statistics")
436
  refresh_btn.click(get_stats, outputs=stats_out)
437
 
438
- gr.Markdown("---\n**Veda Programming Assistant**")
439
 
440
 
441
  if __name__ == "__main__":
 
1
+ """Gradio App for Veda Programming Assistant - Fixed Distillation"""
2
 
3
  import gradio as gr
4
  import tensorflow as tf
5
  import os
6
  import json
 
7
  import re
8
  import ast
9
  import operator as op
 
12
  from tokenizer import VedaTokenizer
13
  from database import db
14
  from train import VedaTrainer
15
+ from teacher import teacher
16
+ from config import MODEL_DIR, DISTILLATION_ENABLED
17
 
18
 
19
  # --------- Globals ----------
20
  model = None
21
  tokenizer = None
22
+ conversation_history = []
23
  current_conv_id = -1
24
 
25
 
26
+ # --------- Helpers ----------
27
  def extract_text(message):
 
 
 
 
 
 
 
28
  if message is None:
29
  return ""
30
  if isinstance(message, str):
31
  return message
 
32
  if isinstance(message, dict):
33
  if "text" in message:
34
  return str(message.get("text", ""))
35
  if "content" in message:
36
  return extract_text(message["content"])
37
  return ""
 
38
  if isinstance(message, list):
39
  parts = []
40
  for part in message:
 
43
  elif isinstance(part, str):
44
  parts.append(part)
45
  return "".join(parts).strip()
 
46
  return str(message)
47
 
48
 
49
  def ensure_messages_history(history):
 
 
 
 
 
 
50
  if history is None:
51
  return []
52
+ if len(history) > 0 and isinstance(history[0], dict) and "role" in history[0]:
 
 
 
 
 
 
 
53
  fixed = []
54
  for m in history:
55
  fixed.append({"role": m["role"], "content": extract_text(m["content"])})
56
  return fixed
 
 
57
  fixed = []
58
  for pair in history:
59
  if isinstance(pair, (list, tuple)) and len(pair) == 2:
 
62
  return fixed
63
 
64
 
65
+ # --------- Math Solver ----------
66
  _ALLOWED_OPS = {
67
  ast.Add: op.add,
68
  ast.Sub: op.sub,
 
76
 
77
 
78
  def safe_eval_math(expr: str):
 
 
 
 
79
  node = ast.parse(expr, mode="eval").body
80
 
81
  def _eval(n):
 
85
  return _ALLOWED_OPS[type(n.op)](_eval(n.left), _eval(n.right))
86
  if isinstance(n, ast.UnaryOp) and type(n.op) in _ALLOWED_OPS:
87
  return _ALLOWED_OPS[type(n.op)](_eval(n.operand))
88
+ raise ValueError("Unsupported")
89
 
90
  return _eval(node)
91
 
92
 
93
  def try_math_answer(user_text: str):
 
 
 
 
 
 
 
 
94
  if not user_text:
95
  return None
96
+ s = user_text.strip().replace("=", "").replace("?", "").strip().replace("^", "**")
 
 
 
 
 
 
97
  if not re.fullmatch(r"[0-9\.\s\+\-\*\/\(\)%]+", s):
98
  return None
 
99
  try:
100
  val = safe_eval_math(s)
 
101
  if isinstance(val, float) and val.is_integer():
102
  val = int(val)
103
  return str(val)
104
+ except:
105
  return None
106
 
107
 
108
+ # --------- Response Quality Check ----------
109
+ def is_good_response(response: str) -> bool:
110
+ """Check if student response is good quality"""
111
+ if not response:
112
+ return False
113
+
114
+ response = response.strip()
115
+
116
+ # Too short
117
+ if len(response) < 20:
118
+ return False
119
+
120
+ # Contains gibberish patterns
121
+ gibberish_patterns = [
122
+ r'\["\]',
123
+ r'arr\[\s*a',
124
+ r'print\s*\(\s*"\s*,',
125
+ r'=\s+=\s+=',
126
+ r'\[\.\]',
127
+ r'return\s+if\s+is',
128
+ r'\s{10,}', # Too many spaces
129
+ r'(\w)\1{5,}', # Repeated characters
130
+ ]
131
+
132
+ for pattern in gibberish_patterns:
133
+ if re.search(pattern, response):
134
+ return False
135
+
136
+ # Too many special characters compared to letters
137
+ letters = sum(1 for c in response if c.isalpha())
138
+ special = sum(1 for c in response if c in '[]{}()=<>|\\')
139
+ if letters > 0 and special / letters > 0.5:
140
+ return False
141
+
142
+ # Check for common error phrases
143
+ error_phrases = [
144
+ "i'm not sure",
145
+ "i don't know",
146
+ "could you try rephrasing",
147
+ "error:",
148
+ "cannot understand",
149
+ ]
150
+
151
+ response_lower = response.lower()
152
+ for phrase in error_phrases:
153
+ if phrase in response_lower:
154
+ return False
155
+
156
+ return True
157
+
158
+
159
+ # --------- Model Init ----------
160
  def initialize():
 
161
  global model, tokenizer
162
 
163
  print("Initializing Veda Programming Assistant...")
 
188
 
189
  print("Model loaded!")
190
  else:
191
+ print("Training new model...")
192
  trainer = VedaTrainer()
193
  trainer.train(epochs=15)
194
  model = trainer.model
 
197
 
198
 
199
  def clean_response(text: str) -> str:
200
+ if not text:
201
+ return ""
202
+
203
  text = text.replace("<CODE>", "\n```python\n")
204
  text = text.replace("<ENDCODE>", "\n```\n")
205
+
206
  for token in ["<PAD>", "<UNK>", "<START>", "<END>", "<USER>", "<ASSISTANT>"]:
207
  text = text.replace(token, "")
208
+
209
  lines = text.split("\n")
210
  cleaned = []
211
  empty_count = 0
212
+
213
  for line in lines:
214
  if line.strip() == "":
215
  empty_count += 1
 
218
  else:
219
  empty_count = 0
220
  cleaned.append(line)
221
+
222
  return "\n".join(cleaned).strip()
223
 
224
 
225
+ def get_student_response(user_input: str, temperature: float = 0.7, max_tokens: int = 200) -> str:
226
+ """Get response from student model (Veda)"""
227
+ if model is None or tokenizer is None:
228
+ return ""
229
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  try:
231
  context = ""
232
  for msg in conversation_history[-3:]:
233
  context += f"<USER> {msg['user']}\n<ASSISTANT> {msg['assistant']}\n"
234
 
235
  prompt = context + f"<USER> {user_input}\n<ASSISTANT>"
 
236
  tokens = tokenizer.encode(prompt)
237
 
238
  if len(tokens) > model.max_length - max_tokens:
 
254
  if "<USER>" in response:
255
  response = response.split("<USER>")[0].strip()
256
 
257
+ return clean_response(response)
258
+
259
+ except Exception as e:
260
+ print(f"Student model error: {e}")
261
+ return ""
262
+
263
 
264
+ def get_teacher_response(user_input: str) -> str:
265
+ """Get response from teacher model (Dolphin Mistral)"""
266
+ try:
267
+ # Build conversation history for teacher
268
+ conv_history = []
269
+ for msg in conversation_history[-4:]:
270
+ conv_history.append({"role": "user", "content": msg["user"]})
271
+ conv_history.append({"role": "assistant", "content": msg["assistant"]})
272
+
273
+ response = teacher.ask(
274
+ user_message=user_input,
275
+ conversation_history=conv_history,
276
+ )
277
+
278
+ return response if response else ""
279
+
280
+ except Exception as e:
281
+ print(f"Teacher model error: {e}")
282
+ return ""
283
 
 
 
284
 
285
+ def generate_response(user_input: str, temperature: float = 0.7, max_tokens: int = 200) -> str:
286
+ """Generate a response - uses teacher if student fails"""
287
+ global current_conv_id, conversation_history
288
 
289
+ user_input = extract_text(user_input).strip()
290
+ if not user_input:
291
+ return "Please type a message!"
 
292
 
293
+ # 1) Try math first
294
+ math_ans = try_math_answer(user_input)
295
+ if math_ans is not None:
296
+ conversation_history.append({"user": user_input, "assistant": math_ans})
297
+ current_conv_id = db.save_conversation(user_input, math_ans)
298
+ return math_ans
299
 
300
+ # 2) Try student model
301
+ print(f"[Student] Generating response for: {user_input[:50]}...")
302
+ student_response = get_student_response(user_input, temperature, max_tokens)
303
+
304
+ # 3) Check if student response is good
305
+ if is_good_response(student_response):
306
+ print("[Student] Response is good quality, using it.")
307
+ final_response = student_response
308
+ source = "student"
309
+ else:
310
+ # 4) Student failed, ask teacher
311
+ print("[Student] Response is poor quality, asking teacher...")
312
+ print(f"[Student Bad Response]: {student_response[:100]}...")
313
+
314
+ teacher_response = get_teacher_response(user_input)
315
+
316
+ if teacher_response:
317
+ print("[Teacher] Got good response from teacher!")
318
+ final_response = teacher_response
319
+ source = "teacher"
320
+
321
+ # Save for future training
322
+ db.save_distillation_data(
323
+ user_input=user_input,
324
+ teacher_response=teacher_response,
325
+ student_response=student_response,
326
+ quality_score=1.0,
327
+ )
328
+ else:
329
+ # Teacher also failed, use student response anyway
330
+ print("[Teacher] No response from teacher, using student response.")
331
+ final_response = student_response if student_response else "I'm sorry, I couldn't generate a good response. Please try again."
332
+ source = "student"
333
+
334
+ # 5) Save and return
335
+ if not final_response:
336
+ final_response = "I'm having trouble responding. Please try asking in a different way."
337
+
338
+ conversation_history.append({"user": user_input, "assistant": final_response})
339
+ current_conv_id = db.save_conversation(user_input, final_response)
340
+
341
+ # Add indicator if from teacher
342
+ if source == "teacher":
343
+ final_response = f"πŸŽ“ {final_response}"
344
+
345
+ return final_response
346
+
347
+
348
+ # --------- Gradio Handlers ----------
349
  def respond(message, history, temperature, max_tokens):
 
350
  history = ensure_messages_history(history)
 
351
  user_text = extract_text(message).strip()
352
  if not user_text:
353
  return "", history
 
361
 
362
 
363
  def feedback_good():
 
364
  if current_conv_id > 0:
365
  db.update_feedback(current_conv_id, 1)
366
+ return "πŸ‘ Thanks! This helps me learn."
367
+ return ""
368
 
369
 
370
  def feedback_bad():
 
371
  if current_conv_id > 0:
372
  db.update_feedback(current_conv_id, -1)
373
+ return "πŸ‘Ž Thanks for feedback. I'll improve!"
374
+ return ""
375
 
376
 
377
  def clear_chat():
 
380
  return [], "Chat cleared."
381
 
382
 
383
+ def retrain_with_distillation(epochs):
384
+ """Retrain using teacher knowledge"""
385
  global model, tokenizer
386
 
387
+ # Get user-approved conversations
388
  good_convs = db.get_good_conversations()
 
 
 
389
  extra_data = ""
390
  for conv in good_convs:
391
  extra_data += f"<USER> {conv['user_input']}\n"
392
  extra_data += f"<ASSISTANT> {conv['assistant_response']}\n\n"
393
 
394
+ # Get distillation data (teacher responses)
395
+ unused_distill = db.get_unused_distillation_data()
396
+ distillation_data = ""
397
+ for item in unused_distill:
398
+ distillation_data += f"<USER> {item['user_input']}\n"
399
+ distillation_data += f"<ASSISTANT> {item['teacher_response']}\n\n"
400
+
401
+ total_samples = len(good_convs) + len(unused_distill)
402
+
403
+ if total_samples == 0:
404
+ return "❌ No training data available. Chat more and rate responses!"
405
+
406
  trainer = VedaTrainer()
407
+ history = trainer.train(
408
+ epochs=int(epochs),
409
+ extra_data=extra_data,
410
+ distillation_data=distillation_data,
411
+ )
412
 
413
  model = trainer.model
414
  tokenizer = trainer.tokenizer
415
 
416
+ # Mark distillation data as used
417
+ if unused_distill:
418
+ ids = [item["id"] for item in unused_distill]
419
+ db.mark_distillation_used(ids)
420
+
421
  loss = history.history["loss"][-1]
422
+
423
+ db.save_training_history(
424
+ training_type="distillation",
425
+ samples_used=total_samples,
426
+ epochs=int(epochs),
427
+ final_loss=loss,
428
+ )
429
+
430
+ return f"""βœ… Training Complete!
431
+
432
+ πŸ“Š **Results:**
433
+ - Loss: {loss:.4f}
434
+ - User samples: {len(good_convs)}
435
+ - Teacher samples: {len(unused_distill)}
436
+ - Total epochs: {epochs}
437
+
438
+ Your model has learned from the teacher!
439
+ """
440
 
441
 
442
  def get_stats():
443
  stats = db.get_stats()
444
+ teacher_available = teacher.is_available()
445
+
446
  return f"""## πŸ“Š Statistics
447
 
448
+ ### Conversations
449
  | Metric | Count |
450
  |--------|-------|
451
+ | πŸ’¬ Total | {stats['total']} |
452
+ | πŸ‘ Positive | {stats['positive']} |
453
+ | πŸ‘Ž Negative | {stats['negative']} |
454
+
455
+ ### πŸŽ“ Distillation
456
+ | Metric | Value |
457
+ |--------|-------|
458
+ | Teacher Available | {'βœ… Yes' if teacher_available else '❌ No'} |
459
+ | Teacher Samples | {stats.get('distillation_total', 0)} |
460
+ | Ready to Train | {stats.get('distillation_unused', 0)} |
461
  """
462
 
463
 
464
  # --------- Startup ----------
465
+ print("=" * 50)
466
+ print("Starting Veda Programming Assistant...")
467
+ print("=" * 50)
468
  initialize()
469
+ print("Checking teacher availability...")
470
+ if teacher.is_available():
471
+ print("βœ… Teacher model (Dolphin Mistral) is available!")
472
+ else:
473
+ print("❌ Teacher model not available - check API key")
474
+ print("=" * 50)
475
+ print("Ready!")
476
+ print("=" * 50)
477
 
478
 
479
  # --------- UI ----------
480
  with gr.Blocks(title="Veda Programming Assistant") as demo:
481
+ gr.Markdown("""
 
482
  # πŸ•‰οΈ Veda Programming Assistant
483
 
484
+ I can help you with **coding**, **programming concepts**, and **math**!
485
+
486
+ *Responses marked with πŸŽ“ come from an advanced AI teacher.*
487
+ """)
488
 
489
  with gr.Tabs():
490
  with gr.TabItem("πŸ’¬ Chat"):
 
493
  with gr.Row():
494
  msg = gr.Textbox(
495
  label="Your message",
496
+ placeholder="Ask me anything about programming...",
497
  lines=2,
498
  scale=4,
499
  )
 
510
 
511
  feedback_msg = gr.Textbox(label="Status", lines=1, interactive=False)
512
 
513
+ send_btn.click(respond, [msg, chatbot, temperature, max_tokens], [msg, chatbot])
514
+ msg.submit(respond, [msg, chatbot, temperature, max_tokens], [msg, chatbot])
 
515
  good_btn.click(feedback_good, outputs=feedback_msg)
516
  bad_btn.click(feedback_bad, outputs=feedback_msg)
517
  clear_btn.click(clear_chat, outputs=[chatbot, feedback_msg])
 
519
  gr.Markdown("### πŸ’‘ Examples")
520
  gr.Examples(
521
  examples=[
522
+ ["Hello! What can you do?"],
 
 
523
  ["What is Python?"],
524
+ ["Write a factorial function"],
525
  ["Explain recursion"],
526
+ ["Write bubble sort"],
527
+ ["2+2=?"],
528
+ ["What is the difference between list and tuple?"],
529
  ],
530
  inputs=msg,
531
  )
532
 
533
  with gr.TabItem("πŸŽ“ Training"):
534
+ gr.Markdown("""
535
+ ### Improve the Model
 
536
 
537
+ The model learns from:
538
+ 1. **Your feedback** - Rate responses πŸ‘ or πŸ‘Ž
539
+ 2. **Teacher knowledge** - Learns from advanced AI
 
 
540
 
541
+ Click below to train with collected data.
542
+ """)
543
+
544
+ train_epochs = gr.Slider(5, 30, 15, step=1, label="Training Epochs")
545
+ train_btn = gr.Button("πŸš€ Train Model", variant="primary")
546
  train_output = gr.Markdown()
547
+
548
+ train_btn.click(retrain_with_distillation, inputs=train_epochs, outputs=train_output)
549
 
550
  with gr.TabItem("πŸ“Š Statistics"):
551
  stats_out = gr.Markdown()
552
+ refresh_btn = gr.Button("πŸ”„ Refresh")
553
  refresh_btn.click(get_stats, outputs=stats_out)
554
 
555
+ gr.Markdown("---\n**Veda Programming Assistant** | Made with ❀️")
556
 
557
 
558
  if __name__ == "__main__":