File size: 13,238 Bytes
5a3b322
 
 
 
 
 
 
 
 
 
 
2be126b
5a3b322
0a70294
5a3b322
 
 
 
0a70294
5a3b322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30d16ab
 
0a70294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a3b322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd13ca4
 
 
5a3b322
 
83152bb
fd13ca4
 
5a3b322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30d16ab
5a3b322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a70294
5a3b322
 
0a70294
 
 
 
 
 
 
 
 
7aa305b
 
0a70294
7aa305b
0a70294
5a3b322
7aa305b
5a3b322
 
0a70294
 
5a3b322
 
 
 
7aa305b
5a3b322
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
from __future__ import annotations

"""
Minimal chat backend (FastAPI) that delegates to the agent app pipeline.

Run:
  uvicorn agent.server:app --reload --port 8000
"""

import uuid
import json
import traceback
from typing import Optional, Callable
from collections import deque, OrderedDict
import time
import math

from fastapi import FastAPI
from fastapi import Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel

from functools import lru_cache
import os
from data.catalog_loader import load_catalog
from recommenders.bm25 import BM25Recommender
from recommenders.vector_recommender import VectorRecommender
from retrieval.vector_index import VectorIndex
from models.embedding_model import EmbeddingModel
from rerankers.cross_encoder import CrossEncoderReranker
from tools.query_plan_tool import build_query_plan
from tools.query_plan_tool_llm import build_query_plan_llm
from llm.nu_extract import NuExtractWrapper, default_query_rewrite_examples
from llm.qwen_rewriter import QwenRewriter
from tools.retrieve_tool import retrieve_candidates
from tools.rerank_tool import rerank_candidates
from tools.constraints_tool import apply_constraints


os.environ.setdefault("HF_HOME", "/home/user/.cache/huggingface")

import hashlib

# ---------------------------
# Simple in-memory TTL + LRU cache for responses
# ---------------------------
_CACHE_MAX_ITEMS = int(os.getenv("RECO_CACHE_MAX_ITEMS", "500"))
_CACHE_TTL_SECONDS = int(os.getenv("RECO_CACHE_TTL_SECONDS", str(24 * 3600)))

_reco_cache: OrderedDict[str, tuple[float, dict]] = OrderedDict()

def _normalize_query(q: str) -> str:
    return " ".join((q or "").lower().split())

def _cache_key(query: str, llm_model: str | None, verbose: bool, endpoint: str) -> str:
    model = (llm_model or os.getenv("LLM_MODEL", "") or "default").strip().lower()
    raw = f"ep={endpoint}|q={_normalize_query(query)}|m={model}|v={int(verbose)}"
    return hashlib.sha256(raw.encode("utf-8")).hexdigest()

def _cache_get(key: str):
    item = _reco_cache.get(key)
    if not item:
        return None
    expires_at, value = item
    if time.time() > expires_at:
        _reco_cache.pop(key, None)
        return None
    _reco_cache.move_to_end(key)  # LRU refresh
    return value

def _cache_set(key: str, value: dict):
    _reco_cache[key] = (time.time() + _CACHE_TTL_SECONDS, value)
    _reco_cache.move_to_end(key)
    while len(_reco_cache) > _CACHE_MAX_ITEMS:
        _reco_cache.popitem(last=False)

class ChatRequest(BaseModel):
    query: str
    clarification_answer: Optional[str] = None
    verbose: bool = False

class RecommendRequest(BaseModel):
    query: str
    llm_model: Optional[str] = None
    verbose: bool = False


def _make_catalog_lookup(df_catalog) -> Callable[[str], dict]:
    cat = df_catalog.set_index("assessment_id")
    def lookup(aid: str) -> dict:
        if aid in cat.index:
            return cat.loc[aid].to_dict()
        return {}
    return lookup


