from __future__ import annotations import time from pathlib import Path import faiss from huggingface_hub import hf_hub_download import joblib import numpy as np import pandas as pd from sentence_transformers import SentenceTransformer try: from .runtime_utils import ( load_duplicate_threshold, load_model_config, resolve_model_dir, resolve_model_reference, ) except ImportError: # pragma: no cover from runtime_utils import ( load_duplicate_threshold, load_model_config, resolve_model_dir, resolve_model_reference, ) class CachedDuplicateDetectionEngine: def __init__(self, base_dir: str | Path | None = None): self.base_dir = ( Path(base_dir).resolve() if base_dir is not None else Path(__file__).resolve().parent ) self.model_dir = resolve_model_dir(self.base_dir) self.model_config = load_model_config(self.base_dir) self.duplicate_threshold = load_duplicate_threshold(self.base_dir) dataset_path = hf_hub_download( repo_id="Eklavya73/ticket-duplicate-assets", filename="Domain-A_Dataset_Clean.csv", ) self.dataset = pd.read_csv(dataset_path) if "text" not in self.dataset.columns: raise ValueError("Duplicate dataset must contain a 'text' column.") self.db_texts = self.dataset["text"].astype(str).tolist() ticket_id_path = hf_hub_download( repo_id="Eklavya73/ticket-duplicate-assets", filename="ticket_ids.pkl", ) loaded_ids = joblib.load(ticket_id_path) self.db_ids = [str(ticket_id) for ticket_id in loaded_ids] if len(self.db_ids) != len(self.db_texts): if "ticket_id" in self.dataset.columns: self.db_ids = self.dataset["ticket_id"].astype(str).tolist() else: self.db_ids = [str(i) for i in range(len(self.db_texts))] embeddings_path = hf_hub_download( repo_id="Eklavya73/ticket-duplicate-assets", filename="db_embeddings.npy", ) self.db_embeddings = np.load(embeddings_path).astype("float32") if self.db_embeddings.ndim != 2: raise ValueError( f"Expected 2D duplicate embedding matrix, got shape={self.db_embeddings.shape}" ) if self.db_embeddings.shape[0] != len(self.db_texts): raise ValueError( "Embedding count does not match dataset rows: " f"{self.db_embeddings.shape[0]} embeddings vs {len(self.db_texts)} texts" ) faiss.normalize_L2(self.db_embeddings) self.embedding_dim = int(self.db_embeddings.shape[1]) self.faiss_meta = self._load_faiss_meta() self.index = self._build_index(self.db_embeddings) self.index.add(self.db_embeddings) self.initial_index_size = int(self.index.ntotal) self._encoder: SentenceTransformer | None = None def _load_faiss_meta(self) -> dict: meta_path = self.model_dir / "faiss_index_meta.pkl" if meta_path.exists(): loaded = joblib.load(meta_path) if isinstance(loaded, dict): return loaded return { "dimension": self.embedding_dim, "index_type": "flat", "size": len(self.db_texts), } def _build_index(self, embeddings: np.ndarray): index_type = str(self.faiss_meta.get("index_type", "flat")).lower() nlist = max(1, int(self.faiss_meta.get("nlist", 256))) nprobe = max(1, int(self.faiss_meta.get("nprobe", 48))) if index_type == "ivf" and len(embeddings) >= max(64, nlist): quantizer = faiss.IndexFlatIP(self.embedding_dim) index = faiss.IndexIVFFlat( quantizer, self.embedding_dim, nlist, faiss.METRIC_INNER_PRODUCT, ) index.train(embeddings) index.nprobe = min(nprobe, nlist) return index return faiss.IndexFlatIP(self.embedding_dim) @property def index_size(self) -> int: return int(self.index.ntotal) def _get_encoder(self) -> SentenceTransformer: if self._encoder is None: model_ref = resolve_model_reference( self.model_config.get( "duplicate_sbert_model", "Eklavya73/duplicate_sbert", ), base_dir=self.base_dir, model_dir=self.model_dir, default="all-mpnet-base-v2", ) self._encoder = SentenceTransformer(model_ref) return self._encoder def _encode( self, texts, *, batch_size: int = 64, show_progress_bar: bool = False, ) -> np.ndarray: encoder = self._get_encoder() embeddings = encoder.encode( list(texts), batch_size=batch_size, show_progress_bar=show_progress_bar, normalize_embeddings=True, ) return np.asarray(embeddings, dtype="float32") def _normalize_query(self, embedding) -> np.ndarray: query = np.asarray(embedding, dtype="float32").reshape(1, -1).copy() faiss.normalize_L2(query) return query def _search(self, embedding, *, k: int = 20): if self.index_size == 0: return np.empty((1, 0), dtype="float32"), np.empty((1, 0), dtype=int) query = self._normalize_query(embedding) return self.index.search(query, min(max(1, int(k)), self.index_size)) def find_best_match( self, embedding, *, k: int = 20, exclude_indices=None, include_baseline: bool = False, ) -> dict | None: if include_baseline: effective_k = k else: effective_k = min(int(k) + self.initial_index_size, self.index_size) scores, indices = self._search(embedding, k=effective_k) excluded = set(int(idx) for idx in (exclude_indices or [])) for score, idx in zip(scores[0], indices[0]): idx = int(idx) if idx < 0 or idx in excluded: continue if not include_baseline and idx < self.initial_index_size: continue return { "index": idx, "ticket_id": self.db_ids[idx] if idx < len(self.db_ids) else None, "duplicate_of": self.db_ids[idx] if idx < len(self.db_ids) else None, "matched_text": self.db_texts[idx] if idx < len(self.db_texts) else None, "similarity": float(score), } return None def detect_duplicate( self, text: str | None = None, *, embedding=None, k: int = 20, exclude_indices=None, include_baseline: bool = False, ) -> dict | None: if embedding is None: if text is None: raise ValueError("Either text or embedding must be provided.") embedding = self._encode([str(text)])[0] match = self.find_best_match( embedding, k=k, exclude_indices=exclude_indices, include_baseline=include_baseline, ) if match is None: return None if float(match["similarity"]) < float(self.duplicate_threshold): return None return match def add_ticket( self, ticket_id: str, text: str, *, embedding=None, ) -> None: if embedding is None: embedding = self._encode([str(text)])[0] query = self._normalize_query(embedding) self.index.add(query) self.db_ids.append(str(ticket_id)) self.db_texts.append(str(text)) self.db_embeddings = np.vstack([self.db_embeddings, query]).astype("float32") self.faiss_meta["size"] = int(self.index.ntotal) def benchmark_duplicate_detection(self, *, num_queries: int = 200, k: int = 5) -> dict: if self.index_size <= 1: return { "exact_latency_ms": 0.0, "faiss_latency_ms": 0.0, "speedup_vs_exact": 0.0, "recall_at_k": 0.0, "duplicate_precision": 0.0, "duplicate_recall": 0.0, "duplicate_f1": 0.0, "duplicate_eval_pairs": 0, } k = max(1, int(k)) rng = np.random.default_rng(42) query_count = min(int(num_queries), self.index_size) sampled_indices = rng.choice(self.index_size, size=query_count, replace=False) exact_hits = 0 tp = 0 fp = 0 fn = 0 exact_latencies = [] faiss_latencies = [] for query_idx in sampled_indices: query_embedding = self.db_embeddings[query_idx] exact_start = time.perf_counter() similarities = self.db_embeddings @ query_embedding similarities[int(query_idx)] = -np.inf exact_top = np.argsort(-similarities)[:k] exact_score = float(similarities[int(exact_top[0])]) if exact_top.size else 0.0 exact_is_duplicate = exact_score >= float(self.duplicate_threshold) exact_latencies.append((time.perf_counter() - exact_start) * 1000.0) faiss_start = time.perf_counter() distances, neighbors = self._search(query_embedding, k=k + 1) faiss_latencies.append((time.perf_counter() - faiss_start) * 1000.0) faiss_candidates = [] faiss_best_score = 0.0 for score, neighbor_idx in zip(distances[0], neighbors[0]): neighbor_idx = int(neighbor_idx) if neighbor_idx < 0 or neighbor_idx == int(query_idx): continue faiss_candidates.append(neighbor_idx) if len(faiss_candidates) == 1: faiss_best_score = float(score) if len(faiss_candidates) >= k: break if exact_top.size and int(exact_top[0]) in set(faiss_candidates): exact_hits += 1 pred_is_duplicate = bool(faiss_candidates) and faiss_best_score >= float(self.duplicate_threshold) if pred_is_duplicate and exact_is_duplicate: tp += 1 elif pred_is_duplicate and not exact_is_duplicate: fp += 1 elif exact_is_duplicate and not pred_is_duplicate: fn += 1 precision = tp / max(tp + fp, 1) recall = tp / max(tp + fn, 1) f1 = 0.0 if (precision + recall) == 0.0 else (2.0 * precision * recall) / (precision + recall) exact_latency_ms = float(np.mean(exact_latencies)) if exact_latencies else 0.0 faiss_latency_ms = float(np.mean(faiss_latencies)) if faiss_latencies else 0.0 return { "exact_latency_ms": exact_latency_ms, "faiss_latency_ms": faiss_latency_ms, "speedup_vs_exact": ( float(exact_latency_ms / faiss_latency_ms) if faiss_latency_ms > 0.0 else 0.0 ), "recall_at_k": float(exact_hits / max(query_count, 1)), "duplicate_precision": float(precision), "duplicate_recall": float(recall), "duplicate_f1": float(f1), "duplicate_eval_pairs": int(query_count), } def get_duplicate_metrics(self) -> dict: return { "duplicate_threshold": float(self.duplicate_threshold), "faiss_meta": { "dimension": int(self.embedding_dim), "index_type": str(self.faiss_meta.get("index_type", "flat")), "nlist": int(self.faiss_meta.get("nlist", 0)), "nprobe": int(self.faiss_meta.get("nprobe", 0)), "size": int(self.index_size), }, }