File size: 5,301 Bytes
4a2546a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""
Predict module for RAG-based utterance prediction.

This module uses retrieval to find similar utterances instead of generating.
"""
from typing import Any
from traceback import format_exc
import os
from datetime import datetime


def _predict(
    model: Any | None, data: BBPredictedUtterance, model_name: str
) -> BBPredictOutput:
    """Make prediction using RAG retriever.
    
    Args:
        model: Dict containing retriever and config
        data: Input utterance data
        model_name: Model identifier
        
    Returns:
        BBPredictOutput with prediction
    """
    predict_start = datetime.now()
    print("[PREDICT] =" * 40)
    print("[PREDICT] 🎯 PREDICTION REQUEST")
    print("[PREDICT] =" * 40)
    
    print(f"[PREDICT] Index: {data.index}")
    print(f"[PREDICT] Step: {data.step}")
    print(f"[PREDICT] Prefix length: {len(data.prefix) if data.prefix else 0} chars")
    print(f"[PREDICT] Context length: {len(data.context) if data.context else 0} chars")
    
    try:
        # Validate model
        if not model:
            print("[PREDICT] ❌ Model not loaded")
            return BBPredictOutput(
                success=False, 
                error="Model not loaded", 
                utterance=data,
                context_used="",
                model=model_name
            )

        # Validate input
        if not data.prefix:
            print("[PREDICT] ❌ No prefix provided")
            return BBPredictOutput(
                success=False, 
                error="No input provided", 
                utterance=data,
                context_used="",
                model=model_name
            )

        # Extract retriever
        retriever = model.get("retriever")
        
        if not retriever:
            print("[PREDICT] ❌ Retriever not found in model")
            return BBPredictOutput(
                success=False, 
                error="Retriever not found in model", 
                utterance=data,
                context_used="",
                model=model_name
            )

        print(f"[PREDICT] Prefix: '{data.prefix}'")
        if data.context:
            print(f"[PREDICT] Context: '{data.context}'")

        # Retrieve most similar utterance
        print("[PREDICT] Querying retriever...")
        retrieval_start = datetime.now()
        
        result = retriever.retrieve_top1(
            prefix=data.prefix,
            context=data.context,
        )
        
        retrieval_elapsed = (datetime.now() - retrieval_start).total_seconds()
        print(f"[PREDICT] Retrieval completed in {retrieval_elapsed:.3f}s")
        
        if not result:
            # No match found - return fallback
            prediction = os.getenv("CHUTE_FALLBACK_COMPLETION", "...")
            print(f"[PREDICT] ⚠️  No match found, using fallback: '{prediction}'")
        else:
            # Extract the continuation from the matched utterance
            matched_utterance = result.utterance
            
            print(f"[PREDICT] ✓ Retrieved match:")
            print(f"[PREDICT]   Score: {result.score:.4f}")
            print(f"[PREDICT]   Utterance: '{matched_utterance}'")
            print(f"[PREDICT]   Dialogue: {result.dialogue_uid}")
            print(f"[PREDICT]   Index: {result.utterance_index}")
            
            # Strategy: Return the full matched utterance as the prediction
            prediction = matched_utterance
            
            # Optional: Try to extract just the continuation if the prefix matches
            if data.prefix and matched_utterance.startswith(data.prefix):
                continuation = matched_utterance[len(data.prefix):].strip()
                if continuation:
                    prediction = continuation
                    print(f"[PREDICT]   Extracted continuation: '{prediction}'")
            
            # Ensure we have some prediction
            if not prediction or prediction.strip() == "":
                prediction = matched_utterance
                print(f"[PREDICT]   Using full utterance as prediction")

        # Update the utterance with the prediction
        predicted_utterance = BBPredictedUtterance(
            index=data.index,
            step=data.step,
            prefix=data.prefix,
            prediction=prediction,
            context=data.context,
            ground_truth=data.ground_truth,
            done=data.done
        )

        total_elapsed = (datetime.now() - predict_start).total_seconds()
        print(f"[PREDICT] ✅ Prediction complete in {total_elapsed:.3f}s")
        print(f"[PREDICT] Prediction: '{prediction}'")
        print("[PREDICT] =" * 40)

        return BBPredictOutput(
            success=True,
            utterance=predicted_utterance,
            context_used=data.context,
            model=model_name,
        )
        
    except Exception as e:
        elapsed = (datetime.now() - predict_start).total_seconds()
        print(f"[PREDICT] ❌ PREDICTION FAILED after {elapsed:.3f}s: {str(e)}")
        print(format_exc())
        print("[PREDICT] =" * 40)
        
        return BBPredictOutput(
            success=False, 
            error=str(e), 
            utterance=data,
            context_used="",
            model=model_name
        )