Spaces:
Sleeping
Sleeping
File size: 16,816 Bytes
a0b643a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 | import atexit
import os
import threading
import time
from dataclasses import dataclass
from typing import Any, Dict, List, Sequence
from rag_core.config import NODE_CACHE_PATH, VECTORSTORE_PATH
from rag_core.evaluator import evaluate_answer
from rag_core.index_builder import build_and_save_index, load_node_cache, load_vectorstore
from rag_core.logging_utils import get_model_flow_logger, log_event
from rag_core.rag_chain import ContextDocument, build_rag_chain
from rag_core.rag_chain_helper import rewrite_question_with_history
@dataclass(frozen=True)
class RefreshConfig:
enabled: bool
at_hour: int
at_minute: int
only_fixed_urls: bool
rebuild_on_startup: bool
@classmethod
def from_env(cls) -> "RefreshConfig":
return cls(
enabled=os.getenv("REFRESH_ENABLED", "true").lower() == "true",
at_hour=int(os.getenv("REFRESH_AT_HOUR", "3")),
at_minute=int(os.getenv("REFRESH_AT_MINUTE", "0")),
only_fixed_urls=os.getenv("REFRESH_ONLY_FIXED_URLS", "false").lower() == "true",
rebuild_on_startup=os.getenv("REFRESH_ON_STARTUP", "false").lower() == "true",
)
def _history_to_text(history: Sequence[Sequence[str]]) -> str:
"""Convert Gradio history ([[user, bot], ...]) into a compact text block."""
if not history:
return ""
lines: List[str] = []
for turn in history:
if not turn or len(turn) < 2:
continue
user_msg, assistant_msg = turn[0], turn[1]
lines.append(f"User: {user_msg}")
lines.append(f"Assistant: {assistant_msg}")
return "\n".join(lines)
def _docs_to_loggable(docs: Sequence[Any], max_chars: int = 220) -> List[dict]:
"""Return lightweight document metadata for logs without dumping full context."""
summaries: List[dict] = []
for doc in docs or []:
source = (doc.metadata or {}).get("source", "unknown")
text = (doc.page_content or "").strip().replace("\n", " ")
summaries.append(
{
"source": source,
"preview": text[:max_chars] + ("..." if len(text) > max_chars else ""),
"metadata": doc.metadata or {},
}
)
return summaries
class CareerQARuntime:
"""Owns model state, refresh scheduling, and answer generation."""
def __init__(self, refresh_config: RefreshConfig | None = None):
self.refresh_config = refresh_config or RefreshConfig.from_env()
self.logger = get_model_flow_logger()
self.state_lock = threading.RLock()
self.stop_refresh_event = threading.Event()
self.vectorstore = None
self.rag_chain = None
self.retriever = None
self.system_prompt = None
self.refresh_thread = None
self.pending_clarification: Dict[str, Any] | None = None
self.init_rag()
self.start_refresh_thread()
atexit.register(self.stop)
def _load_context_docs(self) -> List[ContextDocument]:
return [
ContextDocument(
page_content=record["text"],
metadata=record["metadata"],
)
for record in load_node_cache()
]
def init_rag(self) -> None:
"""Build the index if needed, then load the vectorstore and chain."""
index_path = NODE_CACHE_PATH
should_rebuild = self.refresh_config.rebuild_on_startup or not index_path.exists()
if should_rebuild:
try:
chunk_count, _ = build_and_save_index()
self.log_event(
"refresh.index_built",
mode="startup_rebuild",
chunks=chunk_count,
)
except Exception as exc:
if not index_path.exists():
raise
self.log_event(
"refresh.startup_rebuild_failed",
error=str(exc),
fallback="loading_existing_index",
)
vector_index = load_vectorstore()
docs = self._load_context_docs()
rag_chain, retriever, system_prompt = build_rag_chain(
vector_index,
docs,
k=5,
max_docs=3,
)
with self.state_lock:
self.vectorstore = vector_index
self.rag_chain = rag_chain
self.retriever = retriever
self.system_prompt = system_prompt
self.log_event("init_rag.ready", vectorstore_path=VECTORSTORE_PATH)
def log_event(self, event: str, **payload) -> None:
log_event(self.logger, event, **payload)
def refresh_rag_once(self) -> None:
"""Rebuild the index and atomically swap in a fresh chain."""
self.log_event(
"refresh.start",
only_fixed_urls=self.refresh_config.only_fixed_urls,
)
try:
chunk_count, _ = build_and_save_index()
self.log_event("refresh.index_built", mode="crawl", chunks=chunk_count)
vector_index = load_vectorstore()
docs = self._load_context_docs()
rag_chain, retriever, system_prompt = build_rag_chain(
vector_index,
docs,
k=5,
max_docs=3,
)
with self.state_lock:
self.vectorstore = vector_index
self.rag_chain = rag_chain
self.retriever = retriever
self.system_prompt = system_prompt
self.log_event("refresh.done", status="ok")
except Exception as exc:
self.log_event("refresh.error", error=str(exc))
def _seconds_until_next_run(self) -> int:
"""Compute the delay until the next scheduled refresh in local time."""
now = time.localtime()
target = time.mktime(
(
now.tm_year,
now.tm_mon,
now.tm_mday,
self.refresh_config.at_hour,
self.refresh_config.at_minute,
0,
now.tm_wday,
now.tm_yday,
now.tm_isdst,
)
)
now_ts = time.time()
if target <= now_ts:
target += 24 * 60 * 60
return int(target - now_ts)
def _daily_refresh_loop(self) -> None:
time.sleep(3)
while not self.stop_refresh_event.is_set():
sleep_seconds = self._seconds_until_next_run()
self.log_event(
"refresh.sleep",
seconds=sleep_seconds,
at_hour=self.refresh_config.at_hour,
at_minute=self.refresh_config.at_minute,
)
while sleep_seconds > 0 and not self.stop_refresh_event.is_set():
step = min(5, sleep_seconds)
time.sleep(step)
sleep_seconds -= step
if self.stop_refresh_event.is_set():
break
self.refresh_rag_once()
def start_refresh_thread(self) -> None:
if not self.refresh_config.enabled:
self.log_event("refresh.disabled")
return
if self.refresh_thread and self.refresh_thread.is_alive():
return
self.refresh_thread = threading.Thread(
target=self._daily_refresh_loop,
daemon=True,
)
self.refresh_thread.start()
self.log_event(
"refresh.thread_started",
daily_at=f"{self.refresh_config.at_hour:02d}:{self.refresh_config.at_minute:02d}",
)
def stop(self) -> None:
self.stop_refresh_event.set()
def _run_rag(
self,
question: str,
history_text: str,
forced_tool: str | None = None,
) -> Dict[str, Any]:
with self.state_lock:
local_rag_chain = self.rag_chain
payload: Dict[str, Any] = {
"input": question,
"chat_history": history_text,
}
if forced_tool:
payload["forced_tool"] = forced_tool
return local_rag_chain.invoke(payload)
def generate_answer(self, message: str, history: Sequence[Sequence[str]]) -> str:
"""Run rewrite, RAG, evaluation, and optional retry for one user message."""
self.log_event("request.start", user_message=message)
history_text = _history_to_text(history)
with self.state_lock:
pending = self.pending_clarification
if pending:
with self.state_lock:
local_rag_chain = self.rag_chain
forced_tool = local_rag_chain.resolve_clarification_reply(
message,
pending.get("candidate_tools", []),
pending.get("preferred_tool", "about"),
)
standalone_question = pending.get("original_question", message)
self.log_event(
"routing.clarification_resolved",
original_question=standalone_question,
clarification_reply=message,
forced_tool=forced_tool,
candidate_tools=pending.get("candidate_tools", []),
)
with self.state_lock:
self.pending_clarification = None
try:
rag_result = self._run_rag(standalone_question, history_text, forced_tool=forced_tool)
except Exception as exc:
self.log_event("rag.error", error=str(exc))
fallback = (
"I'm having trouble accessing my knowledge base right now. "
"Please try again in a moment."
)
self.log_event("request.end", final_answer_preview=fallback[:400])
return fallback
else:
try:
standalone_question = rewrite_question_with_history(history, message)
except Exception as exc:
self.log_event("rewrite.error", error=str(exc))
standalone_question = message
self.log_event(
"rewrite.done",
standalone_question=standalone_question,
history_chars=len(history_text),
)
try:
rag_result = self._run_rag(standalone_question, history_text)
except Exception as exc:
self.log_event("rag.error", error=str(exc))
fallback = (
"I'm having trouble accessing my knowledge base right now. "
"Please try again in a moment."
)
self.log_event("request.end", final_answer_preview=fallback[:400])
return fallback
if rag_result.get("needs_clarification"):
clarification_answer = rag_result.get("answer", "") or (
"Could you clarify which area you want me to focus on?"
)
with self.state_lock:
self.pending_clarification = {
"candidate_tools": rag_result.get("candidate_tools", []),
"preferred_tool": rag_result.get("preferred_tool", "about"),
"original_question": rag_result.get("original_question", standalone_question),
}
self.log_event(
"routing.clarification_requested",
original_question=standalone_question,
candidate_tools=rag_result.get("candidate_tools", []),
preferred_tool=rag_result.get("preferred_tool", "about"),
)
self.log_event("request.end", final_answer_preview=clarification_answer[:400])
return clarification_answer
answer_1 = rag_result.get("answer", "") or ""
context_docs_1 = rag_result.get("context", []) or []
self.log_event(
"rag.done",
answer_preview=answer_1[:400] + ("..." if len(answer_1) > 400 else ""),
retrieved_count=len(context_docs_1),
retrieved_docs=_docs_to_loggable(context_docs_1),
)
with self.state_lock:
local_system_prompt = self.system_prompt
eval_result_1 = None
try:
eval_result_1 = evaluate_answer(
system_prompt=local_system_prompt,
question=message,
context_docs=context_docs_1,
answer=answer_1,
)
self.log_event(
"eval.done",
overall_score=float(eval_result_1.overall_score),
grounded=float(eval_result_1.grounded_in_context_score),
hallucination=bool(eval_result_1.hallucination_detected),
feedback=str(eval_result_1.feedback),
)
except Exception as exc:
self.log_event("eval.error", error=str(exc))
final_answer = answer_1
try:
should_retry = (
eval_result_1 is not None
and (
eval_result_1.overall_score < 0.70
or getattr(eval_result_1, "should_retry", True)
)
)
if should_retry:
revision_prompt = (
f"{standalone_question}\n\n"
f"You previously answered this:\n{answer_1}\n\n"
"An evaluator found issues. Revise your answer to address the feedback below.\n"
"Rules:\n"
"- Use ONLY the provided context.\n"
'- If the context does not support the claim, say "I don\'t know".\n'
"- Be specific and grounded.\n\n"
f"Evaluator feedback:\n{eval_result_1.feedback}\n"
)
self.log_event(
"retry.triggered",
reason="eval_score_below_threshold",
threshold=0.90,
)
try:
retry_result = self._run_rag(revision_prompt, history_text)
answer_2 = retry_result.get("answer", "") or ""
context_docs_2 = retry_result.get("context", []) or []
self.log_event(
"rag.retry_done",
answer_preview=answer_2[:400] + ("..." if len(answer_2) > 400 else ""),
retrieved_count=len(context_docs_2),
retrieved_docs=_docs_to_loggable(context_docs_2),
)
eval_result_2 = None
try:
eval_result_2 = evaluate_answer(
system_prompt=local_system_prompt,
question=message,
context_docs=context_docs_2,
answer=answer_2,
)
self.log_event(
"eval.retry_done",
overall_score=float(eval_result_2.overall_score),
grounded=float(eval_result_2.grounded_in_context_score),
hallucination=bool(eval_result_2.hallucination_detected),
feedback=str(eval_result_2.feedback),
)
except Exception as exc:
self.log_event("eval.retry_error", error=str(exc))
if eval_result_2 is not None and eval_result_1.overall_score <= eval_result_2.overall_score:
final_answer = answer_2
else:
final_answer = answer_1
except Exception as exc:
self.log_event("rag.retry_error", error=str(exc))
final_answer = answer_1
except Exception as exc:
self.log_event("retry.block_error", error=str(exc))
final_answer = answer_1
self.log_event(
"request.end",
final_answer_preview=final_answer[:400] + ("..." if len(final_answer) > 400 else ""),
)
return final_answer
def respond(self, message: str, history: Sequence[Sequence[str]]):
"""Gradio callback wrapper that converts failures into safe user responses."""
history = history or []
if not message:
return "", history
try:
answer = self.generate_answer(message, history)
except Exception as exc:
self.log_event("respond.fatal_error", error=str(exc))
answer = (
"Something went wrong on my side while trying to answer. "
"Please try again in a moment."
)
updated_history = list(history) + [[message, answer]]
return "", updated_history
|