File size: 3,621 Bytes
45e9462
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import faiss
import numpy as np
from sentence_transformers import CrossEncoder
from collections import defaultdict
from data import debug_print,call_genai_embedding_api
from nodes.intent import CreditCardState
from data import df
from recommender.vectordb import chunk_name_mapping,chunk_embeddings

#Faiss vector db retrieval

cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L6-v2")

def ranked_card_retrieval_node(state: CreditCardState):
    debug_print("NODE", f"Entering ranked_card_retrieval_node with multi-query support")

    queries = state.get("multi_queries", [])
    original_query = state["query"]
    cards = state["cards"]

    # getting the indices of the cards retrieved from neo4j graph
    rel_idxs = [i for i, name in chunk_name_mapping.items() if name in cards]
    if not rel_idxs:
        debug_print("NODE", "No relevant indices found for cards, returning empty ranked_cards")
        return {"ranked_cards": []}
   
    #storing the chunks in a FAISS index
    embs = chunk_embeddings[rel_idxs]
    mapping = {j: chunk_name_mapping[i] for j, i in enumerate(rel_idxs)}
    idx = faiss.IndexFlatIP(embs.shape[1])
    faiss.normalize_L2(embs)
    idx.add(embs)

    all_retrieved = []

    for query in queries:

        query_embedding = call_genai_embedding_api(
            "models/embedding-001", 
            model="models/embedding-001",
            content=query,
            task_type="RETRIEVAL_QUERY"
        )["embedding"]
        query_embedding = np.array(query_embedding, dtype=np.float32).reshape(1, -1)
        faiss.normalize_L2(query_embedding)

        D, I = idx.search(query_embedding, 50)
        similarity_scores = D[0]

        card_similarity = defaultdict(float)
        card_dict = {card["name"]: card for card in df.to_dict(orient="records")}
        unique_cards = {}

        for i, chunk_idx in enumerate(I[0]):
            if chunk_idx == -1:
                continue
            card_name = mapping.get(chunk_idx)
            if card_name and (card_name not in unique_cards or similarity_scores[i] > card_similarity[card_name]):
                card_similarity[card_name] = similarity_scores[i]
                unique_cards[card_name] = {
                    "name": card_name,
                    "description": card_dict.get(card_name, {}).get("description", ""),
                    "similarity": similarity_scores[i]
                }


        all_retrieved.extend(unique_cards.values())

    seen = set()
    deduped_cards = []
    for card in all_retrieved:
        if card["name"] not in seen:
            seen.add(card["name"])
            deduped_cards.append(card)

    if not deduped_cards:
        debug_print("NODE", "No deduplicated cards found, returning empty ranked_cards")
        return {"ranked_cards": []}

    #Reranking based on original query
    pairs = [[original_query, c["description"]] for c in deduped_cards]
    scores = cross_encoder.predict(pairs)

    if len(scores) == 0 or (max(scores) - min(scores)) == 0:
        ranked_cards = deduped_cards[:5]
        debug_print("NODE", "Using first 5 cards (no meaningful re-rank)")
    else:
        norm_scores = (np.array(scores) - np.min(scores)) / (np.max(scores) - np.min(scores))
        sorted_cards = sorted(zip(norm_scores, deduped_cards), key=lambda x: x[0], reverse=True)
        ranked_cards = [c for _, c in sorted_cards[:5]]

    debug_print("NODE", f"Exiting ranked_card_retrieval_node with {len(ranked_cards)} ranked cards")
    return {"ranked_cards": ranked_cards}