File size: 5,702 Bytes
f162639
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""Knowledge Distillation Engine - Student learns from Teacher"""

from typing import Optional, Tuple, Dict
from teacher import teacher
from database import db
from config import (
    DISTILLATION_ENABLED,
    AUTO_LEARN_FROM_TEACHER,
    MIN_RESPONSE_LENGTH,
    MIN_SAMPLES_FOR_DISTILL_TRAINING,
)


class DistillationEngine:
    """
    Manages knowledge distillation from teacher to student.
    
    Modes:
    1. AUTO: Always ask teacher, save responses, use student for speed
    2. FALLBACK: Only ask teacher if student response is poor
    3. COMPARE: Show both responses for comparison
    """

    def __init__(self):
        self.teacher = teacher
        self.mode = "fallback"  # "auto", "fallback", "compare"
        self.teacher_call_count = 0
        self.student_call_count = 0

    def should_ask_teacher(self, student_response: str) -> bool:
        """Decide if we should ask the teacher based on student response quality"""
        if not DISTILLATION_ENABLED:
            return False

        if not self.teacher.is_available():
            return False

        # Heuristics for low-quality response
        if not student_response:
            return True

        if len(student_response.strip()) < MIN_RESPONSE_LENGTH:
            return True

        # Check for error messages
        low_quality_indicators = [
            "I'm not sure",
            "I don't know",
            "Could you try rephrasing",
            "Error:",
            "not sure how to respond",
        ]
        for indicator in low_quality_indicators:
            if indicator.lower() in student_response.lower():
                return True

        return False

    def get_teacher_response(
        self,
        user_input: str,
        conversation_history: list = None,
        student_response: str = None,
    ) -> Optional[str]:
        """Get response from teacher and optionally save for training"""
        
        teacher_response = self.teacher.ask(
            user_message=user_input,
            conversation_history=conversation_history,
        )

        if teacher_response and AUTO_LEARN_FROM_TEACHER:
            # Save for future training
            db.save_distillation_data(
                user_input=user_input,
                teacher_response=teacher_response,
                student_response=student_response,
                quality_score=1.0,  # Teacher responses are high quality
            )

        if teacher_response:
            self.teacher_call_count += 1

        return teacher_response

    def process_with_distillation(
        self,
        user_input: str,
        student_response: str,
        conversation_history: list = None,
    ) -> Tuple[str, str]:
        """
        Process a response with potential teacher assistance.
        
        Returns:
            Tuple of (final_response, source) where source is "student", "teacher", or "both"
        """
        self.student_call_count += 1

        if self.mode == "auto":
            # Always get teacher response for learning, but return student for speed
            teacher_resp = self.get_teacher_response(
                user_input, conversation_history, student_response
            )
            return student_response, "student"

        elif self.mode == "fallback":
            # Only ask teacher if student response is poor
            if self.should_ask_teacher(student_response):
                teacher_resp = self.get_teacher_response(
                    user_input, conversation_history, student_response
                )
                if teacher_resp:
                    return teacher_resp, "teacher"
            return student_response, "student"

        elif self.mode == "compare":
            # Return both for comparison (useful for debugging/evaluation)
            teacher_resp = self.get_teacher_response(
                user_input, conversation_history, student_response
            )
            if teacher_resp:
                combined = f"**🎓 Teacher (Dolphin):**\n{teacher_resp}\n\n---\n\n**🧠 Student (Veda):**\n{student_response}"
                return combined, "both"
            return student_response, "student"

        return student_response, "student"

    def set_mode(self, mode: str):
        """Set distillation mode: 'auto', 'fallback', or 'compare'"""
        if mode in ["auto", "fallback", "compare", "disabled"]:
            self.mode = mode
            return True
        return False

    def get_stats(self) -> Dict:
        """Get distillation statistics"""
        distill_data = db.get_distillation_count()
        return {
            "mode": self.mode,
            "teacher_calls": self.teacher_call_count,
            "student_calls": self.student_call_count,
            "teacher_available": self.teacher.is_available(),
            "distillation_samples": distill_data["total"],
            "unused_samples": distill_data["unused"],
            "ready_for_training": distill_data["unused"] >= MIN_SAMPLES_FOR_DISTILL_TRAINING,
        }

    def get_training_data(self) -> str:
        """Get accumulated teacher responses as training data"""
        unused = db.get_unused_distillation_data()

        if not unused:
            return ""

        training_text = ""
        for item in unused:
            training_text += f"<USER> {item['user_input']}\n"
            training_text += f"<ASSISTANT> {item['teacher_response']}\n\n"

        return training_text

    def mark_training_complete(self, ids: list):
        """Mark distillation data as used after training"""
        if ids:
            db.mark_distillation_used(ids)


# Global engine instance
distillation_engine = DistillationEngine()