Spaces:
Sleeping
Sleeping
File size: 5,723 Bytes
f866820 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 | """
Reranking module for improving retrieval precision.
Uses cross-encoder models to reorder initial retrieval results
based on query-document relevance.
"""
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass
# Model cache for lazy loading
_reranker_model = None
_reranker_model_name = None
@dataclass
class RerankResult:
"""Result from reranking operation."""
chunks: List[Dict[str, Any]]
model_used: str
reranked: bool
def _get_cross_encoder(model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
"""
Lazy load and cache cross-encoder model.
Args:
model_name: HuggingFace model name for cross-encoder
Returns:
CrossEncoder model instance
"""
global _reranker_model, _reranker_model_name
if _reranker_model is None or _reranker_model_name != model_name:
try:
from sentence_transformers import CrossEncoder
_reranker_model = CrossEncoder(model_name)
_reranker_model_name = model_name
except ImportError:
raise ImportError(
"sentence-transformers not installed. "
"Install with: pip install sentence-transformers"
)
except Exception as e:
raise RuntimeError(f"Failed to load cross-encoder model: {e}")
return _reranker_model
def rerank_chunks(
query: str,
chunks: List[Dict[str, Any]],
top_k: int = 5,
model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"
) -> RerankResult:
"""
Rerank chunks using a cross-encoder model.
Cross-encoders process query and document together, enabling
more nuanced relevance scoring than bi-encoders.
Args:
query: User query
chunks: List of chunks to rerank
top_k: Number of top results to return
model_name: Cross-encoder model to use
Returns:
RerankResult with reranked chunks
"""
if not chunks:
return RerankResult(chunks=[], model_used="none", reranked=False)
if len(chunks) <= 1:
return RerankResult(chunks=chunks, model_used="none", reranked=False)
try:
model = _get_cross_encoder(model_name)
# Prepare query-document pairs
pairs = []
for chunk in chunks:
text = chunk.get("text", "")
if not text and "metadata" in chunk:
text = chunk["metadata"].get("text", "")
pairs.append((query, text))
# Get relevance scores
scores = model.predict(pairs)
# Combine chunks with scores and sort
scored_chunks = list(zip(chunks, scores))
scored_chunks.sort(key=lambda x: x[1], reverse=True)
# Build result with rerank scores
results = []
for chunk, score in scored_chunks[:top_k]:
result_chunk = chunk.copy()
result_chunk["rerank_score"] = float(score)
results.append(result_chunk)
return RerankResult(
chunks=results,
model_used=model_name,
reranked=True
)
except Exception as e:
# Fallback: return original chunks without reranking
return RerankResult(
chunks=chunks[:top_k],
model_used=f"fallback (error: {str(e)[:50]})",
reranked=False
)
def rerank_with_llm(
query: str,
chunks: List[Dict[str, Any]],
top_k: int = 5
) -> RerankResult:
"""
Rerank chunks using LLM-based scoring (fallback method).
More expensive but works without additional models.
Args:
query: User query
chunks: List of chunks to rerank
top_k: Number of top results to return
Returns:
RerankResult with reranked chunks
"""
if not chunks:
return RerankResult(chunks=[], model_used="none", reranked=False)
if len(chunks) <= 1:
return RerankResult(chunks=chunks, model_used="none", reranked=False)
try:
from src.llm_providers import call_llm
# Build scoring prompt
chunk_texts = []
for i, chunk in enumerate(chunks):
text = chunk.get("text", "")[:500] # Truncate for prompt size
chunk_texts.append(f"[{i}] {text}")
prompt = f"""Rate the relevance of each document to the query on a scale of 0-10.
Return ONLY a comma-separated list of scores in order (e.g., "8,3,7,5").
Query: {query}
Documents:
{chr(10).join(chunk_texts)}
Scores:"""
response = call_llm(prompt=prompt, temperature=0.0, max_tokens=100)
scores_text = response.get("text", "").strip()
# Parse scores
try:
scores = [float(s.strip()) for s in scores_text.split(",")]
if len(scores) != len(chunks):
raise ValueError("Score count mismatch")
except (ValueError, AttributeError):
# Fallback to original order
return RerankResult(
chunks=chunks[:top_k],
model_used="llm_parse_failed",
reranked=False
)
# Sort by scores
scored_chunks = list(zip(chunks, scores))
scored_chunks.sort(key=lambda x: x[1], reverse=True)
results = []
for chunk, score in scored_chunks[:top_k]:
result_chunk = chunk.copy()
result_chunk["rerank_score"] = float(score)
results.append(result_chunk)
return RerankResult(
chunks=results,
model_used="llm",
reranked=True
)
except Exception as e:
return RerankResult(
chunks=chunks[:top_k],
model_used=f"llm_fallback (error: {str(e)[:50]})",
reranked=False
)
|