from __future__ import annotations import json from dataclasses import dataclass from pathlib import Path from typing import Dict, List, Union import torch from sentence_transformers import SentenceTransformer from torch.utils.data import Dataset from .beir_data import DatasetSplit from .encoders import EncoderSpec @dataclass(frozen=True) class CachedSplit: split: DatasetSplit document_embeddings: dict[str, torch.Tensor] query_embeddings: dict[str, torch.Tensor] negatives: dict[str, list[str]] def _query_cache_path(cache_dir: Path, split_name: str, encoder_key: str) -> Path: return cache_dir / split_name / f"query_embeddings_{encoder_key}.pt" def load_document_embeddings(cache_dir: Path, split_name: str) -> dict[str, torch.Tensor]: return torch.load(cache_dir / split_name / "embeddings.pt", map_location="cpu", weights_only=True) def load_negatives(cache_dir: Path, split_name: str) -> dict[str, list[str]]: with open(cache_dir / split_name / "negatives.json") as handle: return json.load(handle) def encode_queries( queries: dict[str, str], encoder_spec: EncoderSpec, cache_dir: Path, split_name: str, device: str, batch_size: int = 64, ) -> dict[str, torch.Tensor]: cache_path = _query_cache_path(cache_dir, split_name, encoder_spec.key) if cache_path.exists(): return torch.load(cache_path, map_location="cpu", weights_only=True) cache_path.parent.mkdir(parents=True, exist_ok=True) model = SentenceTransformer(encoder_spec.model_name, device=device) query_ids = list(queries.keys()) texts = [encoder_spec.query_prefix + queries[qid] for qid in query_ids] encoded = model.encode( texts, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=True, device=device, ) query_embeddings = {qid: embedding.cpu() for qid, embedding in zip(query_ids, encoded)} torch.save(query_embeddings, cache_path) return query_embeddings def load_cached_split( cache_dir: Path, split_name: str, dataset_source: DatasetSplit, encoder_spec: EncoderSpec, device: str, ) -> CachedSplit: negatives = load_negatives(cache_dir, split_name) cached_qids = list(negatives.keys()) filtered_queries = { qid: dataset_source.queries[qid] for qid in cached_qids if qid in dataset_source.queries and qid in dataset_source.qrels } filtered_qrels = {qid: dataset_source.qrels[qid] for qid in filtered_queries} filtered_split = DatasetSplit( corpus=dataset_source.corpus, queries=filtered_queries, qrels=filtered_qrels, ) return CachedSplit( split=filtered_split, document_embeddings=load_document_embeddings(cache_dir, split_name), query_embeddings=encode_queries(filtered_split.queries, encoder_spec, cache_dir, split_name, device), negatives=negatives, ) class ContrastiveCachedDataset(Dataset): def __init__( self, cached_split: CachedSplit, num_negatives: int, ) -> None: self.cached_split = cached_split self.num_negatives = num_negatives self.examples: list[tuple[str, str, list[str]]] = [] for qid, qrel in cached_split.split.qrels.items(): if qid not in cached_split.query_embeddings: continue positives = [doc_id for doc_id, rel in qrel.items() if rel > 0 and doc_id in cached_split.document_embeddings] negatives = [doc_id for doc_id in cached_split.negatives.get(qid, []) if doc_id in cached_split.document_embeddings] if not positives or not negatives: continue self.examples.append((qid, positives[0], negatives[:num_negatives])) def __len__(self) -> int: return len(self.examples) def __getitem__(self, index: int) -> dict[str, torch.Tensor]: qid, positive_id, negative_ids = self.examples[index] query_embedding = self.cached_split.query_embeddings[qid].float() positive_embedding = self.cached_split.document_embeddings[positive_id].float() normalized_negative_ids = list(negative_ids[: self.num_negatives]) if not normalized_negative_ids: normalized_negative_ids = [positive_id] * self.num_negatives while len(normalized_negative_ids) < self.num_negatives: normalized_negative_ids.append(normalized_negative_ids[-1]) negative_embeddings = [ self.cached_split.document_embeddings[doc_id].float() for doc_id in normalized_negative_ids ] documents = torch.stack([positive_embedding, *negative_embeddings], dim=0) return { "qid": qid, "query_embedding": query_embedding, "documents": documents, } def collate_contrastive_batch(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, Union[torch.Tensor, List[str]]]: return { "qids": [item["qid"] for item in batch], "query_embeddings": torch.stack([item["query_embedding"] for item in batch], dim=0), "documents": torch.stack([item["documents"] for item in batch], dim=0), }