Spaces:
Sleeping
Sleeping
File size: 5,702 Bytes
f162639 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 | """Knowledge Distillation Engine - Student learns from Teacher"""
from typing import Optional, Tuple, Dict
from teacher import teacher
from database import db
from config import (
DISTILLATION_ENABLED,
AUTO_LEARN_FROM_TEACHER,
MIN_RESPONSE_LENGTH,
MIN_SAMPLES_FOR_DISTILL_TRAINING,
)
class DistillationEngine:
"""
Manages knowledge distillation from teacher to student.
Modes:
1. AUTO: Always ask teacher, save responses, use student for speed
2. FALLBACK: Only ask teacher if student response is poor
3. COMPARE: Show both responses for comparison
"""
def __init__(self):
self.teacher = teacher
self.mode = "fallback" # "auto", "fallback", "compare"
self.teacher_call_count = 0
self.student_call_count = 0
def should_ask_teacher(self, student_response: str) -> bool:
"""Decide if we should ask the teacher based on student response quality"""
if not DISTILLATION_ENABLED:
return False
if not self.teacher.is_available():
return False
# Heuristics for low-quality response
if not student_response:
return True
if len(student_response.strip()) < MIN_RESPONSE_LENGTH:
return True
# Check for error messages
low_quality_indicators = [
"I'm not sure",
"I don't know",
"Could you try rephrasing",
"Error:",
"not sure how to respond",
]
for indicator in low_quality_indicators:
if indicator.lower() in student_response.lower():
return True
return False
def get_teacher_response(
self,
user_input: str,
conversation_history: list = None,
student_response: str = None,
) -> Optional[str]:
"""Get response from teacher and optionally save for training"""
teacher_response = self.teacher.ask(
user_message=user_input,
conversation_history=conversation_history,
)
if teacher_response and AUTO_LEARN_FROM_TEACHER:
# Save for future training
db.save_distillation_data(
user_input=user_input,
teacher_response=teacher_response,
student_response=student_response,
quality_score=1.0, # Teacher responses are high quality
)
if teacher_response:
self.teacher_call_count += 1
return teacher_response
def process_with_distillation(
self,
user_input: str,
student_response: str,
conversation_history: list = None,
) -> Tuple[str, str]:
"""
Process a response with potential teacher assistance.
Returns:
Tuple of (final_response, source) where source is "student", "teacher", or "both"
"""
self.student_call_count += 1
if self.mode == "auto":
# Always get teacher response for learning, but return student for speed
teacher_resp = self.get_teacher_response(
user_input, conversation_history, student_response
)
return student_response, "student"
elif self.mode == "fallback":
# Only ask teacher if student response is poor
if self.should_ask_teacher(student_response):
teacher_resp = self.get_teacher_response(
user_input, conversation_history, student_response
)
if teacher_resp:
return teacher_resp, "teacher"
return student_response, "student"
elif self.mode == "compare":
# Return both for comparison (useful for debugging/evaluation)
teacher_resp = self.get_teacher_response(
user_input, conversation_history, student_response
)
if teacher_resp:
combined = f"**🎓 Teacher (Dolphin):**\n{teacher_resp}\n\n---\n\n**🧠 Student (Veda):**\n{student_response}"
return combined, "both"
return student_response, "student"
return student_response, "student"
def set_mode(self, mode: str):
"""Set distillation mode: 'auto', 'fallback', or 'compare'"""
if mode in ["auto", "fallback", "compare", "disabled"]:
self.mode = mode
return True
return False
def get_stats(self) -> Dict:
"""Get distillation statistics"""
distill_data = db.get_distillation_count()
return {
"mode": self.mode,
"teacher_calls": self.teacher_call_count,
"student_calls": self.student_call_count,
"teacher_available": self.teacher.is_available(),
"distillation_samples": distill_data["total"],
"unused_samples": distill_data["unused"],
"ready_for_training": distill_data["unused"] >= MIN_SAMPLES_FOR_DISTILL_TRAINING,
}
def get_training_data(self) -> str:
"""Get accumulated teacher responses as training data"""
unused = db.get_unused_distillation_data()
if not unused:
return ""
training_text = ""
for item in unused:
training_text += f"<USER> {item['user_input']}\n"
training_text += f"<ASSISTANT> {item['teacher_response']}\n\n"
return training_text
def mark_training_complete(self, ids: list):
"""Mark distillation data as used after training"""
if ids:
db.mark_distillation_used(ids)
# Global engine instance
distillation_engine = DistillationEngine() |