| import os |
| import json |
| import hashlib |
| import random |
| import pandas as pd |
| from typing import Union, Tuple |
| from threading import Lock |
| import torch |
|
|
| from pyserini.search.lucene import LuceneSearcher |
| from pyserini.index.lucene import Document |
| from pyserini.analysis import get_lucene_analyzer |
| from pyserini.pyclass import autoclass |
| from scipy.stats import hypergeom |
| from subspace.tool import SubspaceBERTScore |
|
|
|
|
| _CACHE_LOCK = Lock() |
| _SEARCHER_CACHE = {} |
| _HGF_ANALYZER = None |
| _STANDARD_QUERY_PARSER_CLASS = None |
|
|
|
|
| def _get_cached_searcher(index_path: str): |
| with _CACHE_LOCK: |
| searcher = _SEARCHER_CACHE.get(index_path) |
| if searcher is None: |
| searcher = LuceneSearcher(index_path) |
| _SEARCHER_CACHE[index_path] = searcher |
| return searcher |
|
|
|
|
| def _get_cached_hgf_analyzer(): |
| global _HGF_ANALYZER |
| with _CACHE_LOCK: |
| if _HGF_ANALYZER is None: |
| _HGF_ANALYZER = get_lucene_analyzer( |
| language='hgf_tokenizer', |
| huggingFaceTokenizer='bert-base-uncased' |
| ) |
| return _HGF_ANALYZER |
|
|
| |
| |
| |
|
|
| def get_standard_query(query: str, field: str = "contents", analyzer=None): |
| """ |
| Runs Lucene's StandardQueryParser to get a parsed query object. |
| """ |
| if analyzer is None: |
| analyzer = get_lucene_analyzer() |
| |
| global _STANDARD_QUERY_PARSER_CLASS |
| if _STANDARD_QUERY_PARSER_CLASS is None: |
| _STANDARD_QUERY_PARSER_CLASS = autoclass('org.apache.lucene.queryparser.flexible.standard.StandardQueryParser') |
|
|
| query_parser = _STANDARD_QUERY_PARSER_CLASS() |
| query_parser.setAnalyzer(analyzer) |
| |
| return query_parser.parse(query, field) |
|
|
| def query_bm25_index(index_path: str, keywords: list, doc_count: int = None) -> pd.DataFrame: |
| """Load index, run BM25 phrase search using custom HuggingFace analyzer, and return results.""" |
| |
| searcher = _get_cached_searcher(index_path) |
|
|
| |
| analyzer = _get_cached_hgf_analyzer() |
| |
| |
| query_string = " OR ".join([f'"{kw}"' for kw in keywords]) |
| |
| |
| phrase_q = get_standard_query(query_string, analyzer=analyzer) |
| |
| |
| if doc_count is None: |
| doc_count = searcher.num_docs |
| |
| hits = searcher.search(phrase_q, doc_count) |
| |
| |
| results = [] |
| returned_ids = set() |
|
|
| for hit in hits: |
| returned_ids.add(hit.docid) |
| doc = Document(hit.lucene_document) |
| raw = doc.raw() |
| jd = json.loads(raw) |
| |
| row = { |
| 'id': jd.get("id"), |
| 'content': jd.get("contents", ""), |
| 'score': hit.score |
| } |
| |
| if "metadata" in jd and jd["metadata"]: |
| metadata = json.loads(jd["metadata"]) |
| row.update(metadata) |
| |
| results.append(row) |
| |
| returned_ext_ids = {r['id'] for r in results} |
|
|
| |
| if len(results) < doc_count: |
| needed = doc_count - len(results) |
| total = searcher.num_docs |
|
|
| |
| pool = [] |
| for docnum in range(total): |
| lucene_doc = searcher.doc(docnum) |
| doc = Document(lucene_doc) |
| jd = json.loads(doc.raw()) |
| ext_id = jd.get("id") |
| if ext_id not in returned_ext_ids: |
| pool.append(docnum) |
|
|
| |
| md5 = hashlib.md5(query_string.encode("utf-8")).hexdigest() |
| seed = int(md5, 16) % 2**32 |
| rng = random.Random(seed) |
| rng.shuffle(pool) |
|
|
| |
| for docnum in pool[:needed]: |
| lucene_doc = searcher.doc(docnum) |
| doc = Document(lucene_doc) |
| raw = doc.raw() |
| jd = json.loads(raw) |
| |
| row = { |
| "id": jd.get("id"), |
| "content": jd.get("contents", ""), |
| "score": None |
| } |
| if "metadata" in jd and jd["metadata"]: |
| metadata = json.loads(jd["metadata"]) |
| row.update(metadata) |
| |
| results.append(row) |
| |
| return pd.DataFrame(results) |
|
|
|
|
| def warmup_bm25(index_path: str, warmup_keyword: str = "community", warmup_k: int = 1) -> None: |
| """Pre-warm Lucene searcher/analyzer/parser so first user query is faster.""" |
| if not index_path or not os.path.isdir(index_path): |
| raise ValueError(f"Invalid index path for warmup: {index_path}") |
|
|
| searcher = _get_cached_searcher(index_path) |
| analyzer = _get_cached_hgf_analyzer() |
| query = get_standard_query(f'"{warmup_keyword}"', analyzer=analyzer) |
| searcher.search(query, max(1, int(warmup_k))) |
|
|
| |
| |
| |
|
|
| def _resolve_k(df, k): |
| """Convert float percentages to absolute k or return k as an int.""" |
| if isinstance(k, float) and 0.0 < k <= 1.0: |
| return int(len(df) * k) |
| return int(k) |
|
|
| def precision_at_k(df: pd.DataFrame, correct_demographic: str, k: Union[int, float]) -> float: |
| """Calculate precision at k for a target demographic.""" |
| rel = (df['demographic'] == correct_demographic).astype(int) |
| k_abs = _resolve_k(df, k) |
| if k_abs <= 0: |
| return 0.0 |
| return rel.iloc[:k_abs].sum() / float(k_abs) |
|
|
| def lift_at_k(df: pd.DataFrame, correct_demographic: str, k: Union[int, float]) -> float: |
| """Lift@k: ratio of precision@k to the overall proportion of relevant items.""" |
| k_abs = _resolve_k(df, k) |
| if k_abs <= 0 or len(df) == 0: |
| return 0.0 |
| |
| precision_k = precision_at_k(df, correct_demographic, k) |
| rel = (df['demographic'] == correct_demographic).astype(int) |
| overall_proportion = rel.sum() / float(len(df)) |
| |
| if overall_proportion == 0: |
| return 0.0 |
| |
| return precision_k / overall_proportion |
|
|
| def hypergeometric_significance_test(df: pd.DataFrame, correct_demographic: str, k: Union[int, float], alpha: float = 0.05) -> Tuple[float, Tuple[int, int], Tuple[float, float]]: |
| """Hypergeometric statistical significance test for the retrieval.""" |
| n = _resolve_k(df, k) |
| N = len(df) |
| |
| rel = (df['demographic'] == correct_demographic).astype(int) |
| K = rel.sum() |
| k_obs = rel.iloc[:n].sum() |
| |
| if K == 0 or n <= 0: |
| return 0.0, (0, 0), (0.0, 0.0) |
| |
| p_value = hypergeom.sf(k_obs - 1, N, K, n) |
| L = int(hypergeom.ppf(alpha/2, N, K, n)) |
| U = int(hypergeom.isf(alpha/2, N, K, n)) |
| |
| return p_value, (L, U), (L / n, U / n) |
|
|
| def lift_ci(df: pd.DataFrame, correct_demographic: str, k: Union[int, float], alpha: float = 0.05) -> Tuple[float, float, float]: |
| """Calculate confidence interval for lift@k using hypergeometric distribution.""" |
| n = _resolve_k(df, k) |
| N = len(df) |
| |
| rel = (df['demographic'] == correct_demographic).astype(int) |
| K = rel.sum() |
| overall_proportion = K / float(N) |
| |
| if K == 0 or n <= 0 or overall_proportion == 0: |
| return 0.0, 0.0, 0.0 |
| |
| pval, (L, U), _ = hypergeometric_significance_test(df, correct_demographic, k, alpha) |
| lower_bound_lift = (L / n) / overall_proportion |
| upper_bound_lift = (U / n) / overall_proportion |
| |
| return pval, lower_bound_lift, upper_bound_lift |
|
|
| |
| |
| |
|
|
| _GLOBAL_SCORER = None |
|
|
| def compute_keyword_similarity(set1: list, set2: list, device: str = None) -> dict: |
| """ |
| Computes precision, recall, and F-score similarity metrics between two keyword sets. |
| Mirrors the subspace-based BERT scoring logic handling keyword lists. |
| """ |
| global _GLOBAL_SCORER |
| if device is None: |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| |
| if _GLOBAL_SCORER is None: |
| print(f"Initializing BERT model on {device}...") |
| _GLOBAL_SCORER = SubspaceBERTScore(device=device, model_name_or_path='bert-base-uncased') |
| |
| sentence_1 = [", ".join(set1)] |
| sentence_2 = [", ".join(set2)] |
| |
| scores = _GLOBAL_SCORER(sentence_1, sentence_2) |
| |
| return { |
| 'Precision': scores[0].item(), |
| 'Recall': scores[1].item(), |
| 'F-Score': scores[2].item() |
| } |
|
|