File size: 5,195 Bytes
a608d21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from __future__ import annotations

import json
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Union

import torch
from sentence_transformers import SentenceTransformer
from torch.utils.data import Dataset

from .beir_data import DatasetSplit
from .encoders import EncoderSpec


@dataclass(frozen=True)
class CachedSplit:
    split: DatasetSplit
    document_embeddings: dict[str, torch.Tensor]
    query_embeddings: dict[str, torch.Tensor]
    negatives: dict[str, list[str]]


def _query_cache_path(cache_dir: Path, split_name: str, encoder_key: str) -> Path:
    return cache_dir / split_name / f"query_embeddings_{encoder_key}.pt"


def load_document_embeddings(cache_dir: Path, split_name: str) -> dict[str, torch.Tensor]:
    return torch.load(cache_dir / split_name / "embeddings.pt", map_location="cpu", weights_only=True)


def load_negatives(cache_dir: Path, split_name: str) -> dict[str, list[str]]:
    with open(cache_dir / split_name / "negatives.json") as handle:
        return json.load(handle)


def encode_queries(
    queries: dict[str, str],
    encoder_spec: EncoderSpec,
    cache_dir: Path,
    split_name: str,
    device: str,
    batch_size: int = 64,
) -> dict[str, torch.Tensor]:
    cache_path = _query_cache_path(cache_dir, split_name, encoder_spec.key)
    if cache_path.exists():
        return torch.load(cache_path, map_location="cpu", weights_only=True)

    cache_path.parent.mkdir(parents=True, exist_ok=True)
    model = SentenceTransformer(encoder_spec.model_name, device=device)
    query_ids = list(queries.keys())
    texts = [encoder_spec.query_prefix + queries[qid] for qid in query_ids]
    encoded = model.encode(
        texts,
        batch_size=batch_size,
        convert_to_tensor=True,
        show_progress_bar=True,
        device=device,
    )
    query_embeddings = {qid: embedding.cpu() for qid, embedding in zip(query_ids, encoded)}
    torch.save(query_embeddings, cache_path)
    return query_embeddings


def load_cached_split(
    cache_dir: Path,
    split_name: str,
    dataset_source: DatasetSplit,
    encoder_spec: EncoderSpec,
    device: str,
) -> CachedSplit:
    negatives = load_negatives(cache_dir, split_name)
    cached_qids = list(negatives.keys())
    filtered_queries = {
        qid: dataset_source.queries[qid]
        for qid in cached_qids
        if qid in dataset_source.queries and qid in dataset_source.qrels
    }
    filtered_qrels = {qid: dataset_source.qrels[qid] for qid in filtered_queries}
    filtered_split = DatasetSplit(
        corpus=dataset_source.corpus,
        queries=filtered_queries,
        qrels=filtered_qrels,
    )
    return CachedSplit(
        split=filtered_split,
        document_embeddings=load_document_embeddings(cache_dir, split_name),
        query_embeddings=encode_queries(filtered_split.queries, encoder_spec, cache_dir, split_name, device),
        negatives=negatives,
    )


class ContrastiveCachedDataset(Dataset):
    def __init__(
        self,
        cached_split: CachedSplit,
        num_negatives: int,
    ) -> None:
        self.cached_split = cached_split
        self.num_negatives = num_negatives
        self.examples: list[tuple[str, str, list[str]]] = []

        for qid, qrel in cached_split.split.qrels.items():
            if qid not in cached_split.query_embeddings:
                continue
            positives = [doc_id for doc_id, rel in qrel.items() if rel > 0 and doc_id in cached_split.document_embeddings]
            negatives = [doc_id for doc_id in cached_split.negatives.get(qid, []) if doc_id in cached_split.document_embeddings]
            if not positives or not negatives:
                continue
            self.examples.append((qid, positives[0], negatives[:num_negatives]))

    def __len__(self) -> int:
        return len(self.examples)

    def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
        qid, positive_id, negative_ids = self.examples[index]
        query_embedding = self.cached_split.query_embeddings[qid].float()
        positive_embedding = self.cached_split.document_embeddings[positive_id].float()
        normalized_negative_ids = list(negative_ids[: self.num_negatives])
        if not normalized_negative_ids:
            normalized_negative_ids = [positive_id] * self.num_negatives
        while len(normalized_negative_ids) < self.num_negatives:
            normalized_negative_ids.append(normalized_negative_ids[-1])

        negative_embeddings = [
            self.cached_split.document_embeddings[doc_id].float() for doc_id in normalized_negative_ids
        ]
        documents = torch.stack([positive_embedding, *negative_embeddings], dim=0)
        return {
            "qid": qid,
            "query_embedding": query_embedding,
            "documents": documents,
        }


def collate_contrastive_batch(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, Union[torch.Tensor, List[str]]]:
    return {
        "qids": [item["qid"] for item in batch],
        "query_embeddings": torch.stack([item["query_embedding"] for item in batch], dim=0),
        "documents": torch.stack([item["documents"] for item in batch], dim=0),
    }