vedaco commited on
Commit
f162639
·
verified ·
1 Parent(s): 0fe7d00

Create distillation.py

Browse files
Files changed (1) hide show
  1. distillation.py +170 -0
distillation.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Knowledge Distillation Engine - Student learns from Teacher"""
2
+
3
+ from typing import Optional, Tuple, Dict
4
+ from teacher import teacher
5
+ from database import db
6
+ from config import (
7
+ DISTILLATION_ENABLED,
8
+ AUTO_LEARN_FROM_TEACHER,
9
+ MIN_RESPONSE_LENGTH,
10
+ MIN_SAMPLES_FOR_DISTILL_TRAINING,
11
+ )
12
+
13
+
14
+ class DistillationEngine:
15
+ """
16
+ Manages knowledge distillation from teacher to student.
17
+
18
+ Modes:
19
+ 1. AUTO: Always ask teacher, save responses, use student for speed
20
+ 2. FALLBACK: Only ask teacher if student response is poor
21
+ 3. COMPARE: Show both responses for comparison
22
+ """
23
+
24
+ def __init__(self):
25
+ self.teacher = teacher
26
+ self.mode = "fallback" # "auto", "fallback", "compare"
27
+ self.teacher_call_count = 0
28
+ self.student_call_count = 0
29
+
30
+ def should_ask_teacher(self, student_response: str) -> bool:
31
+ """Decide if we should ask the teacher based on student response quality"""
32
+ if not DISTILLATION_ENABLED:
33
+ return False
34
+
35
+ if not self.teacher.is_available():
36
+ return False
37
+
38
+ # Heuristics for low-quality response
39
+ if not student_response:
40
+ return True
41
+
42
+ if len(student_response.strip()) < MIN_RESPONSE_LENGTH:
43
+ return True
44
+
45
+ # Check for error messages
46
+ low_quality_indicators = [
47
+ "I'm not sure",
48
+ "I don't know",
49
+ "Could you try rephrasing",
50
+ "Error:",
51
+ "not sure how to respond",
52
+ ]
53
+ for indicator in low_quality_indicators:
54
+ if indicator.lower() in student_response.lower():
55
+ return True
56
+
57
+ return False
58
+
59
+ def get_teacher_response(
60
+ self,
61
+ user_input: str,
62
+ conversation_history: list = None,
63
+ student_response: str = None,
64
+ ) -> Optional[str]:
65
+ """Get response from teacher and optionally save for training"""
66
+
67
+ teacher_response = self.teacher.ask(
68
+ user_message=user_input,
69
+ conversation_history=conversation_history,
70
+ )
71
+
72
+ if teacher_response and AUTO_LEARN_FROM_TEACHER:
73
+ # Save for future training
74
+ db.save_distillation_data(
75
+ user_input=user_input,
76
+ teacher_response=teacher_response,
77
+ student_response=student_response,
78
+ quality_score=1.0, # Teacher responses are high quality
79
+ )
80
+
81
+ if teacher_response:
82
+ self.teacher_call_count += 1
83
+
84
+ return teacher_response
85
+
86
+ def process_with_distillation(
87
+ self,
88
+ user_input: str,
89
+ student_response: str,
90
+ conversation_history: list = None,
91
+ ) -> Tuple[str, str]:
92
+ """
93
+ Process a response with potential teacher assistance.
94
+
95
+ Returns:
96
+ Tuple of (final_response, source) where source is "student", "teacher", or "both"
97
+ """
98
+ self.student_call_count += 1
99
+
100
+ if self.mode == "auto":
101
+ # Always get teacher response for learning, but return student for speed
102
+ teacher_resp = self.get_teacher_response(
103
+ user_input, conversation_history, student_response
104
+ )
105
+ return student_response, "student"
106
+
107
+ elif self.mode == "fallback":
108
+ # Only ask teacher if student response is poor
109
+ if self.should_ask_teacher(student_response):
110
+ teacher_resp = self.get_teacher_response(
111
+ user_input, conversation_history, student_response
112
+ )
113
+ if teacher_resp:
114
+ return teacher_resp, "teacher"
115
+ return student_response, "student"
116
+
117
+ elif self.mode == "compare":
118
+ # Return both for comparison (useful for debugging/evaluation)
119
+ teacher_resp = self.get_teacher_response(
120
+ user_input, conversation_history, student_response
121
+ )
122
+ if teacher_resp:
123
+ combined = f"**🎓 Teacher (Dolphin):**\n{teacher_resp}\n\n---\n\n**🧠 Student (Veda):**\n{student_response}"
124
+ return combined, "both"
125
+ return student_response, "student"
126
+
127
+ return student_response, "student"
128
+
129
+ def set_mode(self, mode: str):
130
+ """Set distillation mode: 'auto', 'fallback', or 'compare'"""
131
+ if mode in ["auto", "fallback", "compare", "disabled"]:
132
+ self.mode = mode
133
+ return True
134
+ return False
135
+
136
+ def get_stats(self) -> Dict:
137
+ """Get distillation statistics"""
138
+ distill_data = db.get_distillation_count()
139
+ return {
140
+ "mode": self.mode,
141
+ "teacher_calls": self.teacher_call_count,
142
+ "student_calls": self.student_call_count,
143
+ "teacher_available": self.teacher.is_available(),
144
+ "distillation_samples": distill_data["total"],
145
+ "unused_samples": distill_data["unused"],
146
+ "ready_for_training": distill_data["unused"] >= MIN_SAMPLES_FOR_DISTILL_TRAINING,
147
+ }
148
+
149
+ def get_training_data(self) -> str:
150
+ """Get accumulated teacher responses as training data"""
151
+ unused = db.get_unused_distillation_data()
152
+
153
+ if not unused:
154
+ return ""
155
+
156
+ training_text = ""
157
+ for item in unused:
158
+ training_text += f"<USER> {item['user_input']}\n"
159
+ training_text += f"<ASSISTANT> {item['teacher_response']}\n\n"
160
+
161
+ return training_text
162
+
163
+ def mark_training_complete(self, ids: list):
164
+ """Mark distillation data as used after training"""
165
+ if ids:
166
+ db.mark_distillation_used(ids)
167
+
168
+
169
+ # Global engine instance
170
+ distillation_engine = DistillationEngine()