PersonaRag / rag_core /runtime.py
github-actions[bot]
Deploy snapshot to HF (binaries stripped)
a0b643a
import atexit
import os
import threading
import time
from dataclasses import dataclass
from typing import Any, Dict, List, Sequence
from rag_core.config import NODE_CACHE_PATH, VECTORSTORE_PATH
from rag_core.evaluator import evaluate_answer
from rag_core.index_builder import build_and_save_index, load_node_cache, load_vectorstore
from rag_core.logging_utils import get_model_flow_logger, log_event
from rag_core.rag_chain import ContextDocument, build_rag_chain
from rag_core.rag_chain_helper import rewrite_question_with_history
@dataclass(frozen=True)
class RefreshConfig:
enabled: bool
at_hour: int
at_minute: int
only_fixed_urls: bool
rebuild_on_startup: bool
@classmethod
def from_env(cls) -> "RefreshConfig":
return cls(
enabled=os.getenv("REFRESH_ENABLED", "true").lower() == "true",
at_hour=int(os.getenv("REFRESH_AT_HOUR", "3")),
at_minute=int(os.getenv("REFRESH_AT_MINUTE", "0")),
only_fixed_urls=os.getenv("REFRESH_ONLY_FIXED_URLS", "false").lower() == "true",
rebuild_on_startup=os.getenv("REFRESH_ON_STARTUP", "false").lower() == "true",
)
def _history_to_text(history: Sequence[Sequence[str]]) -> str:
"""Convert Gradio history ([[user, bot], ...]) into a compact text block."""
if not history:
return ""
lines: List[str] = []
for turn in history:
if not turn or len(turn) < 2:
continue
user_msg, assistant_msg = turn[0], turn[1]
lines.append(f"User: {user_msg}")
lines.append(f"Assistant: {assistant_msg}")
return "\n".join(lines)
def _docs_to_loggable(docs: Sequence[Any], max_chars: int = 220) -> List[dict]:
"""Return lightweight document metadata for logs without dumping full context."""
summaries: List[dict] = []
for doc in docs or []:
source = (doc.metadata or {}).get("source", "unknown")
text = (doc.page_content or "").strip().replace("\n", " ")
summaries.append(
{
"source": source,
"preview": text[:max_chars] + ("..." if len(text) > max_chars else ""),
"metadata": doc.metadata or {},
}
)
return summaries
class CareerQARuntime:
"""Owns model state, refresh scheduling, and answer generation."""
def __init__(self, refresh_config: RefreshConfig | None = None):
self.refresh_config = refresh_config or RefreshConfig.from_env()
self.logger = get_model_flow_logger()
self.state_lock = threading.RLock()
self.stop_refresh_event = threading.Event()
self.vectorstore = None
self.rag_chain = None
self.retriever = None
self.system_prompt = None
self.refresh_thread = None
self.pending_clarification: Dict[str, Any] | None = None
self.init_rag()
self.start_refresh_thread()
atexit.register(self.stop)
def _load_context_docs(self) -> List[ContextDocument]:
return [
ContextDocument(
page_content=record["text"],
metadata=record["metadata"],
)
for record in load_node_cache()
]
def init_rag(self) -> None:
"""Build the index if needed, then load the vectorstore and chain."""
index_path = NODE_CACHE_PATH
should_rebuild = self.refresh_config.rebuild_on_startup or not index_path.exists()
if should_rebuild:
try:
chunk_count, _ = build_and_save_index()
self.log_event(
"refresh.index_built",
mode="startup_rebuild",
chunks=chunk_count,
)
except Exception as exc:
if not index_path.exists():
raise
self.log_event(
"refresh.startup_rebuild_failed",
error=str(exc),
fallback="loading_existing_index",
)
vector_index = load_vectorstore()
docs = self._load_context_docs()
rag_chain, retriever, system_prompt = build_rag_chain(
vector_index,
docs,
k=5,
max_docs=3,
)
with self.state_lock:
self.vectorstore = vector_index
self.rag_chain = rag_chain
self.retriever = retriever
self.system_prompt = system_prompt
self.log_event("init_rag.ready", vectorstore_path=VECTORSTORE_PATH)
def log_event(self, event: str, **payload) -> None:
log_event(self.logger, event, **payload)
def refresh_rag_once(self) -> None:
"""Rebuild the index and atomically swap in a fresh chain."""
self.log_event(
"refresh.start",
only_fixed_urls=self.refresh_config.only_fixed_urls,
)
try:
chunk_count, _ = build_and_save_index()
self.log_event("refresh.index_built", mode="crawl", chunks=chunk_count)
vector_index = load_vectorstore()
docs = self._load_context_docs()
rag_chain, retriever, system_prompt = build_rag_chain(
vector_index,
docs,
k=5,
max_docs=3,
)
with self.state_lock:
self.vectorstore = vector_index
self.rag_chain = rag_chain
self.retriever = retriever
self.system_prompt = system_prompt
self.log_event("refresh.done", status="ok")
except Exception as exc:
self.log_event("refresh.error", error=str(exc))
def _seconds_until_next_run(self) -> int:
"""Compute the delay until the next scheduled refresh in local time."""
now = time.localtime()
target = time.mktime(
(
now.tm_year,
now.tm_mon,
now.tm_mday,
self.refresh_config.at_hour,
self.refresh_config.at_minute,
0,
now.tm_wday,
now.tm_yday,
now.tm_isdst,
)
)
now_ts = time.time()
if target <= now_ts:
target += 24 * 60 * 60
return int(target - now_ts)
def _daily_refresh_loop(self) -> None:
time.sleep(3)
while not self.stop_refresh_event.is_set():
sleep_seconds = self._seconds_until_next_run()
self.log_event(
"refresh.sleep",
seconds=sleep_seconds,
at_hour=self.refresh_config.at_hour,
at_minute=self.refresh_config.at_minute,
)
while sleep_seconds > 0 and not self.stop_refresh_event.is_set():
step = min(5, sleep_seconds)
time.sleep(step)
sleep_seconds -= step
if self.stop_refresh_event.is_set():
break
self.refresh_rag_once()
def start_refresh_thread(self) -> None:
if not self.refresh_config.enabled:
self.log_event("refresh.disabled")
return
if self.refresh_thread and self.refresh_thread.is_alive():
return
self.refresh_thread = threading.Thread(
target=self._daily_refresh_loop,
daemon=True,
)
self.refresh_thread.start()
self.log_event(
"refresh.thread_started",
daily_at=f"{self.refresh_config.at_hour:02d}:{self.refresh_config.at_minute:02d}",
)
def stop(self) -> None:
self.stop_refresh_event.set()
def _run_rag(
self,
question: str,
history_text: str,
forced_tool: str | None = None,
) -> Dict[str, Any]:
with self.state_lock:
local_rag_chain = self.rag_chain
payload: Dict[str, Any] = {
"input": question,
"chat_history": history_text,
}
if forced_tool:
payload["forced_tool"] = forced_tool
return local_rag_chain.invoke(payload)
def generate_answer(self, message: str, history: Sequence[Sequence[str]]) -> str:
"""Run rewrite, RAG, evaluation, and optional retry for one user message."""
self.log_event("request.start", user_message=message)
history_text = _history_to_text(history)
with self.state_lock:
pending = self.pending_clarification
if pending:
with self.state_lock:
local_rag_chain = self.rag_chain
forced_tool = local_rag_chain.resolve_clarification_reply(
message,
pending.get("candidate_tools", []),
pending.get("preferred_tool", "about"),
)
standalone_question = pending.get("original_question", message)
self.log_event(
"routing.clarification_resolved",
original_question=standalone_question,
clarification_reply=message,
forced_tool=forced_tool,
candidate_tools=pending.get("candidate_tools", []),
)
with self.state_lock:
self.pending_clarification = None
try:
rag_result = self._run_rag(standalone_question, history_text, forced_tool=forced_tool)
except Exception as exc:
self.log_event("rag.error", error=str(exc))
fallback = (
"I'm having trouble accessing my knowledge base right now. "
"Please try again in a moment."
)
self.log_event("request.end", final_answer_preview=fallback[:400])
return fallback
else:
try:
standalone_question = rewrite_question_with_history(history, message)
except Exception as exc:
self.log_event("rewrite.error", error=str(exc))
standalone_question = message
self.log_event(
"rewrite.done",
standalone_question=standalone_question,
history_chars=len(history_text),
)
try:
rag_result = self._run_rag(standalone_question, history_text)
except Exception as exc:
self.log_event("rag.error", error=str(exc))
fallback = (
"I'm having trouble accessing my knowledge base right now. "
"Please try again in a moment."
)
self.log_event("request.end", final_answer_preview=fallback[:400])
return fallback
if rag_result.get("needs_clarification"):
clarification_answer = rag_result.get("answer", "") or (
"Could you clarify which area you want me to focus on?"
)
with self.state_lock:
self.pending_clarification = {
"candidate_tools": rag_result.get("candidate_tools", []),
"preferred_tool": rag_result.get("preferred_tool", "about"),
"original_question": rag_result.get("original_question", standalone_question),
}
self.log_event(
"routing.clarification_requested",
original_question=standalone_question,
candidate_tools=rag_result.get("candidate_tools", []),
preferred_tool=rag_result.get("preferred_tool", "about"),
)
self.log_event("request.end", final_answer_preview=clarification_answer[:400])
return clarification_answer
answer_1 = rag_result.get("answer", "") or ""
context_docs_1 = rag_result.get("context", []) or []
self.log_event(
"rag.done",
answer_preview=answer_1[:400] + ("..." if len(answer_1) > 400 else ""),
retrieved_count=len(context_docs_1),
retrieved_docs=_docs_to_loggable(context_docs_1),
)
with self.state_lock:
local_system_prompt = self.system_prompt
eval_result_1 = None
try:
eval_result_1 = evaluate_answer(
system_prompt=local_system_prompt,
question=message,
context_docs=context_docs_1,
answer=answer_1,
)
self.log_event(
"eval.done",
overall_score=float(eval_result_1.overall_score),
grounded=float(eval_result_1.grounded_in_context_score),
hallucination=bool(eval_result_1.hallucination_detected),
feedback=str(eval_result_1.feedback),
)
except Exception as exc:
self.log_event("eval.error", error=str(exc))
final_answer = answer_1
try:
should_retry = (
eval_result_1 is not None
and (
eval_result_1.overall_score < 0.70
or getattr(eval_result_1, "should_retry", True)
)
)
if should_retry:
revision_prompt = (
f"{standalone_question}\n\n"
f"You previously answered this:\n{answer_1}\n\n"
"An evaluator found issues. Revise your answer to address the feedback below.\n"
"Rules:\n"
"- Use ONLY the provided context.\n"
'- If the context does not support the claim, say "I don\'t know".\n'
"- Be specific and grounded.\n\n"
f"Evaluator feedback:\n{eval_result_1.feedback}\n"
)
self.log_event(
"retry.triggered",
reason="eval_score_below_threshold",
threshold=0.90,
)
try:
retry_result = self._run_rag(revision_prompt, history_text)
answer_2 = retry_result.get("answer", "") or ""
context_docs_2 = retry_result.get("context", []) or []
self.log_event(
"rag.retry_done",
answer_preview=answer_2[:400] + ("..." if len(answer_2) > 400 else ""),
retrieved_count=len(context_docs_2),
retrieved_docs=_docs_to_loggable(context_docs_2),
)
eval_result_2 = None
try:
eval_result_2 = evaluate_answer(
system_prompt=local_system_prompt,
question=message,
context_docs=context_docs_2,
answer=answer_2,
)
self.log_event(
"eval.retry_done",
overall_score=float(eval_result_2.overall_score),
grounded=float(eval_result_2.grounded_in_context_score),
hallucination=bool(eval_result_2.hallucination_detected),
feedback=str(eval_result_2.feedback),
)
except Exception as exc:
self.log_event("eval.retry_error", error=str(exc))
if eval_result_2 is not None and eval_result_1.overall_score <= eval_result_2.overall_score:
final_answer = answer_2
else:
final_answer = answer_1
except Exception as exc:
self.log_event("rag.retry_error", error=str(exc))
final_answer = answer_1
except Exception as exc:
self.log_event("retry.block_error", error=str(exc))
final_answer = answer_1
self.log_event(
"request.end",
final_answer_preview=final_answer[:400] + ("..." if len(final_answer) > 400 else ""),
)
return final_answer
def respond(self, message: str, history: Sequence[Sequence[str]]):
"""Gradio callback wrapper that converts failures into safe user responses."""
history = history or []
if not message:
return "", history
try:
answer = self.generate_answer(message, history)
except Exception as exc:
self.log_event("respond.fatal_error", error=str(exc))
answer = (
"Something went wrong on my side while trying to answer. "
"Please try again in a moment."
)
updated_history = list(history) + [[message, answer]]
return "", updated_history