File size: 6,137 Bytes
5a3b322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from __future__ import annotations

"""
Chat-style agent using Gemini for planning + explanation, deterministic tools for retrieval/rerank.
Set GOOGLE_API_KEY in your environment.
"""
import json
import os
from typing import Callable

import pandas as pd

from data.catalog_loader import load_catalog
from recommenders.bm25 import BM25Recommender
from recommenders.vector_recommender import VectorRecommender
from retrieval.vector_index import VectorIndex
from models.embedding_model import EmbeddingModel
from rerankers.cross_encoder import CrossEncoderReranker

from tools.query_plan_tool_llm import build_query_plan_llm
from tools.query_plan_tool import build_query_plan as deterministic_plan
from tools.retrieve_tool import retrieve_candidates
from tools.rerank_tool import rerank_candidates
from tools.constraints_tool import apply_constraints
from tools.explain_tool import explain
from schemas.query_plan import QueryPlan


def load_resources():
    df_catalog, _, _ = load_catalog("data/catalog_docs_rich.jsonl")
    bm25 = BM25Recommender(df_catalog)
    embed = EmbeddingModel("BAAI/bge-small-en-v1.5")
    index = VectorIndex.load("data/faiss_index/index_bge.faiss")
    with open("data/embeddings_bge/assessment_ids.json") as f:
        ids = json.load(f)
    vec = VectorRecommender(embed, index, df_catalog, ids, k_candidates=200)
    catalog_by_id = {row["assessment_id"]: row for _, row in df_catalog.iterrows()}
    return df_catalog, bm25, vec, catalog_by_id


def make_catalog_lookup(df_catalog: pd.DataFrame) -> Callable[[str], dict]:
    cat = df_catalog.set_index("assessment_id")

    def lookup(aid: str) -> dict:
        if aid in cat.index:
            return cat.loc[aid].to_dict()
        return {}

    return lookup


def _maybe_clarify(plan: QueryPlan, cand_count: int, topn: int) -> str | None:
    # LLM-flagged clarification
    if plan.needs_clarification and plan.clarifying_question:
        return plan.clarifying_question
    # Coverage-based triggers
    if cand_count < max(10, int(0.25 * topn)):
        return "Results look thin. Clarify: are you looking for (1) personality/culture fit, (2) leadership judgment (SJT), or (3) role capability?"
    if plan.intent in {"BEHAVIORAL", "UNKNOWN", "MIXED"} and cand_count < max(20, int(0.5 * topn)):
        return "For culture/behavioral focus, choose: (1) personality/culture fit, (2) leadership judgment (SJT), or (3) role capability. Please pick one."
    return None


def run_chat(
    user_text: str,
    vocab_path: str = "data/catalog_role_vocab.json",
    model_name: str = "gemini-2.5-flash-lite",
    clarification_answer: str | None = None,
    topn: int = 200,
    verbose: bool = False,
):
    vocab = json.load(open(vocab_path)) if vocab_path and os.path.exists(vocab_path) else {}
    df_catalog, bm25, vec, catalog_by_id = load_resources()
    catalog_lookup = make_catalog_lookup(df_catalog)

    trace_id = f"trace-{abs(hash(user_text))}"
    log = {"trace_id": trace_id, "raw_query": user_text}

    # Plan with LLM; fallback deterministic if LLM fails
    try:
        plan = build_query_plan_llm(user_text, vocab=vocab, model_name=model_name)
        QueryPlan.model_validate(plan.dict())  # schema guard
        log["plan_source"] = "llm"
    except Exception as e:
        plan = deterministic_plan(user_text, vocab=vocab)
        log["plan_source"] = f"deterministic (llm_fail={str(e)})"
    log["query_plan"] = plan.dict()

    # Retrieve union
    cand_set = retrieve_candidates(plan, bm25, vec, topn=topn, catalog_df=df_catalog)
    if verbose:
        log["candidates"] = [c.model_dump() for c in cand_set.candidates[:10]]

    # Clarification loop
    question = _maybe_clarify(plan, cand_count=len(cand_set.candidates), topn=topn)
    if question and not clarification_answer:
        log["clarification"] = question
        if verbose:
            print(json.dumps(log, indent=2))
        return f"Clarification needed: {question}"
    if question and clarification_answer:
        clarified_text = f"{user_text}\nUser clarification: {clarification_answer}"
        try:
            plan = build_query_plan_llm(clarified_text, vocab=vocab, model_name=model_name)
            QueryPlan.model_validate(plan.dict())
        except Exception:
            plan = deterministic_plan(clarified_text, vocab=vocab)
        log["query_plan_clarified"] = plan.dict()
        cand_set = retrieve_candidates(plan, bm25, vec, topn=topn, catalog_df=df_catalog)
        if verbose:
            log["candidates_clarified"] = [c.model_dump() for c in cand_set.candidates[:10]]

    # Rerank
    reranker = CrossEncoderReranker(model_name="models/reranker_crossenc/v0.1.0")
    ranked = rerank_candidates(plan, cand_set, reranker, df_catalog, k=10)
    log["rerank"] = [item.model_dump() for item in ranked.items]

    # Constraints
    final_list = apply_constraints(plan, ranked, catalog_by_id, k=10)
    log["final"] = [item.model_dump() for item in final_list.items]

    # Explain
    summary = explain(plan, final_list, catalog_lookup)
    log["summary"] = summary

    # Compact output: top-10 with metadata
    final_results = []
    for item in final_list.items:
        meta = catalog_lookup(item.assessment_id)
        final_results.append(
            {
                "assessment_id": item.assessment_id,
                "score": item.score,
                "name": meta.get("name"),
                "url": meta.get("url"),
                "test_type_full": meta.get("test_type_full") or meta.get("test_type"),
                "duration": meta.get("duration_minutes") or meta.get("duration"),
            }
        )

    if verbose:
        log["final_results"] = final_results
        print(json.dumps(log, indent=2))
    else:
        print(json.dumps({"trace_id": trace_id, "final_results": final_results}, indent=2))

    return summary


if __name__ == "__main__":
    import sys

    if "GOOGLE_API_KEY" not in os.environ:
        print("Please set GOOGLE_API_KEY for Gemini.")
    user_text = " ".join(sys.argv[1:]) or "Find a 1 hour culture fit assessment for a COO"
    print(run_chat(user_text, verbose=False))