veda-programming / distillation.py
vedaco's picture
Create distillation.py
f162639 verified
"""Knowledge Distillation Engine - Student learns from Teacher"""
from typing import Optional, Tuple, Dict
from teacher import teacher
from database import db
from config import (
DISTILLATION_ENABLED,
AUTO_LEARN_FROM_TEACHER,
MIN_RESPONSE_LENGTH,
MIN_SAMPLES_FOR_DISTILL_TRAINING,
)
class DistillationEngine:
"""
Manages knowledge distillation from teacher to student.
Modes:
1. AUTO: Always ask teacher, save responses, use student for speed
2. FALLBACK: Only ask teacher if student response is poor
3. COMPARE: Show both responses for comparison
"""
def __init__(self):
self.teacher = teacher
self.mode = "fallback" # "auto", "fallback", "compare"
self.teacher_call_count = 0
self.student_call_count = 0
def should_ask_teacher(self, student_response: str) -> bool:
"""Decide if we should ask the teacher based on student response quality"""
if not DISTILLATION_ENABLED:
return False
if not self.teacher.is_available():
return False
# Heuristics for low-quality response
if not student_response:
return True
if len(student_response.strip()) < MIN_RESPONSE_LENGTH:
return True
# Check for error messages
low_quality_indicators = [
"I'm not sure",
"I don't know",
"Could you try rephrasing",
"Error:",
"not sure how to respond",
]
for indicator in low_quality_indicators:
if indicator.lower() in student_response.lower():
return True
return False
def get_teacher_response(
self,
user_input: str,
conversation_history: list = None,
student_response: str = None,
) -> Optional[str]:
"""Get response from teacher and optionally save for training"""
teacher_response = self.teacher.ask(
user_message=user_input,
conversation_history=conversation_history,
)
if teacher_response and AUTO_LEARN_FROM_TEACHER:
# Save for future training
db.save_distillation_data(
user_input=user_input,
teacher_response=teacher_response,
student_response=student_response,
quality_score=1.0, # Teacher responses are high quality
)
if teacher_response:
self.teacher_call_count += 1
return teacher_response
def process_with_distillation(
self,
user_input: str,
student_response: str,
conversation_history: list = None,
) -> Tuple[str, str]:
"""
Process a response with potential teacher assistance.
Returns:
Tuple of (final_response, source) where source is "student", "teacher", or "both"
"""
self.student_call_count += 1
if self.mode == "auto":
# Always get teacher response for learning, but return student for speed
teacher_resp = self.get_teacher_response(
user_input, conversation_history, student_response
)
return student_response, "student"
elif self.mode == "fallback":
# Only ask teacher if student response is poor
if self.should_ask_teacher(student_response):
teacher_resp = self.get_teacher_response(
user_input, conversation_history, student_response
)
if teacher_resp:
return teacher_resp, "teacher"
return student_response, "student"
elif self.mode == "compare":
# Return both for comparison (useful for debugging/evaluation)
teacher_resp = self.get_teacher_response(
user_input, conversation_history, student_response
)
if teacher_resp:
combined = f"**🎓 Teacher (Dolphin):**\n{teacher_resp}\n\n---\n\n**🧠 Student (Veda):**\n{student_response}"
return combined, "both"
return student_response, "student"
return student_response, "student"
def set_mode(self, mode: str):
"""Set distillation mode: 'auto', 'fallback', or 'compare'"""
if mode in ["auto", "fallback", "compare", "disabled"]:
self.mode = mode
return True
return False
def get_stats(self) -> Dict:
"""Get distillation statistics"""
distill_data = db.get_distillation_count()
return {
"mode": self.mode,
"teacher_calls": self.teacher_call_count,
"student_calls": self.student_call_count,
"teacher_available": self.teacher.is_available(),
"distillation_samples": distill_data["total"],
"unused_samples": distill_data["unused"],
"ready_for_training": distill_data["unused"] >= MIN_SAMPLES_FOR_DISTILL_TRAINING,
}
def get_training_data(self) -> str:
"""Get accumulated teacher responses as training data"""
unused = db.get_unused_distillation_data()
if not unused:
return ""
training_text = ""
for item in unused:
training_text += f"<USER> {item['user_input']}\n"
training_text += f"<ASSISTANT> {item['teacher_response']}\n\n"
return training_text
def mark_training_complete(self, ids: list):
"""Mark distillation data as used after training"""
if ids:
db.mark_distillation_used(ids)
# Global engine instance
distillation_engine = DistillationEngine()