import atexit import os import threading import time from dataclasses import dataclass from typing import Any, Dict, List, Sequence from rag_core.config import NODE_CACHE_PATH, VECTORSTORE_PATH from rag_core.evaluator import evaluate_answer from rag_core.index_builder import build_and_save_index, load_node_cache, load_vectorstore from rag_core.logging_utils import get_model_flow_logger, log_event from rag_core.rag_chain import ContextDocument, build_rag_chain from rag_core.rag_chain_helper import rewrite_question_with_history @dataclass(frozen=True) class RefreshConfig: enabled: bool at_hour: int at_minute: int only_fixed_urls: bool rebuild_on_startup: bool @classmethod def from_env(cls) -> "RefreshConfig": return cls( enabled=os.getenv("REFRESH_ENABLED", "true").lower() == "true", at_hour=int(os.getenv("REFRESH_AT_HOUR", "3")), at_minute=int(os.getenv("REFRESH_AT_MINUTE", "0")), only_fixed_urls=os.getenv("REFRESH_ONLY_FIXED_URLS", "false").lower() == "true", rebuild_on_startup=os.getenv("REFRESH_ON_STARTUP", "false").lower() == "true", ) def _history_to_text(history: Sequence[Sequence[str]]) -> str: """Convert Gradio history ([[user, bot], ...]) into a compact text block.""" if not history: return "" lines: List[str] = [] for turn in history: if not turn or len(turn) < 2: continue user_msg, assistant_msg = turn[0], turn[1] lines.append(f"User: {user_msg}") lines.append(f"Assistant: {assistant_msg}") return "\n".join(lines) def _docs_to_loggable(docs: Sequence[Any], max_chars: int = 220) -> List[dict]: """Return lightweight document metadata for logs without dumping full context.""" summaries: List[dict] = [] for doc in docs or []: source = (doc.metadata or {}).get("source", "unknown") text = (doc.page_content or "").strip().replace("\n", " ") summaries.append( { "source": source, "preview": text[:max_chars] + ("..." if len(text) > max_chars else ""), "metadata": doc.metadata or {}, } ) return summaries class CareerQARuntime: """Owns model state, refresh scheduling, and answer generation.""" def __init__(self, refresh_config: RefreshConfig | None = None): self.refresh_config = refresh_config or RefreshConfig.from_env() self.logger = get_model_flow_logger() self.state_lock = threading.RLock() self.stop_refresh_event = threading.Event() self.vectorstore = None self.rag_chain = None self.retriever = None self.system_prompt = None self.refresh_thread = None self.pending_clarification: Dict[str, Any] | None = None self.init_rag() self.start_refresh_thread() atexit.register(self.stop) def _load_context_docs(self) -> List[ContextDocument]: return [ ContextDocument( page_content=record["text"], metadata=record["metadata"], ) for record in load_node_cache() ] def init_rag(self) -> None: """Build the index if needed, then load the vectorstore and chain.""" index_path = NODE_CACHE_PATH should_rebuild = self.refresh_config.rebuild_on_startup or not index_path.exists() if should_rebuild: try: chunk_count, _ = build_and_save_index() self.log_event( "refresh.index_built", mode="startup_rebuild", chunks=chunk_count, ) except Exception as exc: if not index_path.exists(): raise self.log_event( "refresh.startup_rebuild_failed", error=str(exc), fallback="loading_existing_index", ) vector_index = load_vectorstore() docs = self._load_context_docs() rag_chain, retriever, system_prompt = build_rag_chain( vector_index, docs, k=5, max_docs=3, ) with self.state_lock: self.vectorstore = vector_index self.rag_chain = rag_chain self.retriever = retriever self.system_prompt = system_prompt self.log_event("init_rag.ready", vectorstore_path=VECTORSTORE_PATH) def log_event(self, event: str, **payload) -> None: log_event(self.logger, event, **payload) def refresh_rag_once(self) -> None: """Rebuild the index and atomically swap in a fresh chain.""" self.log_event( "refresh.start", only_fixed_urls=self.refresh_config.only_fixed_urls, ) try: chunk_count, _ = build_and_save_index() self.log_event("refresh.index_built", mode="crawl", chunks=chunk_count) vector_index = load_vectorstore() docs = self._load_context_docs() rag_chain, retriever, system_prompt = build_rag_chain( vector_index, docs, k=5, max_docs=3, ) with self.state_lock: self.vectorstore = vector_index self.rag_chain = rag_chain self.retriever = retriever self.system_prompt = system_prompt self.log_event("refresh.done", status="ok") except Exception as exc: self.log_event("refresh.error", error=str(exc)) def _seconds_until_next_run(self) -> int: """Compute the delay until the next scheduled refresh in local time.""" now = time.localtime() target = time.mktime( ( now.tm_year, now.tm_mon, now.tm_mday, self.refresh_config.at_hour, self.refresh_config.at_minute, 0, now.tm_wday, now.tm_yday, now.tm_isdst, ) ) now_ts = time.time() if target <= now_ts: target += 24 * 60 * 60 return int(target - now_ts) def _daily_refresh_loop(self) -> None: time.sleep(3) while not self.stop_refresh_event.is_set(): sleep_seconds = self._seconds_until_next_run() self.log_event( "refresh.sleep", seconds=sleep_seconds, at_hour=self.refresh_config.at_hour, at_minute=self.refresh_config.at_minute, ) while sleep_seconds > 0 and not self.stop_refresh_event.is_set(): step = min(5, sleep_seconds) time.sleep(step) sleep_seconds -= step if self.stop_refresh_event.is_set(): break self.refresh_rag_once() def start_refresh_thread(self) -> None: if not self.refresh_config.enabled: self.log_event("refresh.disabled") return if self.refresh_thread and self.refresh_thread.is_alive(): return self.refresh_thread = threading.Thread( target=self._daily_refresh_loop, daemon=True, ) self.refresh_thread.start() self.log_event( "refresh.thread_started", daily_at=f"{self.refresh_config.at_hour:02d}:{self.refresh_config.at_minute:02d}", ) def stop(self) -> None: self.stop_refresh_event.set() def _run_rag( self, question: str, history_text: str, forced_tool: str | None = None, ) -> Dict[str, Any]: with self.state_lock: local_rag_chain = self.rag_chain payload: Dict[str, Any] = { "input": question, "chat_history": history_text, } if forced_tool: payload["forced_tool"] = forced_tool return local_rag_chain.invoke(payload) def generate_answer(self, message: str, history: Sequence[Sequence[str]]) -> str: """Run rewrite, RAG, evaluation, and optional retry for one user message.""" self.log_event("request.start", user_message=message) history_text = _history_to_text(history) with self.state_lock: pending = self.pending_clarification if pending: with self.state_lock: local_rag_chain = self.rag_chain forced_tool = local_rag_chain.resolve_clarification_reply( message, pending.get("candidate_tools", []), pending.get("preferred_tool", "about"), ) standalone_question = pending.get("original_question", message) self.log_event( "routing.clarification_resolved", original_question=standalone_question, clarification_reply=message, forced_tool=forced_tool, candidate_tools=pending.get("candidate_tools", []), ) with self.state_lock: self.pending_clarification = None try: rag_result = self._run_rag(standalone_question, history_text, forced_tool=forced_tool) except Exception as exc: self.log_event("rag.error", error=str(exc)) fallback = ( "I'm having trouble accessing my knowledge base right now. " "Please try again in a moment." ) self.log_event("request.end", final_answer_preview=fallback[:400]) return fallback else: try: standalone_question = rewrite_question_with_history(history, message) except Exception as exc: self.log_event("rewrite.error", error=str(exc)) standalone_question = message self.log_event( "rewrite.done", standalone_question=standalone_question, history_chars=len(history_text), ) try: rag_result = self._run_rag(standalone_question, history_text) except Exception as exc: self.log_event("rag.error", error=str(exc)) fallback = ( "I'm having trouble accessing my knowledge base right now. " "Please try again in a moment." ) self.log_event("request.end", final_answer_preview=fallback[:400]) return fallback if rag_result.get("needs_clarification"): clarification_answer = rag_result.get("answer", "") or ( "Could you clarify which area you want me to focus on?" ) with self.state_lock: self.pending_clarification = { "candidate_tools": rag_result.get("candidate_tools", []), "preferred_tool": rag_result.get("preferred_tool", "about"), "original_question": rag_result.get("original_question", standalone_question), } self.log_event( "routing.clarification_requested", original_question=standalone_question, candidate_tools=rag_result.get("candidate_tools", []), preferred_tool=rag_result.get("preferred_tool", "about"), ) self.log_event("request.end", final_answer_preview=clarification_answer[:400]) return clarification_answer answer_1 = rag_result.get("answer", "") or "" context_docs_1 = rag_result.get("context", []) or [] self.log_event( "rag.done", answer_preview=answer_1[:400] + ("..." if len(answer_1) > 400 else ""), retrieved_count=len(context_docs_1), retrieved_docs=_docs_to_loggable(context_docs_1), ) with self.state_lock: local_system_prompt = self.system_prompt eval_result_1 = None try: eval_result_1 = evaluate_answer( system_prompt=local_system_prompt, question=message, context_docs=context_docs_1, answer=answer_1, ) self.log_event( "eval.done", overall_score=float(eval_result_1.overall_score), grounded=float(eval_result_1.grounded_in_context_score), hallucination=bool(eval_result_1.hallucination_detected), feedback=str(eval_result_1.feedback), ) except Exception as exc: self.log_event("eval.error", error=str(exc)) final_answer = answer_1 try: should_retry = ( eval_result_1 is not None and ( eval_result_1.overall_score < 0.70 or getattr(eval_result_1, "should_retry", True) ) ) if should_retry: revision_prompt = ( f"{standalone_question}\n\n" f"You previously answered this:\n{answer_1}\n\n" "An evaluator found issues. Revise your answer to address the feedback below.\n" "Rules:\n" "- Use ONLY the provided context.\n" '- If the context does not support the claim, say "I don\'t know".\n' "- Be specific and grounded.\n\n" f"Evaluator feedback:\n{eval_result_1.feedback}\n" ) self.log_event( "retry.triggered", reason="eval_score_below_threshold", threshold=0.90, ) try: retry_result = self._run_rag(revision_prompt, history_text) answer_2 = retry_result.get("answer", "") or "" context_docs_2 = retry_result.get("context", []) or [] self.log_event( "rag.retry_done", answer_preview=answer_2[:400] + ("..." if len(answer_2) > 400 else ""), retrieved_count=len(context_docs_2), retrieved_docs=_docs_to_loggable(context_docs_2), ) eval_result_2 = None try: eval_result_2 = evaluate_answer( system_prompt=local_system_prompt, question=message, context_docs=context_docs_2, answer=answer_2, ) self.log_event( "eval.retry_done", overall_score=float(eval_result_2.overall_score), grounded=float(eval_result_2.grounded_in_context_score), hallucination=bool(eval_result_2.hallucination_detected), feedback=str(eval_result_2.feedback), ) except Exception as exc: self.log_event("eval.retry_error", error=str(exc)) if eval_result_2 is not None and eval_result_1.overall_score <= eval_result_2.overall_score: final_answer = answer_2 else: final_answer = answer_1 except Exception as exc: self.log_event("rag.retry_error", error=str(exc)) final_answer = answer_1 except Exception as exc: self.log_event("retry.block_error", error=str(exc)) final_answer = answer_1 self.log_event( "request.end", final_answer_preview=final_answer[:400] + ("..." if len(final_answer) > 400 else ""), ) return final_answer def respond(self, message: str, history: Sequence[Sequence[str]]): """Gradio callback wrapper that converts failures into safe user responses.""" history = history or [] if not message: return "", history try: answer = self.generate_answer(message, history) except Exception as exc: self.log_event("respond.fatal_error", error=str(exc)) answer = ( "Something went wrong on my side while trying to answer. " "Please try again in a moment." ) updated_history = list(history) + [[message, answer]] return "", updated_history