|
|
""" |
|
|
Predict module for RAG-based utterance prediction. |
|
|
|
|
|
This module uses retrieval to find similar utterances instead of generating. |
|
|
""" |
|
|
from typing import Any |
|
|
from traceback import format_exc |
|
|
import os |
|
|
from datetime import datetime |
|
|
|
|
|
|
|
|
def _predict( |
|
|
model: Any | None, data: BBPredictedUtterance, model_name: str |
|
|
) -> BBPredictOutput: |
|
|
"""Make prediction using RAG retriever. |
|
|
|
|
|
Args: |
|
|
model: Dict containing retriever and config |
|
|
data: Input utterance data |
|
|
model_name: Model identifier |
|
|
|
|
|
Returns: |
|
|
BBPredictOutput with prediction |
|
|
""" |
|
|
predict_start = datetime.now() |
|
|
print("[PREDICT] =" * 40) |
|
|
print("[PREDICT] 🎯 PREDICTION REQUEST") |
|
|
print("[PREDICT] =" * 40) |
|
|
|
|
|
print(f"[PREDICT] Index: {data.index}") |
|
|
print(f"[PREDICT] Step: {data.step}") |
|
|
print(f"[PREDICT] Prefix length: {len(data.prefix) if data.prefix else 0} chars") |
|
|
print(f"[PREDICT] Context length: {len(data.context) if data.context else 0} chars") |
|
|
|
|
|
try: |
|
|
|
|
|
if not model: |
|
|
print("[PREDICT] ❌ Model not loaded") |
|
|
return BBPredictOutput( |
|
|
success=False, |
|
|
error="Model not loaded", |
|
|
utterance=data, |
|
|
context_used="", |
|
|
model=model_name |
|
|
) |
|
|
|
|
|
|
|
|
if not data.prefix: |
|
|
print("[PREDICT] ❌ No prefix provided") |
|
|
return BBPredictOutput( |
|
|
success=False, |
|
|
error="No input provided", |
|
|
utterance=data, |
|
|
context_used="", |
|
|
model=model_name |
|
|
) |
|
|
|
|
|
|
|
|
retriever = model.get("retriever") |
|
|
|
|
|
if not retriever: |
|
|
print("[PREDICT] ❌ Retriever not found in model") |
|
|
return BBPredictOutput( |
|
|
success=False, |
|
|
error="Retriever not found in model", |
|
|
utterance=data, |
|
|
context_used="", |
|
|
model=model_name |
|
|
) |
|
|
|
|
|
print(f"[PREDICT] Prefix: '{data.prefix}'") |
|
|
if data.context: |
|
|
print(f"[PREDICT] Context: '{data.context}'") |
|
|
|
|
|
|
|
|
print("[PREDICT] Querying retriever...") |
|
|
retrieval_start = datetime.now() |
|
|
|
|
|
result = retriever.retrieve_top1( |
|
|
prefix=data.prefix, |
|
|
context=data.context, |
|
|
) |
|
|
|
|
|
retrieval_elapsed = (datetime.now() - retrieval_start).total_seconds() |
|
|
print(f"[PREDICT] Retrieval completed in {retrieval_elapsed:.3f}s") |
|
|
|
|
|
if not result: |
|
|
|
|
|
prediction = os.getenv("CHUTE_FALLBACK_COMPLETION", "...") |
|
|
print(f"[PREDICT] ⚠️ No match found, using fallback: '{prediction}'") |
|
|
else: |
|
|
|
|
|
matched_utterance = result.utterance |
|
|
|
|
|
print(f"[PREDICT] ✓ Retrieved match:") |
|
|
print(f"[PREDICT] Score: {result.score:.4f}") |
|
|
print(f"[PREDICT] Utterance: '{matched_utterance}'") |
|
|
print(f"[PREDICT] Dialogue: {result.dialogue_uid}") |
|
|
print(f"[PREDICT] Index: {result.utterance_index}") |
|
|
|
|
|
|
|
|
prediction = matched_utterance |
|
|
|
|
|
|
|
|
if data.prefix and matched_utterance.startswith(data.prefix): |
|
|
continuation = matched_utterance[len(data.prefix):].strip() |
|
|
if continuation: |
|
|
prediction = continuation |
|
|
print(f"[PREDICT] Extracted continuation: '{prediction}'") |
|
|
|
|
|
|
|
|
if not prediction or prediction.strip() == "": |
|
|
prediction = matched_utterance |
|
|
print(f"[PREDICT] Using full utterance as prediction") |
|
|
|
|
|
|
|
|
predicted_utterance = BBPredictedUtterance( |
|
|
index=data.index, |
|
|
step=data.step, |
|
|
prefix=data.prefix, |
|
|
prediction=prediction, |
|
|
context=data.context, |
|
|
ground_truth=data.ground_truth, |
|
|
done=data.done |
|
|
) |
|
|
|
|
|
total_elapsed = (datetime.now() - predict_start).total_seconds() |
|
|
print(f"[PREDICT] ✅ Prediction complete in {total_elapsed:.3f}s") |
|
|
print(f"[PREDICT] Prediction: '{prediction}'") |
|
|
print("[PREDICT] =" * 40) |
|
|
|
|
|
return BBPredictOutput( |
|
|
success=True, |
|
|
utterance=predicted_utterance, |
|
|
context_used=data.context, |
|
|
model=model_name, |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
elapsed = (datetime.now() - predict_start).total_seconds() |
|
|
print(f"[PREDICT] ❌ PREDICTION FAILED after {elapsed:.3f}s: {str(e)}") |
|
|
print(format_exc()) |
|
|
print("[PREDICT] =" * 40) |
|
|
|
|
|
return BBPredictOutput( |
|
|
success=False, |
|
|
error=str(e), |
|
|
utterance=data, |
|
|
context_used="", |
|
|
model=model_name |
|
|
) |
|
|
|