ZedLow's picture
Create retrieval.py
4792259 verified
import hashlib
import json
from pathlib import Path
from typing import List, Tuple
import torch
import torch.nn.functional as F
from rag.config import Settings
from rag.data import Doc
from rag.logging_utils import get_logger
logger = get_logger(__name__)
def last_token_pool(last_hidden_states, attention_mask):
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return last_hidden_states[:, -1]
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_states.shape[0]
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
def _fingerprint(docs: List[Doc], settings: Settings) -> str:
h = hashlib.sha256()
h.update(settings.embed_model_id.encode("utf-8"))
h.update(str(settings.embed_max_len).encode("utf-8"))
for d in docs:
h.update(d.doc_name.encode("utf-8"))
h.update(d.company.encode("utf-8"))
h.update(d.text.encode("utf-8"))
return h.hexdigest()
def ensure_index_dir(settings: Settings):
Path(settings.index_dir).mkdir(parents=True, exist_ok=True)
@torch.no_grad()
def build_or_load_doc_embeddings(
docs: List[Doc],
embed_tokenizer,
embed_model,
settings: Settings,
) -> Tuple[torch.Tensor, str]:
"""
Returns (doc_embeddings [N, D] on CPU, fingerprint)
Caches to data/index/doc_embeds.pt
"""
ensure_index_dir(settings)
fp = _fingerprint(docs, settings)
cache_file = settings.doc_embeds_file()
meta_file = settings.doc_meta_file()
if cache_file.exists() and meta_file.exists():
try:
meta = json.loads(meta_file.read_text(encoding="utf-8"))
if meta.get("fingerprint") == fp:
logger.info("Loading cached doc embeddings: %s", str(cache_file))
payload = torch.load(cache_file, map_location="cpu")
return payload["embeddings"], fp
except Exception as e:
logger.warning("Failed to load cache, rebuilding. Reason: %s", e)
logger.info("Building doc embeddings cache (%d docs)...", len(docs))
doc_texts = [d.text for d in docs]
embs = []
for i in range(0, len(doc_texts), settings.embed_batch_size):
batch = doc_texts[i : i + settings.embed_batch_size]
d_inputs = embed_tokenizer(
batch,
max_length=settings.embed_max_len,
padding=True,
truncation=True,
return_tensors="pt",
).to(embed_model.device)
d_outputs = embed_model(**d_inputs)
batch_emb = last_token_pool(d_outputs.last_hidden_state, d_inputs["attention_mask"])
batch_emb = F.normalize(batch_emb, p=2, dim=1)
embs.append(batch_emb.detach().to("cpu"))
doc_embs = torch.cat(embs, dim=0)
torch.save({"embeddings": doc_embs}, cache_file)
meta_file.write_text(json.dumps({"fingerprint": fp, "n_docs": len(docs)}, indent=2), encoding="utf-8")
logger.info("Saved embeddings cache: %s", str(cache_file))
return doc_embs, fp
@torch.no_grad()
def embed_query(query: str, embed_tokenizer, embed_model, settings: Settings) -> torch.Tensor:
query_text = (
"Instruct: Given a user query, retrieve relevant passages that answer the query.\n"
f"Query: {query}"
)
q_inputs = embed_tokenizer(
[query_text],
max_length=settings.embed_max_len,
padding=True,
truncation=True,
return_tensors="pt",
).to(embed_model.device)
q_outputs = embed_model(**q_inputs)
q_emb = last_token_pool(q_outputs.last_hidden_state, q_inputs["attention_mask"])
q_emb = F.normalize(q_emb, p=2, dim=1)
return q_emb.detach().to("cpu") # keep retrieval ops on CPU
def topk_retrieval(q_emb_cpu: torch.Tensor, doc_embs_cpu: torch.Tensor, k: int) -> List[int]:
# q_emb: [1, D], doc_embs: [N, D]
scores = (q_emb_cpu @ doc_embs_cpu.T).squeeze(0)
k = min(k, scores.shape[0])
return torch.topk(scores, k=k).indices.tolist()
@torch.no_grad()
def rerank(
query: str,
candidate_docs: List[Doc],
rerank_tokenizer,
rerank_model,
settings: Settings,
k: int,
) -> Tuple[List[int], torch.Tensor]:
pairs = [[query, d.text] for d in candidate_docs]
r_inputs = rerank_tokenizer(
pairs,
padding=True,
truncation=True,
return_tensors="pt",
max_length=settings.rerank_max_len,
).to(rerank_model.device)
r_scores = rerank_model(**r_inputs, return_dict=True).logits.view(-1).float().detach().to("cpu")
k = min(k, len(candidate_docs))
top_idx = torch.topk(r_scores, k=k).indices.tolist()
return top_idx, r_scores