IMRNNs / src /imrnns /data.py
yashsaxena21's picture
Upload folder using huggingface_hub
a608d21 verified
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Union
import torch
from sentence_transformers import SentenceTransformer
from torch.utils.data import Dataset
from .beir_data import DatasetSplit
from .encoders import EncoderSpec
@dataclass(frozen=True)
class CachedSplit:
split: DatasetSplit
document_embeddings: dict[str, torch.Tensor]
query_embeddings: dict[str, torch.Tensor]
negatives: dict[str, list[str]]
def _query_cache_path(cache_dir: Path, split_name: str, encoder_key: str) -> Path:
return cache_dir / split_name / f"query_embeddings_{encoder_key}.pt"
def load_document_embeddings(cache_dir: Path, split_name: str) -> dict[str, torch.Tensor]:
return torch.load(cache_dir / split_name / "embeddings.pt", map_location="cpu", weights_only=True)
def load_negatives(cache_dir: Path, split_name: str) -> dict[str, list[str]]:
with open(cache_dir / split_name / "negatives.json") as handle:
return json.load(handle)
def encode_queries(
queries: dict[str, str],
encoder_spec: EncoderSpec,
cache_dir: Path,
split_name: str,
device: str,
batch_size: int = 64,
) -> dict[str, torch.Tensor]:
cache_path = _query_cache_path(cache_dir, split_name, encoder_spec.key)
if cache_path.exists():
return torch.load(cache_path, map_location="cpu", weights_only=True)
cache_path.parent.mkdir(parents=True, exist_ok=True)
model = SentenceTransformer(encoder_spec.model_name, device=device)
query_ids = list(queries.keys())
texts = [encoder_spec.query_prefix + queries[qid] for qid in query_ids]
encoded = model.encode(
texts,
batch_size=batch_size,
convert_to_tensor=True,
show_progress_bar=True,
device=device,
)
query_embeddings = {qid: embedding.cpu() for qid, embedding in zip(query_ids, encoded)}
torch.save(query_embeddings, cache_path)
return query_embeddings
def load_cached_split(
cache_dir: Path,
split_name: str,
dataset_source: DatasetSplit,
encoder_spec: EncoderSpec,
device: str,
) -> CachedSplit:
negatives = load_negatives(cache_dir, split_name)
cached_qids = list(negatives.keys())
filtered_queries = {
qid: dataset_source.queries[qid]
for qid in cached_qids
if qid in dataset_source.queries and qid in dataset_source.qrels
}
filtered_qrels = {qid: dataset_source.qrels[qid] for qid in filtered_queries}
filtered_split = DatasetSplit(
corpus=dataset_source.corpus,
queries=filtered_queries,
qrels=filtered_qrels,
)
return CachedSplit(
split=filtered_split,
document_embeddings=load_document_embeddings(cache_dir, split_name),
query_embeddings=encode_queries(filtered_split.queries, encoder_spec, cache_dir, split_name, device),
negatives=negatives,
)
class ContrastiveCachedDataset(Dataset):
def __init__(
self,
cached_split: CachedSplit,
num_negatives: int,
) -> None:
self.cached_split = cached_split
self.num_negatives = num_negatives
self.examples: list[tuple[str, str, list[str]]] = []
for qid, qrel in cached_split.split.qrels.items():
if qid not in cached_split.query_embeddings:
continue
positives = [doc_id for doc_id, rel in qrel.items() if rel > 0 and doc_id in cached_split.document_embeddings]
negatives = [doc_id for doc_id in cached_split.negatives.get(qid, []) if doc_id in cached_split.document_embeddings]
if not positives or not negatives:
continue
self.examples.append((qid, positives[0], negatives[:num_negatives]))
def __len__(self) -> int:
return len(self.examples)
def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
qid, positive_id, negative_ids = self.examples[index]
query_embedding = self.cached_split.query_embeddings[qid].float()
positive_embedding = self.cached_split.document_embeddings[positive_id].float()
normalized_negative_ids = list(negative_ids[: self.num_negatives])
if not normalized_negative_ids:
normalized_negative_ids = [positive_id] * self.num_negatives
while len(normalized_negative_ids) < self.num_negatives:
normalized_negative_ids.append(normalized_negative_ids[-1])
negative_embeddings = [
self.cached_split.document_embeddings[doc_id].float() for doc_id in normalized_negative_ids
]
documents = torch.stack([positive_embedding, *negative_embeddings], dim=0)
return {
"qid": qid,
"query_embedding": query_embedding,
"documents": documents,
}
def collate_contrastive_batch(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, Union[torch.Tensor, List[str]]]:
return {
"qids": [item["qid"] for item in batch],
"query_embeddings": torch.stack([item["query_embedding"] for item in batch], dim=0),
"documents": torch.stack([item["documents"] for item in batch], dim=0),
}