"""Knowledge Distillation Engine - Student learns from Teacher""" from typing import Optional, Tuple, Dict from teacher import teacher from database import db from config import ( DISTILLATION_ENABLED, AUTO_LEARN_FROM_TEACHER, MIN_RESPONSE_LENGTH, MIN_SAMPLES_FOR_DISTILL_TRAINING, ) class DistillationEngine: """ Manages knowledge distillation from teacher to student. Modes: 1. AUTO: Always ask teacher, save responses, use student for speed 2. FALLBACK: Only ask teacher if student response is poor 3. COMPARE: Show both responses for comparison """ def __init__(self): self.teacher = teacher self.mode = "fallback" # "auto", "fallback", "compare" self.teacher_call_count = 0 self.student_call_count = 0 def should_ask_teacher(self, student_response: str) -> bool: """Decide if we should ask the teacher based on student response quality""" if not DISTILLATION_ENABLED: return False if not self.teacher.is_available(): return False # Heuristics for low-quality response if not student_response: return True if len(student_response.strip()) < MIN_RESPONSE_LENGTH: return True # Check for error messages low_quality_indicators = [ "I'm not sure", "I don't know", "Could you try rephrasing", "Error:", "not sure how to respond", ] for indicator in low_quality_indicators: if indicator.lower() in student_response.lower(): return True return False def get_teacher_response( self, user_input: str, conversation_history: list = None, student_response: str = None, ) -> Optional[str]: """Get response from teacher and optionally save for training""" teacher_response = self.teacher.ask( user_message=user_input, conversation_history=conversation_history, ) if teacher_response and AUTO_LEARN_FROM_TEACHER: # Save for future training db.save_distillation_data( user_input=user_input, teacher_response=teacher_response, student_response=student_response, quality_score=1.0, # Teacher responses are high quality ) if teacher_response: self.teacher_call_count += 1 return teacher_response def process_with_distillation( self, user_input: str, student_response: str, conversation_history: list = None, ) -> Tuple[str, str]: """ Process a response with potential teacher assistance. Returns: Tuple of (final_response, source) where source is "student", "teacher", or "both" """ self.student_call_count += 1 if self.mode == "auto": # Always get teacher response for learning, but return student for speed teacher_resp = self.get_teacher_response( user_input, conversation_history, student_response ) return student_response, "student" elif self.mode == "fallback": # Only ask teacher if student response is poor if self.should_ask_teacher(student_response): teacher_resp = self.get_teacher_response( user_input, conversation_history, student_response ) if teacher_resp: return teacher_resp, "teacher" return student_response, "student" elif self.mode == "compare": # Return both for comparison (useful for debugging/evaluation) teacher_resp = self.get_teacher_response( user_input, conversation_history, student_response ) if teacher_resp: combined = f"**🎓 Teacher (Dolphin):**\n{teacher_resp}\n\n---\n\n**🧠 Student (Veda):**\n{student_response}" return combined, "both" return student_response, "student" return student_response, "student" def set_mode(self, mode: str): """Set distillation mode: 'auto', 'fallback', or 'compare'""" if mode in ["auto", "fallback", "compare", "disabled"]: self.mode = mode return True return False def get_stats(self) -> Dict: """Get distillation statistics""" distill_data = db.get_distillation_count() return { "mode": self.mode, "teacher_calls": self.teacher_call_count, "student_calls": self.student_call_count, "teacher_available": self.teacher.is_available(), "distillation_samples": distill_data["total"], "unused_samples": distill_data["unused"], "ready_for_training": distill_data["unused"] >= MIN_SAMPLES_FOR_DISTILL_TRAINING, } def get_training_data(self) -> str: """Get accumulated teacher responses as training data""" unused = db.get_unused_distillation_data() if not unused: return "" training_text = "" for item in unused: training_text += f" {item['user_input']}\n" training_text += f" {item['teacher_response']}\n\n" return training_text def mark_training_complete(self, ids: list): """Mark distillation data as used after training""" if ids: db.mark_distillation_used(ids) # Global engine instance distillation_engine = DistillationEngine()