File size: 5,195 Bytes
a608d21 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Union
import torch
from sentence_transformers import SentenceTransformer
from torch.utils.data import Dataset
from .beir_data import DatasetSplit
from .encoders import EncoderSpec
@dataclass(frozen=True)
class CachedSplit:
split: DatasetSplit
document_embeddings: dict[str, torch.Tensor]
query_embeddings: dict[str, torch.Tensor]
negatives: dict[str, list[str]]
def _query_cache_path(cache_dir: Path, split_name: str, encoder_key: str) -> Path:
return cache_dir / split_name / f"query_embeddings_{encoder_key}.pt"
def load_document_embeddings(cache_dir: Path, split_name: str) -> dict[str, torch.Tensor]:
return torch.load(cache_dir / split_name / "embeddings.pt", map_location="cpu", weights_only=True)
def load_negatives(cache_dir: Path, split_name: str) -> dict[str, list[str]]:
with open(cache_dir / split_name / "negatives.json") as handle:
return json.load(handle)
def encode_queries(
queries: dict[str, str],
encoder_spec: EncoderSpec,
cache_dir: Path,
split_name: str,
device: str,
batch_size: int = 64,
) -> dict[str, torch.Tensor]:
cache_path = _query_cache_path(cache_dir, split_name, encoder_spec.key)
if cache_path.exists():
return torch.load(cache_path, map_location="cpu", weights_only=True)
cache_path.parent.mkdir(parents=True, exist_ok=True)
model = SentenceTransformer(encoder_spec.model_name, device=device)
query_ids = list(queries.keys())
texts = [encoder_spec.query_prefix + queries[qid] for qid in query_ids]
encoded = model.encode(
texts,
batch_size=batch_size,
convert_to_tensor=True,
show_progress_bar=True,
device=device,
)
query_embeddings = {qid: embedding.cpu() for qid, embedding in zip(query_ids, encoded)}
torch.save(query_embeddings, cache_path)
return query_embeddings
def load_cached_split(
cache_dir: Path,
split_name: str,
dataset_source: DatasetSplit,
encoder_spec: EncoderSpec,
device: str,
) -> CachedSplit:
negatives = load_negatives(cache_dir, split_name)
cached_qids = list(negatives.keys())
filtered_queries = {
qid: dataset_source.queries[qid]
for qid in cached_qids
if qid in dataset_source.queries and qid in dataset_source.qrels
}
filtered_qrels = {qid: dataset_source.qrels[qid] for qid in filtered_queries}
filtered_split = DatasetSplit(
corpus=dataset_source.corpus,
queries=filtered_queries,
qrels=filtered_qrels,
)
return CachedSplit(
split=filtered_split,
document_embeddings=load_document_embeddings(cache_dir, split_name),
query_embeddings=encode_queries(filtered_split.queries, encoder_spec, cache_dir, split_name, device),
negatives=negatives,
)
class ContrastiveCachedDataset(Dataset):
def __init__(
self,
cached_split: CachedSplit,
num_negatives: int,
) -> None:
self.cached_split = cached_split
self.num_negatives = num_negatives
self.examples: list[tuple[str, str, list[str]]] = []
for qid, qrel in cached_split.split.qrels.items():
if qid not in cached_split.query_embeddings:
continue
positives = [doc_id for doc_id, rel in qrel.items() if rel > 0 and doc_id in cached_split.document_embeddings]
negatives = [doc_id for doc_id in cached_split.negatives.get(qid, []) if doc_id in cached_split.document_embeddings]
if not positives or not negatives:
continue
self.examples.append((qid, positives[0], negatives[:num_negatives]))
def __len__(self) -> int:
return len(self.examples)
def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
qid, positive_id, negative_ids = self.examples[index]
query_embedding = self.cached_split.query_embeddings[qid].float()
positive_embedding = self.cached_split.document_embeddings[positive_id].float()
normalized_negative_ids = list(negative_ids[: self.num_negatives])
if not normalized_negative_ids:
normalized_negative_ids = [positive_id] * self.num_negatives
while len(normalized_negative_ids) < self.num_negatives:
normalized_negative_ids.append(normalized_negative_ids[-1])
negative_embeddings = [
self.cached_split.document_embeddings[doc_id].float() for doc_id in normalized_negative_ids
]
documents = torch.stack([positive_embedding, *negative_embeddings], dim=0)
return {
"qid": qid,
"query_embedding": query_embedding,
"documents": documents,
}
def collate_contrastive_batch(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, Union[torch.Tensor, List[str]]]:
return {
"qids": [item["qid"] for item in batch],
"query_embeddings": torch.stack([item["query_embedding"] for item in batch], dim=0),
"documents": torch.stack([item["documents"] for item in batch], dim=0),
}
|