@lru_cache(maxsize=1)
def load_resources(llm_model_override: Optional[str] = None):
    df_catalog, _, _ = load_catalog("data/catalog_docs_rich.jsonl")
    bm25 = BM25Recommender(df_catalog)
    embed = EmbeddingModel("BAAI/bge-small-en-v1.5")
    index = VectorIndex.load("data/faiss_index/index_bge.faiss")
    with open("data/embeddings_bge/assessment_ids.json") as f:
        ids = json.load(f)
    vec = VectorRecommender(embed, index, df_catalog, ids, k_candidates=200)
    reranker = CrossEncoderReranker(model_name="models/reranker_crossenc/v0.1.0")
    lookup = _make_catalog_lookup(df_catalog)
    catalog_by_id = {row["assessment_id"]: row for _, row in df_catalog.iterrows()}
    vocab = {}
    vocab_path = "data/catalog_role_vocab.json"
    if os.path.exists(vocab_path):
        try:
            with open(vocab_path) as vf:
                vocab = json.load(vf)
        except Exception:
            vocab = {}
    # Optional LLM rewriter; choose via request override or env LLM_MODEL
    llm_extractor = None
    llm_model = llm_model_override or os.getenv("LLM_MODEL", "").strip()
    if not llm_model:
        llm_model = "Qwen/Qwen2.5-1.5B-Instruct"
    try:
        if "qwen" in llm_model.lower():
           llm_extractor = QwenRewriter(model_name=llm_model, default_examples=default_query_rewrite_examples())

        elif not os.getenv("GOOGLE_API_KEY"):
            llm_extractor = NuExtractWrapper(default_examples=default_query_rewrite_examples())
    except Exception as e:
        print("LLM init failed:", repr(e))
        traceback.print_exc()
        llm_extractor = None
    return df_catalog, bm25, vec, reranker, lookup, vocab, llm_extractor, catalog_by_id


def _infer_remote_adaptive(meta: dict) -> (Optional[bool], Optional[bool]):
    remote = meta.get("remote_support", True if meta.get("remote_support") is None else meta.get("remote_support"))
    adaptive = meta.get("adaptive_support")
    text_blob = " ".join([str(meta.get("name", "")), str(meta.get("description", "")), str(meta.get("doc_text", ""))]).lower()
    if adaptive is None and "adaptive" in text_blob:
        adaptive = True
    return remote, adaptive


def _build_plan_with_fallback(query: str, vocab: dict, llm_extractor):
    """
    Build the query plan using the LLM rewriter (Qwen) when available, otherwise
    fall back to deterministic rewrite. No Gemini refinement to keep behavior predictable.
    """
    try:
        return build_query_plan(query, vocab=vocab, llm_extractor=llm_extractor)
    except Exception:
        return build_query_plan(query, vocab=vocab)


def _safe_num(val):
    try:
        if val is None:
            return None
        f = float(val)
        if math.isfinite(f):
            return f
    except Exception:
        return None
    return None


def _sanitize_debug(obj):
    """Recursively replace NaN/inf with None to keep JSON safe."""
    if isinstance(obj, dict):
        return {k: _sanitize_debug(v) for k, v in obj.items()}
    if isinstance(obj, list):
        return [_sanitize_debug(v) for v in obj]
    if isinstance(obj, tuple):
        return tuple(_sanitize_debug(v) for v in obj)
    if isinstance(obj, (int, float)):
        return _safe_num(obj)
    return obj


CODE_TO_FULL = {
    "A": "Ability & Aptitude",
    "B": "Biodata & Situational Judgement",
    "C": "Competencies",
    "D": "Development & 360",
    "E": "Assessment Exercises",
    "K": "Knowledge & Skills",
    "P": "Personality & Behavior",
    "S": "Simulations",
}


def _format_test_types(meta: dict) -> list[str]:
    if meta.get("test_type_full"):
        raw = meta["test_type_full"]
    elif meta.get("test_type"):
        raw = meta["test_type"]
    else:
        return []
    if isinstance(raw, list):
        vals = raw
    else:
        vals = str(raw).replace("/", ",").split(",")
    out = []
    for v in vals:
        v = v.strip()
        if not v:
            continue
        # Map letter codes to full names when applicable
        if len(v) == 1 and v in CODE_TO_FULL:
            out.append(CODE_TO_FULL[v])
        else:
            out.append(v)
    return out


