|
|
import numpy as np |
|
|
from rank_bm25 import BM25Okapi |
|
|
import re |
|
|
import os |
|
|
from loguru import logger |
|
|
from app.constants import SAMPLE_STORE_KEYS, build_key_search_text |
|
|
from app.models import KeyMapping |
|
|
from app.services.embedding_service import EmbeddingService |
|
|
|
|
|
|
|
|
class KeyMapper: |
|
|
""" Hybrid approach - combines semantic search with keyword matching |
|
|
TODO: maybe add cross-encoder reranking later if needed """ |
|
|
|
|
|
def __init__(self, embedding_service): |
|
|
try: |
|
|
self.embed_service = embedding_service |
|
|
self.rrf_k = int(os.getenv("RRF_K", "60")) |
|
|
self.threshold = float(os.getenv("SIM_THRESHOLD", "0.7")) |
|
|
|
|
|
logger.info("Initializing KeyMapper...") |
|
|
|
|
|
self.keys = SAMPLE_STORE_KEYS |
|
|
|
|
|
|
|
|
self.key_texts = [build_key_search_text(k) for k in self.keys] |
|
|
logger.debug(f"Built {len(self.key_texts)} key search texts") |
|
|
|
|
|
|
|
|
logger.info("Computing key embeddings...") |
|
|
self.key_embeddings = self.embed_service.embed_batch(self.key_texts) |
|
|
logger.debug(f"Key embeddings shape: {self.key_embeddings.shape}") |
|
|
|
|
|
|
|
|
logger.info("Building BM25 index...") |
|
|
self.tokenized_keys = [self.tokenize(text) for text in self.key_texts] |
|
|
self.bm25 = BM25Okapi(self.tokenized_keys) |
|
|
logger.success("KeyMapper initialized successfully") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize KeyMapper: {str(e)}") |
|
|
raise |
|
|
|
|
|
def tokenize(self, text): |
|
|
|
|
|
try: |
|
|
tokens = re.findall(r'\w+', text.lower()) |
|
|
return tokens |
|
|
except Exception as e: |
|
|
logger.error(f"Tokenization failed: {str(e)}") |
|
|
return [] |
|
|
|
|
|
def extract_key_phrases(self, prompt): |
|
|
|
|
|
|
|
|
try: |
|
|
phrases = [] |
|
|
|
|
|
phrases.append(prompt.strip()) |
|
|
|
|
|
tokens = self.tokenize(prompt) |
|
|
|
|
|
|
|
|
for i in range(len(tokens) - 1): |
|
|
phrases.append(f"{tokens[i]} {tokens[i+1]}") |
|
|
|
|
|
|
|
|
for i in range(len(tokens) - 2): |
|
|
phrases.append(f"{tokens[i]} {tokens[i+1]} {tokens[i+2]}") |
|
|
|
|
|
|
|
|
phrases.extend([t for t in tokens if len(t) > 3]) |
|
|
|
|
|
|
|
|
seen = set() |
|
|
unique = [] |
|
|
for p in phrases: |
|
|
if p not in seen: |
|
|
seen.add(p) |
|
|
unique.append(p) |
|
|
|
|
|
return unique[:15] |
|
|
except Exception as e: |
|
|
logger.error(f"Phrase extraction failed: {str(e)}") |
|
|
return [prompt] |
|
|
|
|
|
def compute_dense_ranks(self, prompt): |
|
|
|
|
|
try: |
|
|
prompt_emb = self.embed_service.embed_single(prompt) |
|
|
|
|
|
similarities = self.embed_service.batch_cosine_similarity( |
|
|
prompt_emb, |
|
|
self.key_embeddings |
|
|
) |
|
|
|
|
|
|
|
|
ranks = np.argsort(-similarities) |
|
|
|
|
|
|
|
|
rank_positions = np.zeros(len(self.keys), dtype=int) |
|
|
for pos, idx in enumerate(ranks): |
|
|
rank_positions[idx] = pos + 1 |
|
|
|
|
|
return rank_positions, similarities |
|
|
except Exception as e: |
|
|
logger.error(f"Dense ranking failed: {str(e)}") |
|
|
|
|
|
default_ranks = np.arange(1, len(self.keys) + 1) |
|
|
default_sims = np.zeros(len(self.keys)) |
|
|
return default_ranks, default_sims |
|
|
|
|
|
def compute_sparse_ranks(self, prompt): |
|
|
|
|
|
try: |
|
|
prompt_tokens = self.tokenize(prompt) |
|
|
bm25_scores = self.bm25.get_scores(prompt_tokens) |
|
|
|
|
|
ranks = np.argsort(-bm25_scores) |
|
|
|
|
|
rank_positions = np.zeros(len(self.keys), dtype=int) |
|
|
for pos, idx in enumerate(ranks): |
|
|
rank_positions[idx] = pos + 1 |
|
|
|
|
|
return rank_positions, bm25_scores |
|
|
except Exception as e: |
|
|
logger.error(f"Sparse ranking failed: {str(e)}") |
|
|
default_ranks = np.arange(1, len(self.keys) + 1) |
|
|
default_scores = np.zeros(len(self.keys)) |
|
|
return default_ranks, default_scores |
|
|
|
|
|
def apply_rrf(self, dense_ranks, sparse_ranks): |
|
|
|
|
|
|
|
|
try: |
|
|
rrf_scores = (1.0 / (self.rrf_k + dense_ranks)) + \ |
|
|
(1.0 / (self.rrf_k + sparse_ranks)) |
|
|
return rrf_scores |
|
|
except Exception as e: |
|
|
logger.error(f"RRF fusion failed: {str(e)}") |
|
|
|
|
|
return 1.0 / (self.rrf_k + dense_ranks) |
|
|
|
|
|
def map_keys(self, prompt, top_k=5): |
|
|
"""Map user prompt to actual store keys""" |
|
|
try: |
|
|
logger.info(f"Mapping keys for prompt: {prompt[:50]}...") |
|
|
|
|
|
|
|
|
dense_ranks, dense_sims = self.compute_dense_ranks(prompt) |
|
|
sparse_ranks, sparse_scores = self.compute_sparse_ranks(prompt) |
|
|
|
|
|
|
|
|
rrf_scores = self.apply_rrf(dense_ranks, sparse_ranks) |
|
|
|
|
|
|
|
|
sorted_indices = np.argsort(-rrf_scores) |
|
|
|
|
|
|
|
|
key_phrases = self.extract_key_phrases(prompt) |
|
|
|
|
|
|
|
|
mappings = [] |
|
|
for idx in sorted_indices: |
|
|
|
|
|
max_rrf = 2.0 / (self.rrf_k + 1) |
|
|
normalized_score = float(rrf_scores[idx] / max_rrf) |
|
|
|
|
|
|
|
|
key_emb = self.key_embeddings[idx] |
|
|
best_phrase = prompt |
|
|
best_phrase_sim = dense_sims[idx] |
|
|
|
|
|
|
|
|
for phrase in key_phrases: |
|
|
phrase_emb = self.embed_service.embed_single(phrase) |
|
|
phrase_sim = self.embed_service.cosine_similarity(phrase_emb, key_emb) |
|
|
if phrase_sim > best_phrase_sim: |
|
|
best_phrase = phrase |
|
|
best_phrase_sim = phrase_sim |
|
|
|
|
|
mappings.append(KeyMapping( |
|
|
user_phrase=best_phrase[:50], |
|
|
mapped_to=self.keys[idx]['value'], |
|
|
similarity=float(np.clip(normalized_score, 0.0, 1.0)) |
|
|
)) |
|
|
|
|
|
if len(mappings) >= top_k: |
|
|
break |
|
|
|
|
|
logger.success(f"Mapped {len(mappings)} keys successfully") |
|
|
return mappings |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Key mapping failed: {str(e)}") |
|
|
|
|
|
return [] |
|
|
|
|
|
def get_top_keys(self, prompt, top_k=5, min_similarity=None): |
|
|
"""Get top keys with full metadata""" |
|
|
try: |
|
|
threshold = min_similarity if min_similarity is not None else self.threshold |
|
|
|
|
|
|
|
|
mappings = self.map_keys(prompt, top_k=top_k * 2) |
|
|
|
|
|
filtered = [m for m in mappings if m.similarity >= threshold] |
|
|
|
|
|
|
|
|
result = [] |
|
|
for mapping in filtered[:top_k]: |
|
|
key_obj = next((k for k in self.keys if k['value'] == mapping.mapped_to), None) |
|
|
if key_obj: |
|
|
result.append({ |
|
|
**key_obj, |
|
|
'similarity': mapping.similarity, |
|
|
'matched_phrase': mapping.user_phrase |
|
|
}) |
|
|
|
|
|
return result |
|
|
except Exception as e: |
|
|
logger.error(f"get_top_keys failed: {str(e)}") |
|
|
return [] |
|
|
|