File size: 5,301 Bytes
4a2546a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | """
Predict module for RAG-based utterance prediction.
This module uses retrieval to find similar utterances instead of generating.
"""
from typing import Any
from traceback import format_exc
import os
from datetime import datetime
def _predict(
model: Any | None, data: BBPredictedUtterance, model_name: str
) -> BBPredictOutput:
"""Make prediction using RAG retriever.
Args:
model: Dict containing retriever and config
data: Input utterance data
model_name: Model identifier
Returns:
BBPredictOutput with prediction
"""
predict_start = datetime.now()
print("[PREDICT] =" * 40)
print("[PREDICT] 🎯 PREDICTION REQUEST")
print("[PREDICT] =" * 40)
print(f"[PREDICT] Index: {data.index}")
print(f"[PREDICT] Step: {data.step}")
print(f"[PREDICT] Prefix length: {len(data.prefix) if data.prefix else 0} chars")
print(f"[PREDICT] Context length: {len(data.context) if data.context else 0} chars")
try:
# Validate model
if not model:
print("[PREDICT] ❌ Model not loaded")
return BBPredictOutput(
success=False,
error="Model not loaded",
utterance=data,
context_used="",
model=model_name
)
# Validate input
if not data.prefix:
print("[PREDICT] ❌ No prefix provided")
return BBPredictOutput(
success=False,
error="No input provided",
utterance=data,
context_used="",
model=model_name
)
# Extract retriever
retriever = model.get("retriever")
if not retriever:
print("[PREDICT] ❌ Retriever not found in model")
return BBPredictOutput(
success=False,
error="Retriever not found in model",
utterance=data,
context_used="",
model=model_name
)
print(f"[PREDICT] Prefix: '{data.prefix}'")
if data.context:
print(f"[PREDICT] Context: '{data.context}'")
# Retrieve most similar utterance
print("[PREDICT] Querying retriever...")
retrieval_start = datetime.now()
result = retriever.retrieve_top1(
prefix=data.prefix,
context=data.context,
)
retrieval_elapsed = (datetime.now() - retrieval_start).total_seconds()
print(f"[PREDICT] Retrieval completed in {retrieval_elapsed:.3f}s")
if not result:
# No match found - return fallback
prediction = os.getenv("CHUTE_FALLBACK_COMPLETION", "...")
print(f"[PREDICT] ⚠️ No match found, using fallback: '{prediction}'")
else:
# Extract the continuation from the matched utterance
matched_utterance = result.utterance
print(f"[PREDICT] ✓ Retrieved match:")
print(f"[PREDICT] Score: {result.score:.4f}")
print(f"[PREDICT] Utterance: '{matched_utterance}'")
print(f"[PREDICT] Dialogue: {result.dialogue_uid}")
print(f"[PREDICT] Index: {result.utterance_index}")
# Strategy: Return the full matched utterance as the prediction
prediction = matched_utterance
# Optional: Try to extract just the continuation if the prefix matches
if data.prefix and matched_utterance.startswith(data.prefix):
continuation = matched_utterance[len(data.prefix):].strip()
if continuation:
prediction = continuation
print(f"[PREDICT] Extracted continuation: '{prediction}'")
# Ensure we have some prediction
if not prediction or prediction.strip() == "":
prediction = matched_utterance
print(f"[PREDICT] Using full utterance as prediction")
# Update the utterance with the prediction
predicted_utterance = BBPredictedUtterance(
index=data.index,
step=data.step,
prefix=data.prefix,
prediction=prediction,
context=data.context,
ground_truth=data.ground_truth,
done=data.done
)
total_elapsed = (datetime.now() - predict_start).total_seconds()
print(f"[PREDICT] ✅ Prediction complete in {total_elapsed:.3f}s")
print(f"[PREDICT] Prediction: '{prediction}'")
print("[PREDICT] =" * 40)
return BBPredictOutput(
success=True,
utterance=predicted_utterance,
context_used=data.context,
model=model_name,
)
except Exception as e:
elapsed = (datetime.now() - predict_start).total_seconds()
print(f"[PREDICT] ❌ PREDICTION FAILED after {elapsed:.3f}s: {str(e)}")
print(format_exc())
print("[PREDICT] =" * 40)
return BBPredictOutput(
success=False,
error=str(e),
utterance=data,
context_used="",
model=model_name
)
|