| | from __future__ import annotations |
| |
|
| | import json |
| | from dataclasses import dataclass |
| | from pathlib import Path |
| | from typing import Dict, List, Union |
| |
|
| | import torch |
| | from sentence_transformers import SentenceTransformer |
| | from torch.utils.data import Dataset |
| |
|
| | from .beir_data import DatasetSplit |
| | from .encoders import EncoderSpec |
| |
|
| |
|
| | @dataclass(frozen=True) |
| | class CachedSplit: |
| | split: DatasetSplit |
| | document_embeddings: dict[str, torch.Tensor] |
| | query_embeddings: dict[str, torch.Tensor] |
| | negatives: dict[str, list[str]] |
| |
|
| |
|
| | def _query_cache_path(cache_dir: Path, split_name: str, encoder_key: str) -> Path: |
| | return cache_dir / split_name / f"query_embeddings_{encoder_key}.pt" |
| |
|
| |
|
| | def load_document_embeddings(cache_dir: Path, split_name: str) -> dict[str, torch.Tensor]: |
| | return torch.load(cache_dir / split_name / "embeddings.pt", map_location="cpu", weights_only=True) |
| |
|
| |
|
| | def load_negatives(cache_dir: Path, split_name: str) -> dict[str, list[str]]: |
| | with open(cache_dir / split_name / "negatives.json") as handle: |
| | return json.load(handle) |
| |
|
| |
|
| | def encode_queries( |
| | queries: dict[str, str], |
| | encoder_spec: EncoderSpec, |
| | cache_dir: Path, |
| | split_name: str, |
| | device: str, |
| | batch_size: int = 64, |
| | ) -> dict[str, torch.Tensor]: |
| | cache_path = _query_cache_path(cache_dir, split_name, encoder_spec.key) |
| | if cache_path.exists(): |
| | return torch.load(cache_path, map_location="cpu", weights_only=True) |
| |
|
| | cache_path.parent.mkdir(parents=True, exist_ok=True) |
| | model = SentenceTransformer(encoder_spec.model_name, device=device) |
| | query_ids = list(queries.keys()) |
| | texts = [encoder_spec.query_prefix + queries[qid] for qid in query_ids] |
| | encoded = model.encode( |
| | texts, |
| | batch_size=batch_size, |
| | convert_to_tensor=True, |
| | show_progress_bar=True, |
| | device=device, |
| | ) |
| | query_embeddings = {qid: embedding.cpu() for qid, embedding in zip(query_ids, encoded)} |
| | torch.save(query_embeddings, cache_path) |
| | return query_embeddings |
| |
|
| |
|
| | def load_cached_split( |
| | cache_dir: Path, |
| | split_name: str, |
| | dataset_source: DatasetSplit, |
| | encoder_spec: EncoderSpec, |
| | device: str, |
| | ) -> CachedSplit: |
| | negatives = load_negatives(cache_dir, split_name) |
| | cached_qids = list(negatives.keys()) |
| | filtered_queries = { |
| | qid: dataset_source.queries[qid] |
| | for qid in cached_qids |
| | if qid in dataset_source.queries and qid in dataset_source.qrels |
| | } |
| | filtered_qrels = {qid: dataset_source.qrels[qid] for qid in filtered_queries} |
| | filtered_split = DatasetSplit( |
| | corpus=dataset_source.corpus, |
| | queries=filtered_queries, |
| | qrels=filtered_qrels, |
| | ) |
| | return CachedSplit( |
| | split=filtered_split, |
| | document_embeddings=load_document_embeddings(cache_dir, split_name), |
| | query_embeddings=encode_queries(filtered_split.queries, encoder_spec, cache_dir, split_name, device), |
| | negatives=negatives, |
| | ) |
| |
|
| |
|
| | class ContrastiveCachedDataset(Dataset): |
| | def __init__( |
| | self, |
| | cached_split: CachedSplit, |
| | num_negatives: int, |
| | ) -> None: |
| | self.cached_split = cached_split |
| | self.num_negatives = num_negatives |
| | self.examples: list[tuple[str, str, list[str]]] = [] |
| |
|
| | for qid, qrel in cached_split.split.qrels.items(): |
| | if qid not in cached_split.query_embeddings: |
| | continue |
| | positives = [doc_id for doc_id, rel in qrel.items() if rel > 0 and doc_id in cached_split.document_embeddings] |
| | negatives = [doc_id for doc_id in cached_split.negatives.get(qid, []) if doc_id in cached_split.document_embeddings] |
| | if not positives or not negatives: |
| | continue |
| | self.examples.append((qid, positives[0], negatives[:num_negatives])) |
| |
|
| | def __len__(self) -> int: |
| | return len(self.examples) |
| |
|
| | def __getitem__(self, index: int) -> dict[str, torch.Tensor]: |
| | qid, positive_id, negative_ids = self.examples[index] |
| | query_embedding = self.cached_split.query_embeddings[qid].float() |
| | positive_embedding = self.cached_split.document_embeddings[positive_id].float() |
| | normalized_negative_ids = list(negative_ids[: self.num_negatives]) |
| | if not normalized_negative_ids: |
| | normalized_negative_ids = [positive_id] * self.num_negatives |
| | while len(normalized_negative_ids) < self.num_negatives: |
| | normalized_negative_ids.append(normalized_negative_ids[-1]) |
| |
|
| | negative_embeddings = [ |
| | self.cached_split.document_embeddings[doc_id].float() for doc_id in normalized_negative_ids |
| | ] |
| | documents = torch.stack([positive_embedding, *negative_embeddings], dim=0) |
| | return { |
| | "qid": qid, |
| | "query_embedding": query_embedding, |
| | "documents": documents, |
| | } |
| |
|
| |
|
| | def collate_contrastive_batch(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, Union[torch.Tensor, List[str]]]: |
| | return { |
| | "qids": [item["qid"] for item in batch], |
| | "query_embeddings": torch.stack([item["query_embedding"] for item in batch], dim=0), |
| | "documents": torch.stack([item["documents"] for item in batch], dim=0), |
| | } |
| |
|