File size: 4,470 Bytes
8bd9348
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from __future__ import annotations
import re, json, hashlib
from dataclasses import dataclass
from typing import List, Dict, Any, Tuple

import numpy as np
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer, CrossEncoder

def cache_key(obj: Any) -> str:
    return hashlib.sha256(json.dumps(obj, sort_keys=True).encode()).hexdigest()

def normalize_text(s: str) -> str:
    return re.sub(r"\s+", " ", s).strip()

@dataclass
class Doc:
    id: str
    text: str
    metadata: Dict[str, Any]

class HybridIndex:
    def __init__(self, dense_model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
        self.dense_model_name = dense_model_name
        self.embedder = SentenceTransformer(dense_model_name)
        self.docs: List[Doc] = []
        self.bm25 = None
        self.embeddings = None

    def add(self, docs: List[Doc]):
        self.docs.extend(docs)

    def build(self):
        # Build only if we have docs
        if not self.docs:
            self.bm25, self.embeddings = None, None
            return
        corpus = [d.text for d in self.docs]
        tokenized = [c.split() for c in corpus]
        self.bm25 = BM25Okapi(tokenized)
        self.embeddings = self.embedder.encode(
            corpus, convert_to_numpy=True, normalize_embeddings=True
        )

    def ready(self) -> bool:
        return (self.bm25 is not None) and (self.embeddings is not None) and (len(self.docs) > 0)

    def search(self, query: str, k: int = 8, alpha: float = 0.5, filters: Dict[str, Any] | None = None):
        # If index isn't ready, return empty (UI/pipeline should guide the user)
        if not self.ready():
            return []

        # Dense embedding for query
        query_vec = self.embedder.encode([query], convert_to_numpy=True, normalize_embeddings=True)[0]

        # BM25 + dense scores
        q_tokens = query.split()
        try:
            bm25_scores = self.bm25.get_scores(q_tokens)
        except Exception:
            # Fallback if BM25 hiccups (e.g., empty tokens)
            bm25_scores = np.zeros(len(self.docs), dtype=float)
        dense_scores = (self.embeddings @ query_vec)

        # NumPy 2.x-safe normalization
        def _norm(x: np.ndarray) -> np.ndarray:
            x = np.asarray(x, dtype=float)
            rng = np.ptp(x)
            return (x - x.min()) / (rng + 1e-8)

        bm25_norm = _norm(bm25_scores)
        dense_norm = _norm(dense_scores)
        scores = alpha * bm25_norm + (1 - alpha) * dense_norm

        # Optional metadata filters
        idxs = np.arange(len(self.docs))
        if filters:
            def ok(d: Doc) -> bool:
                for kf, vf in filters.items():
                    if kf not in d.metadata:
                        return False
                    dv = str(d.metadata[kf]).lower()
                    if isinstance(vf, (list, tuple, set)):
                        if not any(str(x).lower() in dv for x in vf):
                            return False
                    else:
                        if str(vf).lower() not in dv:
                            return False
                return True

            keep = [i for i in idxs if ok(self.docs[int(i)])]
            if not keep:
                return []
            idxs = np.array(keep, dtype=int)
            scores = scores[idxs]

        # Top-k results
        order = np.argsort(-scores)[:k]
        return [(self.docs[int(idxs[i])], float(scores[i])) for i in order]

class Reranker:
    def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
        self.model = CrossEncoder(model_name)

    def rerank(self, query: str, docs: List[Tuple[Doc, float]], top_k: int = 5) -> List[Tuple[Doc, float]]:
        if not docs:
            return []
        pairs = [(query, d.text) for d, _ in docs]
        scores = self.model.predict(pairs)
        rescored = list(zip([d for d,_ in docs], [float(s) for s in scores]))
        rescored.sort(key=lambda x: -x[1])
        return rescored[:top_k]

def select_fewshots(query: str, fewshots: List[Dict[str, str]], embedder: SentenceTransformer, n: int = 2):
    if not fewshots:
        return []
    qv = embedder.encode([query], convert_to_numpy=True, normalize_embeddings=True)[0]
    ex_vecs = embedder.encode([fs["q"] for fs in fewshots], convert_to_numpy=True, normalize_embeddings=True)
    sims = ex_vecs @ qv
    order = np.argsort(-sims)[:n]
    return [fewshots[i] for i in order]