Sulaiman8's picture
Upload all the files
45e9462 verified
import faiss
import numpy as np
from sentence_transformers import CrossEncoder
from collections import defaultdict
from data import debug_print,call_genai_embedding_api
from nodes.intent import CreditCardState
from data import df
from recommender.vectordb import chunk_name_mapping,chunk_embeddings
#Faiss vector db retrieval
cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L6-v2")
def ranked_card_retrieval_node(state: CreditCardState):
debug_print("NODE", f"Entering ranked_card_retrieval_node with multi-query support")
queries = state.get("multi_queries", [])
original_query = state["query"]
cards = state["cards"]
# getting the indices of the cards retrieved from neo4j graph
rel_idxs = [i for i, name in chunk_name_mapping.items() if name in cards]
if not rel_idxs:
debug_print("NODE", "No relevant indices found for cards, returning empty ranked_cards")
return {"ranked_cards": []}
#storing the chunks in a FAISS index
embs = chunk_embeddings[rel_idxs]
mapping = {j: chunk_name_mapping[i] for j, i in enumerate(rel_idxs)}
idx = faiss.IndexFlatIP(embs.shape[1])
faiss.normalize_L2(embs)
idx.add(embs)
all_retrieved = []
for query in queries:
query_embedding = call_genai_embedding_api(
"models/embedding-001",
model="models/embedding-001",
content=query,
task_type="RETRIEVAL_QUERY"
)["embedding"]
query_embedding = np.array(query_embedding, dtype=np.float32).reshape(1, -1)
faiss.normalize_L2(query_embedding)
D, I = idx.search(query_embedding, 50)
similarity_scores = D[0]
card_similarity = defaultdict(float)
card_dict = {card["name"]: card for card in df.to_dict(orient="records")}
unique_cards = {}
for i, chunk_idx in enumerate(I[0]):
if chunk_idx == -1:
continue
card_name = mapping.get(chunk_idx)
if card_name and (card_name not in unique_cards or similarity_scores[i] > card_similarity[card_name]):
card_similarity[card_name] = similarity_scores[i]
unique_cards[card_name] = {
"name": card_name,
"description": card_dict.get(card_name, {}).get("description", ""),
"similarity": similarity_scores[i]
}
all_retrieved.extend(unique_cards.values())
seen = set()
deduped_cards = []
for card in all_retrieved:
if card["name"] not in seen:
seen.add(card["name"])
deduped_cards.append(card)
if not deduped_cards:
debug_print("NODE", "No deduplicated cards found, returning empty ranked_cards")
return {"ranked_cards": []}
#Reranking based on original query
pairs = [[original_query, c["description"]] for c in deduped_cards]
scores = cross_encoder.predict(pairs)
if len(scores) == 0 or (max(scores) - min(scores)) == 0:
ranked_cards = deduped_cards[:5]
debug_print("NODE", "Using first 5 cards (no meaningful re-rank)")
else:
norm_scores = (np.array(scores) - np.min(scores)) / (np.max(scores) - np.min(scores))
sorted_cards = sorted(zip(norm_scores, deduped_cards), key=lambda x: x[0], reverse=True)
ranked_cards = [c for _, c in sorted_cards[:5]]
debug_print("NODE", f"Exiting ranked_card_retrieval_node with {len(ranked_cards)} ranked cards")
return {"ranked_cards": ranked_cards}