import faiss import numpy as np from sentence_transformers import CrossEncoder from collections import defaultdict from data import debug_print,call_genai_embedding_api from nodes.intent import CreditCardState from data import df from recommender.vectordb import chunk_name_mapping,chunk_embeddings #Faiss vector db retrieval cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L6-v2") def ranked_card_retrieval_node(state: CreditCardState): debug_print("NODE", f"Entering ranked_card_retrieval_node with multi-query support") queries = state.get("multi_queries", []) original_query = state["query"] cards = state["cards"] # getting the indices of the cards retrieved from neo4j graph rel_idxs = [i for i, name in chunk_name_mapping.items() if name in cards] if not rel_idxs: debug_print("NODE", "No relevant indices found for cards, returning empty ranked_cards") return {"ranked_cards": []} #storing the chunks in a FAISS index embs = chunk_embeddings[rel_idxs] mapping = {j: chunk_name_mapping[i] for j, i in enumerate(rel_idxs)} idx = faiss.IndexFlatIP(embs.shape[1]) faiss.normalize_L2(embs) idx.add(embs) all_retrieved = [] for query in queries: query_embedding = call_genai_embedding_api( "models/embedding-001", model="models/embedding-001", content=query, task_type="RETRIEVAL_QUERY" )["embedding"] query_embedding = np.array(query_embedding, dtype=np.float32).reshape(1, -1) faiss.normalize_L2(query_embedding) D, I = idx.search(query_embedding, 50) similarity_scores = D[0] card_similarity = defaultdict(float) card_dict = {card["name"]: card for card in df.to_dict(orient="records")} unique_cards = {} for i, chunk_idx in enumerate(I[0]): if chunk_idx == -1: continue card_name = mapping.get(chunk_idx) if card_name and (card_name not in unique_cards or similarity_scores[i] > card_similarity[card_name]): card_similarity[card_name] = similarity_scores[i] unique_cards[card_name] = { "name": card_name, "description": card_dict.get(card_name, {}).get("description", ""), "similarity": similarity_scores[i] } all_retrieved.extend(unique_cards.values()) seen = set() deduped_cards = [] for card in all_retrieved: if card["name"] not in seen: seen.add(card["name"]) deduped_cards.append(card) if not deduped_cards: debug_print("NODE", "No deduplicated cards found, returning empty ranked_cards") return {"ranked_cards": []} #Reranking based on original query pairs = [[original_query, c["description"]] for c in deduped_cards] scores = cross_encoder.predict(pairs) if len(scores) == 0 or (max(scores) - min(scores)) == 0: ranked_cards = deduped_cards[:5] debug_print("NODE", "Using first 5 cards (no meaningful re-rank)") else: norm_scores = (np.array(scores) - np.min(scores)) / (np.max(scores) - np.min(scores)) sorted_cards = sorted(zip(norm_scores, deduped_cards), key=lambda x: x[0], reverse=True) ranked_cards = [c for _, c in sorted_cards[:5]] debug_print("NODE", f"Exiting ranked_card_retrieval_node with {len(ranked_cards)} ranked cards") return {"ranked_cards": ranked_cards}