Spaces:
Sleeping
Sleeping
File size: 4,023 Bytes
a34068e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | import logging
import time
from collections import defaultdict
from app.core.bm25 import BM25Index
from app.core.embedder import EmbedderService
from app.core.vectorstore import VectorStoreService
from app.models.document import DocumentMetadata
from app.models.schemas import RetrievedChunk, SearchFilters
logger = logging.getLogger(__name__)
class HybridRetriever:
def __init__(
self,
vectorstore: VectorStoreService,
bm25: BM25Index,
embedder: EmbedderService,
):
self.vectorstore = vectorstore
self.bm25 = bm25
self.embedder = embedder
def retrieve(
self,
query: str,
top_k: int = 10,
filters: SearchFilters | None = None,
dense_weight: float = 0.6,
sparse_weight: float = 0.4,
) -> list[RetrievedChunk]:
start = time.perf_counter()
query_vector = self.embedder.embed_query(query)
# Dense search via Qdrant (over-fetch 2x)
dense_results = self.vectorstore.search(
query_vector=query_vector,
limit=top_k * 2,
filters=filters,
)
# Sparse search via BM25
sparse_results = self.bm25.search(query, top_k=top_k * 2)
# Post-filter BM25 results if filters are provided
if filters and filters.has_filters():
sparse_results = self._apply_filters(sparse_results, filters)
# RRF fusion
fused = self.rrf_fuse(
[dense_results, sparse_results],
weights=[dense_weight, sparse_weight],
)
# Deduplicate by chunk_id and take top_k
seen = set()
unique = []
for item in fused:
if item["chunk_id"] not in seen:
seen.add(item["chunk_id"])
unique.append(item)
if len(unique) >= top_k:
break
# Convert to RetrievedChunk models
results = [
RetrievedChunk(
chunk_id=item["chunk_id"],
document_id=item.get("document_id", ""),
text=item["text"],
score=item["fused_score"],
metadata=DocumentMetadata(**item.get("metadata", {})),
rank=i,
)
for i, item in enumerate(unique)
]
elapsed = (time.perf_counter() - start) * 1000
logger.info(
f"Hybrid retrieval: {len(dense_results)} dense + {len(sparse_results)} sparse "
f"→ {len(results)} results in {elapsed:.0f}ms"
)
return results
@staticmethod
def rrf_fuse(
result_lists: list[list[dict]],
k: int = 60,
weights: list[float] | None = None,
) -> list[dict]:
if weights is None:
weights = [1.0] * len(result_lists)
scores: dict[str, float] = defaultdict(float)
docs: dict[str, dict] = {}
for result_list, weight in zip(result_lists, weights):
for rank, item in enumerate(result_list):
chunk_id = item["chunk_id"]
scores[chunk_id] += weight * (1.0 / (k + rank))
if chunk_id not in docs:
docs[chunk_id] = item
ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)
return [
{**docs[chunk_id], "fused_score": score}
for chunk_id, score in ranked
]
@staticmethod
def _apply_filters(results: list[dict], filters: SearchFilters) -> list[dict]:
filtered = []
for r in results:
meta = r.get("metadata", {})
if filters.source and meta.get("source") != filters.source:
continue
if filters.doc_type and meta.get("doc_type") != filters.doc_type:
continue
if filters.tags:
doc_tags = meta.get("tags", [])
if not any(t in doc_tags for t in filters.tags):
continue
filtered.append(r)
return filtered
|