def _run_pipeline(query: str, topn: int = 200, verbose: bool = False, llm_model: Optional[str] = None):

    df_catalog, bm25, vec, reranker, lookup, vocab, llm_extractor, catalog_by_id = load_resources(llm_model_override=llm_model)
    plan = _build_plan_with_fallback(query, vocab=vocab, llm_extractor=llm_extractor)
    cand_set = retrieve_candidates(plan, bm25, vec, topn=topn, catalog_df=df_catalog)
    ranked = rerank_candidates(plan, cand_set, reranker, df_catalog, k=10)
    final_list = apply_constraints(plan, ranked, catalog_by_id, k=10)

    debug_payload = {}
    if verbose:
        debug_payload["plan"] = plan.dict()
        # If plan carries a source (from planner), include it
        if hasattr(plan, "plan_source"):
            debug_payload["plan_source"] = getattr(plan, "plan_source")
        # Capture NuExtract LLM debug if present
        if hasattr(plan, "llm_debug") and plan.llm_debug:
            debug_payload["llm_debug"] = plan.llm_debug
        if hasattr(cand_set, "fusion") and cand_set.fusion:
            debug_payload["fusion"] = cand_set.fusion
        debug_payload["candidates"] = [
            {
                "assessment_id": c.assessment_id,
                "bm25_rank": c.bm25_rank,
                "vector_rank": c.vector_rank,
                "hybrid_rank": c.hybrid_rank,
                "bm25_score": _safe_num(c.bm25_score),
                "vector_score": _safe_num(c.vector_score),
                "score": _safe_num(c.score),
            }
            for c in cand_set.candidates[: min(20, len(cand_set.candidates))]
        ]
        debug_payload["rerank"] = [
            {"assessment_id": r.assessment_id, "score": _safe_num(r.score)}
            for r in ranked.items[: min(20, len(ranked.items))]
        ]
        debug_payload["constraints"] = [
            {
                "assessment_id": r.assessment_id,
                "score": _safe_num(r.score),
                "debug": r.debug,
            }
            for r in final_list.items
        ]

    final_results = []
    for item in final_list.items:
        meta = lookup(item.assessment_id)
        remote, adaptive = _infer_remote_adaptive(meta)
        score = _safe_num(item.score)
        duration = _safe_num(meta.get("duration_minutes") or meta.get("duration"))
        duration_int = int(duration) if duration is not None else None
        description = meta.get("description") or meta.get("doc_text") or ""
        test_types = _format_test_types(meta)
        final_results.append(
            {
                "url": meta.get("url"),
                "name": meta.get("name"),
                "adaptive_support": "Yes" if adaptive else "No",
                "description": description,
                "duration": duration_int if duration_int is not None else 0,
                "remote_support": "Yes" if remote else "No",
                "test_type": test_types,
            }
        )
    # Guarantee at least one result if pipeline produced candidates
    if not final_results and ranked.items:
        item = ranked.items[0]
        meta = lookup(item.assessment_id)
        remote, adaptive = _infer_remote_adaptive(meta)
        duration = _safe_num(meta.get("duration_minutes") or meta.get("duration"))
        duration_int = int(duration) if duration is not None else 0
        final_results.append(
            {
                "url": meta.get("url"),
                "name": meta.get("name"),
                "adaptive_support": "Yes" if adaptive else "No",
                "description": meta.get("description") or meta.get("doc_text") or "",
                "duration": duration_int,
                "remote_support": "Yes" if remote else "No",
                "test_type": _format_test_types(meta),
            }
        )
    summary = {"plan": plan.intent, "top": len(final_results)}
    return final_results, summary, debug_payload


app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=False,  # '*' cannot be used with credentials
    allow_methods=["*"],
    allow_headers=["*"],
)

# Serve frontend assets
app.mount("/static", StaticFiles(directory="frontend"), name="static")

# Simple in-process rate limiter (max 5 requests per second)
_timestamps = deque()
_RATE_LIMIT = 5
_WINDOW = 1.0

def _allow_request() -> bool:
    now = time.time()
    while _timestamps and now - _timestamps[0] > _WINDOW:
        _timestamps.popleft()
    if len(_timestamps) < _RATE_LIMIT:
        _timestamps.append(now)
        return True
    return False


@app.post("/chat")
def chat(req: ChatRequest):
    if not _allow_request():
        return {"error": "rate limit exceeded"}
    trace_id = str(uuid.uuid4())
    final_results, summary, debug_payload = _run_pipeline(req.query, verbose=req.verbose)
    payload = {"trace_id": trace_id, "final_results": final_results}
    if req.verbose:
        payload["summary"] = summary
        payload["debug"] = _sanitize_debug(debug_payload)
    return payload


@app.post("/recommend")
def recommend(req: RecommendRequest, response: Response):
    if not _allow_request():
        return {"error": "rate limit exceeded"}

    key = _cache_key(req.query, req.llm_model, req.verbose, endpoint="/recommend")
    cached = _cache_get(key)
    if cached is not None:
        response.headers["X-Cache"] = "HIT"
        return cached

    response.headers["X-Cache"] = "MISS"

    llm_model = req.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
    verbose = bool(req.verbose) if req.verbose is not None else False
    final_results, summary, debug_payload = _run_pipeline(
        req.query, verbose=verbose, llm_model=llm_model
    )
    resp = {"recommended_assessments": final_results}
    if verbose:
        resp["debug"] = _sanitize_debug(debug_payload)
        resp["summary"] = summary

    _cache_set(key, resp)
    return resp

@app.get("/health")
def health():
    return {"status": "healthy"}


@app.get("/")
def index():
    # Serve the SPA entry point
    return FileResponse("frontend/index.html")