""" Predict module for RAG-based utterance prediction. This module uses retrieval to find similar utterances instead of generating. """ from typing import Any from traceback import format_exc import os from datetime import datetime def _predict( model: Any | None, data: BBPredictedUtterance, model_name: str ) -> BBPredictOutput: """Make prediction using RAG retriever. Args: model: Dict containing retriever and config data: Input utterance data model_name: Model identifier Returns: BBPredictOutput with prediction """ predict_start = datetime.now() print("[PREDICT] =" * 40) print("[PREDICT] 🎯 PREDICTION REQUEST") print("[PREDICT] =" * 40) print(f"[PREDICT] Index: {data.index}") print(f"[PREDICT] Step: {data.step}") print(f"[PREDICT] Prefix length: {len(data.prefix) if data.prefix else 0} chars") print(f"[PREDICT] Context length: {len(data.context) if data.context else 0} chars") try: # Validate model if not model: print("[PREDICT] ❌ Model not loaded") return BBPredictOutput( success=False, error="Model not loaded", utterance=data, context_used="", model=model_name ) # Validate input if not data.prefix: print("[PREDICT] ❌ No prefix provided") return BBPredictOutput( success=False, error="No input provided", utterance=data, context_used="", model=model_name ) # Extract retriever retriever = model.get("retriever") if not retriever: print("[PREDICT] ❌ Retriever not found in model") return BBPredictOutput( success=False, error="Retriever not found in model", utterance=data, context_used="", model=model_name ) print(f"[PREDICT] Prefix: '{data.prefix}'") if data.context: print(f"[PREDICT] Context: '{data.context}'") # Retrieve most similar utterance print("[PREDICT] Querying retriever...") retrieval_start = datetime.now() result = retriever.retrieve_top1( prefix=data.prefix, context=data.context, ) retrieval_elapsed = (datetime.now() - retrieval_start).total_seconds() print(f"[PREDICT] Retrieval completed in {retrieval_elapsed:.3f}s") if not result: # No match found - return fallback prediction = os.getenv("CHUTE_FALLBACK_COMPLETION", "...") print(f"[PREDICT] ⚠️ No match found, using fallback: '{prediction}'") else: # Extract the continuation from the matched utterance matched_utterance = result.utterance print(f"[PREDICT] ✓ Retrieved match:") print(f"[PREDICT] Score: {result.score:.4f}") print(f"[PREDICT] Utterance: '{matched_utterance}'") print(f"[PREDICT] Dialogue: {result.dialogue_uid}") print(f"[PREDICT] Index: {result.utterance_index}") # Strategy: Return the full matched utterance as the prediction prediction = matched_utterance # Optional: Try to extract just the continuation if the prefix matches if data.prefix and matched_utterance.startswith(data.prefix): continuation = matched_utterance[len(data.prefix):].strip() if continuation: prediction = continuation print(f"[PREDICT] Extracted continuation: '{prediction}'") # Ensure we have some prediction if not prediction or prediction.strip() == "": prediction = matched_utterance print(f"[PREDICT] Using full utterance as prediction") # Update the utterance with the prediction predicted_utterance = BBPredictedUtterance( index=data.index, step=data.step, prefix=data.prefix, prediction=prediction, context=data.context, ground_truth=data.ground_truth, done=data.done ) total_elapsed = (datetime.now() - predict_start).total_seconds() print(f"[PREDICT] ✅ Prediction complete in {total_elapsed:.3f}s") print(f"[PREDICT] Prediction: '{prediction}'") print("[PREDICT] =" * 40) return BBPredictOutput( success=True, utterance=predicted_utterance, context_used=data.context, model=model_name, ) except Exception as e: elapsed = (datetime.now() - predict_start).total_seconds() print(f"[PREDICT] ❌ PREDICTION FAILED after {elapsed:.3f}s: {str(e)}") print(format_exc()) print("[PREDICT] =" * 40) return BBPredictOutput( success=False, error=str(e), utterance=data, context_used="", model=model_name )