from __future__ import annotations """ Minimal chat backend (FastAPI) that delegates to the agent app pipeline. Run: uvicorn agent.server:app --reload --port 8000 """ import uuid import json import traceback from typing import Optional, Callable from collections import deque, OrderedDict import time import math from fastapi import FastAPI from fastapi import Response from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel from functools import lru_cache import os from data.catalog_loader import load_catalog from recommenders.bm25 import BM25Recommender from recommenders.vector_recommender import VectorRecommender from retrieval.vector_index import VectorIndex from models.embedding_model import EmbeddingModel from rerankers.cross_encoder import CrossEncoderReranker from tools.query_plan_tool import build_query_plan from tools.query_plan_tool_llm import build_query_plan_llm from llm.nu_extract import NuExtractWrapper, default_query_rewrite_examples from llm.qwen_rewriter import QwenRewriter from tools.retrieve_tool import retrieve_candidates from tools.rerank_tool import rerank_candidates from tools.constraints_tool import apply_constraints os.environ.setdefault("HF_HOME", "/home/user/.cache/huggingface") import hashlib # --------------------------- # Simple in-memory TTL + LRU cache for responses # --------------------------- _CACHE_MAX_ITEMS = int(os.getenv("RECO_CACHE_MAX_ITEMS", "500")) _CACHE_TTL_SECONDS = int(os.getenv("RECO_CACHE_TTL_SECONDS", str(24 * 3600))) _reco_cache: OrderedDict[str, tuple[float, dict]] = OrderedDict() def _normalize_query(q: str) -> str: return " ".join((q or "").lower().split()) def _cache_key(query: str, llm_model: str | None, verbose: bool, endpoint: str) -> str: model = (llm_model or os.getenv("LLM_MODEL", "") or "default").strip().lower() raw = f"ep={endpoint}|q={_normalize_query(query)}|m={model}|v={int(verbose)}" return hashlib.sha256(raw.encode("utf-8")).hexdigest() def _cache_get(key: str): item = _reco_cache.get(key) if not item: return None expires_at, value = item if time.time() > expires_at: _reco_cache.pop(key, None) return None _reco_cache.move_to_end(key) # LRU refresh return value def _cache_set(key: str, value: dict): _reco_cache[key] = (time.time() + _CACHE_TTL_SECONDS, value) _reco_cache.move_to_end(key) while len(_reco_cache) > _CACHE_MAX_ITEMS: _reco_cache.popitem(last=False) class ChatRequest(BaseModel): query: str clarification_answer: Optional[str] = None verbose: bool = False class RecommendRequest(BaseModel): query: str llm_model: Optional[str] = None verbose: bool = False def _make_catalog_lookup(df_catalog) -> Callable[[str], dict]: cat = df_catalog.set_index("assessment_id") def lookup(aid: str) -> dict: if aid in cat.index: return cat.loc[aid].to_dict() return {} return lookup @lru_cache(maxsize=1) def load_resources(llm_model_override: Optional[str] = None): df_catalog, _, _ = load_catalog("data/catalog_docs_rich.jsonl") bm25 = BM25Recommender(df_catalog) embed = EmbeddingModel("BAAI/bge-small-en-v1.5") index = VectorIndex.load("data/faiss_index/index_bge.faiss") with open("data/embeddings_bge/assessment_ids.json") as f: ids = json.load(f) vec = VectorRecommender(embed, index, df_catalog, ids, k_candidates=200) reranker = CrossEncoderReranker(model_name="models/reranker_crossenc/v0.1.0") lookup = _make_catalog_lookup(df_catalog) catalog_by_id = {row["assessment_id"]: row for _, row in df_catalog.iterrows()} vocab = {} vocab_path = "data/catalog_role_vocab.json" if os.path.exists(vocab_path): try: with open(vocab_path) as vf: vocab = json.load(vf) except Exception: vocab = {} # Optional LLM rewriter; choose via request override or env LLM_MODEL llm_extractor = None llm_model = llm_model_override or os.getenv("LLM_MODEL", "").strip() if not llm_model: llm_model = "Qwen/Qwen2.5-1.5B-Instruct" try: if "qwen" in llm_model.lower(): llm_extractor = QwenRewriter(model_name=llm_model, default_examples=default_query_rewrite_examples()) elif not os.getenv("GOOGLE_API_KEY"): llm_extractor = NuExtractWrapper(default_examples=default_query_rewrite_examples()) except Exception as e: print("LLM init failed:", repr(e)) traceback.print_exc() llm_extractor = None return df_catalog, bm25, vec, reranker, lookup, vocab, llm_extractor, catalog_by_id def _infer_remote_adaptive(meta: dict) -> (Optional[bool], Optional[bool]): remote = meta.get("remote_support", True if meta.get("remote_support") is None else meta.get("remote_support")) adaptive = meta.get("adaptive_support") text_blob = " ".join([str(meta.get("name", "")), str(meta.get("description", "")), str(meta.get("doc_text", ""))]).lower() if adaptive is None and "adaptive" in text_blob: adaptive = True return remote, adaptive def _build_plan_with_fallback(query: str, vocab: dict, llm_extractor): """ Build the query plan using the LLM rewriter (Qwen) when available, otherwise fall back to deterministic rewrite. No Gemini refinement to keep behavior predictable. """ try: return build_query_plan(query, vocab=vocab, llm_extractor=llm_extractor) except Exception: return build_query_plan(query, vocab=vocab) def _safe_num(val): try: if val is None: return None f = float(val) if math.isfinite(f): return f except Exception: return None return None def _sanitize_debug(obj): """Recursively replace NaN/inf with None to keep JSON safe.""" if isinstance(obj, dict): return {k: _sanitize_debug(v) for k, v in obj.items()} if isinstance(obj, list): return [_sanitize_debug(v) for v in obj] if isinstance(obj, tuple): return tuple(_sanitize_debug(v) for v in obj) if isinstance(obj, (int, float)): return _safe_num(obj) return obj CODE_TO_FULL = { "A": "Ability & Aptitude", "B": "Biodata & Situational Judgement", "C": "Competencies", "D": "Development & 360", "E": "Assessment Exercises", "K": "Knowledge & Skills", "P": "Personality & Behavior", "S": "Simulations", } def _format_test_types(meta: dict) -> list[str]: if meta.get("test_type_full"): raw = meta["test_type_full"] elif meta.get("test_type"): raw = meta["test_type"] else: return [] if isinstance(raw, list): vals = raw else: vals = str(raw).replace("/", ",").split(",") out = [] for v in vals: v = v.strip() if not v: continue # Map letter codes to full names when applicable if len(v) == 1 and v in CODE_TO_FULL: out.append(CODE_TO_FULL[v]) else: out.append(v) return out def _run_pipeline(query: str, topn: int = 200, verbose: bool = False, llm_model: Optional[str] = None): df_catalog, bm25, vec, reranker, lookup, vocab, llm_extractor, catalog_by_id = load_resources(llm_model_override=llm_model) plan = _build_plan_with_fallback(query, vocab=vocab, llm_extractor=llm_extractor) cand_set = retrieve_candidates(plan, bm25, vec, topn=topn, catalog_df=df_catalog) ranked = rerank_candidates(plan, cand_set, reranker, df_catalog, k=10) final_list = apply_constraints(plan, ranked, catalog_by_id, k=10) debug_payload = {} if verbose: debug_payload["plan"] = plan.dict() # If plan carries a source (from planner), include it if hasattr(plan, "plan_source"): debug_payload["plan_source"] = getattr(plan, "plan_source") # Capture NuExtract LLM debug if present if hasattr(plan, "llm_debug") and plan.llm_debug: debug_payload["llm_debug"] = plan.llm_debug if hasattr(cand_set, "fusion") and cand_set.fusion: debug_payload["fusion"] = cand_set.fusion debug_payload["candidates"] = [ { "assessment_id": c.assessment_id, "bm25_rank": c.bm25_rank, "vector_rank": c.vector_rank, "hybrid_rank": c.hybrid_rank, "bm25_score": _safe_num(c.bm25_score), "vector_score": _safe_num(c.vector_score), "score": _safe_num(c.score), } for c in cand_set.candidates[: min(20, len(cand_set.candidates))] ] debug_payload["rerank"] = [ {"assessment_id": r.assessment_id, "score": _safe_num(r.score)} for r in ranked.items[: min(20, len(ranked.items))] ] debug_payload["constraints"] = [ { "assessment_id": r.assessment_id, "score": _safe_num(r.score), "debug": r.debug, } for r in final_list.items ] final_results = [] for item in final_list.items: meta = lookup(item.assessment_id) remote, adaptive = _infer_remote_adaptive(meta) score = _safe_num(item.score) duration = _safe_num(meta.get("duration_minutes") or meta.get("duration")) duration_int = int(duration) if duration is not None else None description = meta.get("description") or meta.get("doc_text") or "" test_types = _format_test_types(meta) final_results.append( { "url": meta.get("url"), "name": meta.get("name"), "adaptive_support": "Yes" if adaptive else "No", "description": description, "duration": duration_int if duration_int is not None else 0, "remote_support": "Yes" if remote else "No", "test_type": test_types, } ) # Guarantee at least one result if pipeline produced candidates if not final_results and ranked.items: item = ranked.items[0] meta = lookup(item.assessment_id) remote, adaptive = _infer_remote_adaptive(meta) duration = _safe_num(meta.get("duration_minutes") or meta.get("duration")) duration_int = int(duration) if duration is not None else 0 final_results.append( { "url": meta.get("url"), "name": meta.get("name"), "adaptive_support": "Yes" if adaptive else "No", "description": meta.get("description") or meta.get("doc_text") or "", "duration": duration_int, "remote_support": "Yes" if remote else "No", "test_type": _format_test_types(meta), } ) summary = {"plan": plan.intent, "top": len(final_results)} return final_results, summary, debug_payload app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=False, # '*' cannot be used with credentials allow_methods=["*"], allow_headers=["*"], ) # Serve frontend assets app.mount("/static", StaticFiles(directory="frontend"), name="static") # Simple in-process rate limiter (max 5 requests per second) _timestamps = deque() _RATE_LIMIT = 5 _WINDOW = 1.0 def _allow_request() -> bool: now = time.time() while _timestamps and now - _timestamps[0] > _WINDOW: _timestamps.popleft() if len(_timestamps) < _RATE_LIMIT: _timestamps.append(now) return True return False @app.post("/chat") def chat(req: ChatRequest): if not _allow_request(): return {"error": "rate limit exceeded"} trace_id = str(uuid.uuid4()) final_results, summary, debug_payload = _run_pipeline(req.query, verbose=req.verbose) payload = {"trace_id": trace_id, "final_results": final_results} if req.verbose: payload["summary"] = summary payload["debug"] = _sanitize_debug(debug_payload) return payload @app.post("/recommend") def recommend(req: RecommendRequest, response: Response): if not _allow_request(): return {"error": "rate limit exceeded"} key = _cache_key(req.query, req.llm_model, req.verbose, endpoint="/recommend") cached = _cache_get(key) if cached is not None: response.headers["X-Cache"] = "HIT" return cached response.headers["X-Cache"] = "MISS" llm_model = req.llm_model or "Qwen/Qwen2.5-1.5B-Instruct" verbose = bool(req.verbose) if req.verbose is not None else False final_results, summary, debug_payload = _run_pipeline( req.query, verbose=verbose, llm_model=llm_model ) resp = {"recommended_assessments": final_results} if verbose: resp["debug"] = _sanitize_debug(debug_payload) resp["summary"] = summary _cache_set(key, resp) return resp @app.get("/health") def health(): return {"status": "healthy"} @app.get("/") def index(): # Serve the SPA entry point return FileResponse("frontend/index.html")