Babelbit-hksa01 / predict.py
aitask1024's picture
Upload from sasn59/Babelbit-hksa01
4a2546a verified
"""
Predict module for RAG-based utterance prediction.
This module uses retrieval to find similar utterances instead of generating.
"""
from typing import Any
from traceback import format_exc
import os
from datetime import datetime
def _predict(
model: Any | None, data: BBPredictedUtterance, model_name: str
) -> BBPredictOutput:
"""Make prediction using RAG retriever.
Args:
model: Dict containing retriever and config
data: Input utterance data
model_name: Model identifier
Returns:
BBPredictOutput with prediction
"""
predict_start = datetime.now()
print("[PREDICT] =" * 40)
print("[PREDICT] 🎯 PREDICTION REQUEST")
print("[PREDICT] =" * 40)
print(f"[PREDICT] Index: {data.index}")
print(f"[PREDICT] Step: {data.step}")
print(f"[PREDICT] Prefix length: {len(data.prefix) if data.prefix else 0} chars")
print(f"[PREDICT] Context length: {len(data.context) if data.context else 0} chars")
try:
# Validate model
if not model:
print("[PREDICT] ❌ Model not loaded")
return BBPredictOutput(
success=False,
error="Model not loaded",
utterance=data,
context_used="",
model=model_name
)
# Validate input
if not data.prefix:
print("[PREDICT] ❌ No prefix provided")
return BBPredictOutput(
success=False,
error="No input provided",
utterance=data,
context_used="",
model=model_name
)
# Extract retriever
retriever = model.get("retriever")
if not retriever:
print("[PREDICT] ❌ Retriever not found in model")
return BBPredictOutput(
success=False,
error="Retriever not found in model",
utterance=data,
context_used="",
model=model_name
)
print(f"[PREDICT] Prefix: '{data.prefix}'")
if data.context:
print(f"[PREDICT] Context: '{data.context}'")
# Retrieve most similar utterance
print("[PREDICT] Querying retriever...")
retrieval_start = datetime.now()
result = retriever.retrieve_top1(
prefix=data.prefix,
context=data.context,
)
retrieval_elapsed = (datetime.now() - retrieval_start).total_seconds()
print(f"[PREDICT] Retrieval completed in {retrieval_elapsed:.3f}s")
if not result:
# No match found - return fallback
prediction = os.getenv("CHUTE_FALLBACK_COMPLETION", "...")
print(f"[PREDICT] ⚠️ No match found, using fallback: '{prediction}'")
else:
# Extract the continuation from the matched utterance
matched_utterance = result.utterance
print(f"[PREDICT] ✓ Retrieved match:")
print(f"[PREDICT] Score: {result.score:.4f}")
print(f"[PREDICT] Utterance: '{matched_utterance}'")
print(f"[PREDICT] Dialogue: {result.dialogue_uid}")
print(f"[PREDICT] Index: {result.utterance_index}")
# Strategy: Return the full matched utterance as the prediction
prediction = matched_utterance
# Optional: Try to extract just the continuation if the prefix matches
if data.prefix and matched_utterance.startswith(data.prefix):
continuation = matched_utterance[len(data.prefix):].strip()
if continuation:
prediction = continuation
print(f"[PREDICT] Extracted continuation: '{prediction}'")
# Ensure we have some prediction
if not prediction or prediction.strip() == "":
prediction = matched_utterance
print(f"[PREDICT] Using full utterance as prediction")
# Update the utterance with the prediction
predicted_utterance = BBPredictedUtterance(
index=data.index,
step=data.step,
prefix=data.prefix,
prediction=prediction,
context=data.context,
ground_truth=data.ground_truth,
done=data.done
)
total_elapsed = (datetime.now() - predict_start).total_seconds()
print(f"[PREDICT] ✅ Prediction complete in {total_elapsed:.3f}s")
print(f"[PREDICT] Prediction: '{prediction}'")
print("[PREDICT] =" * 40)
return BBPredictOutput(
success=True,
utterance=predicted_utterance,
context_used=data.context,
model=model_name,
)
except Exception as e:
elapsed = (datetime.now() - predict_start).total_seconds()
print(f"[PREDICT] ❌ PREDICTION FAILED after {elapsed:.3f}s: {str(e)}")
print(format_exc())
print("[PREDICT] =" * 40)
return BBPredictOutput(
success=False,
error=str(e),
utterance=data,
context_used="",
model=model_name
)