yashsaxena21 commited on
Commit
11bc1ef
·
verified ·
1 Parent(s): 4ae6a27

Upload folder using huggingface_hub

Browse files
src/imrnns/__init__.py CHANGED
@@ -1,15 +1,16 @@
1
  """IMRNNs package."""
2
 
3
  from .api import cache_embeddings, evaluate, run, train
 
4
  from .hub import DEFAULT_REPO_ID, download_checkpoint, get_download_count, load_pretrained
5
- from .model import BiHyperNetIR, HyperNet, IMRNN, ModelConfig
6
 
7
  __all__ = [
8
- "BiHyperNetIR",
9
  "DEFAULT_REPO_ID",
10
- "HyperNet",
11
  "IMRNN",
12
  "ModelConfig",
 
13
  "cache_embeddings",
14
  "download_checkpoint",
15
  "evaluate",
 
1
  """IMRNNs package."""
2
 
3
  from .api import cache_embeddings, evaluate, run, train
4
+ from .adapter import IMRNNAdapter, RetrievalResult
5
  from .hub import DEFAULT_REPO_ID, download_checkpoint, get_download_count, load_pretrained
6
+ from .model import IMRNN, ModelConfig
7
 
8
  __all__ = [
 
9
  "DEFAULT_REPO_ID",
10
+ "IMRNNAdapter",
11
  "IMRNN",
12
  "ModelConfig",
13
+ "RetrievalResult",
14
  "cache_embeddings",
15
  "download_checkpoint",
16
  "evaluate",
src/imrnns/adapter.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Sequence
5
+
6
+ import torch
7
+ from sentence_transformers import SentenceTransformer
8
+
9
+ from .encoders import EncoderSpec
10
+ from .hub import DEFAULT_REPO_ID, load_pretrained
11
+ from .model import IMRNN
12
+
13
+
14
+ @dataclass(frozen=True)
15
+ class RetrievalResult:
16
+ rank: int
17
+ index: int
18
+ text: str
19
+ score: float
20
+
21
+
22
+ def _format_query(text: str, encoder_spec: EncoderSpec) -> str:
23
+ return f"{encoder_spec.query_prefix}{text}" if encoder_spec.query_prefix else text
24
+
25
+
26
+ def _format_document(text: str, encoder_spec: EncoderSpec) -> str:
27
+ return f"{encoder_spec.passage_prefix}{text}" if encoder_spec.passage_prefix else text
28
+
29
+
30
+ class IMRNNAdapter:
31
+ """Inference wrapper for applying a pretrained IMRNN adapter to a base retriever."""
32
+
33
+ def __init__(
34
+ self,
35
+ *,
36
+ model: IMRNN,
37
+ encoder: SentenceTransformer,
38
+ encoder_spec: EncoderSpec,
39
+ metadata: dict[str, Any],
40
+ device: str,
41
+ ) -> None:
42
+ self.model = model
43
+ self.encoder = encoder
44
+ self.encoder_spec = encoder_spec
45
+ self.metadata = metadata
46
+ self.device = device
47
+
48
+ @classmethod
49
+ def from_pretrained(
50
+ cls,
51
+ *,
52
+ encoder: str,
53
+ dataset: str,
54
+ repo_id: str = DEFAULT_REPO_ID,
55
+ device: str = "cpu",
56
+ ) -> "IMRNNAdapter":
57
+ model, metadata, encoder_spec = load_pretrained(
58
+ encoder=encoder,
59
+ dataset=dataset,
60
+ repo_id=repo_id,
61
+ device=device,
62
+ )
63
+ encoder_model = SentenceTransformer(encoder_spec.model_name, device=device)
64
+ return cls(
65
+ model=model,
66
+ encoder=encoder_model,
67
+ encoder_spec=encoder_spec,
68
+ metadata=metadata,
69
+ device=device,
70
+ )
71
+
72
+ def score(self, query: str, documents: Sequence[str], top_k: int | None = None) -> list[RetrievalResult]:
73
+ if not documents:
74
+ return []
75
+
76
+ formatted_query = _format_query(query, self.encoder_spec)
77
+ formatted_documents = [_format_document(document, self.encoder_spec) for document in documents]
78
+
79
+ with torch.no_grad():
80
+ query_embedding = self.encoder.encode(
81
+ [formatted_query],
82
+ convert_to_tensor=True,
83
+ show_progress_bar=False,
84
+ device=self.device,
85
+ )[0].to(self.device)
86
+ document_embeddings = self.encoder.encode(
87
+ formatted_documents,
88
+ convert_to_tensor=True,
89
+ show_progress_bar=False,
90
+ device=self.device,
91
+ ).to(self.device)
92
+ _, _, scores = self.model.score_candidates(query_embedding, document_embeddings)
93
+
94
+ ranked_indices = torch.argsort(scores, descending=True).tolist()
95
+ if top_k is not None:
96
+ ranked_indices = ranked_indices[:top_k]
97
+
98
+ return [
99
+ RetrievalResult(
100
+ rank=rank,
101
+ index=index,
102
+ text=documents[index],
103
+ score=float(scores[index].item()),
104
+ )
105
+ for rank, index in enumerate(ranked_indices, start=1)
106
+ ]
src/imrnns/api.py CHANGED
@@ -9,7 +9,7 @@ from .checkpoints import default_checkpoint_name, load_model, save_checkpoint
9
  from .data import ContrastiveCachedDataset, load_cached_split
10
  from .encoders import get_encoder_spec
11
  from .evaluation import evaluate_model
12
- from .model import BiHyperNetIR, ModelConfig
13
  from .training import TrainingConfig, train_model
14
 
15
 
@@ -66,7 +66,7 @@ def train(
66
  val_split = load_cached_split(cache_dir, "val", beir_source, encoder_spec, device)
67
  test_split = load_cached_split(cache_dir, "test", beir_source, encoder_spec, device)
68
 
69
- model = BiHyperNetIR(
70
  ModelConfig(
71
  input_dim=encoder_spec.embedding_dim,
72
  output_dim=output_dim,
 
9
  from .data import ContrastiveCachedDataset, load_cached_split
10
  from .encoders import get_encoder_spec
11
  from .evaluation import evaluate_model
12
+ from .model import IMRNN, ModelConfig
13
  from .training import TrainingConfig, train_model
14
 
15
 
 
66
  val_split = load_cached_split(cache_dir, "val", beir_source, encoder_spec, device)
67
  test_split = load_cached_split(cache_dir, "test", beir_source, encoder_spec, device)
68
 
69
+ model = IMRNN(
70
  ModelConfig(
71
  input_dim=encoder_spec.embedding_dim,
72
  output_dim=output_dim,
src/imrnns/checkpoints.py CHANGED
@@ -7,7 +7,7 @@ from typing import Any
7
  import torch
8
 
9
  from .encoders import normalize_encoder_name
10
- from .model import BiHyperNetIR, ModelConfig
11
 
12
 
13
  def default_checkpoint_name(encoder: str, dataset: str) -> str:
@@ -29,7 +29,7 @@ def sanitize_legacy_state_dict(state_dict: dict[str, Any]) -> dict[str, Any]:
29
 
30
  def save_checkpoint(
31
  path: Path,
32
- model: BiHyperNetIR,
33
  metadata: dict[str, Any],
34
  ) -> None:
35
  payload = {
@@ -56,9 +56,9 @@ def load_model(
56
  checkpoint_path: Path,
57
  model_config: ModelConfig,
58
  device: str,
59
- ) -> tuple[BiHyperNetIR, dict[str, Any], list[str], list[str]]:
60
  state_dict, metadata = load_checkpoint(checkpoint_path)
61
- model = BiHyperNetIR(model_config)
62
  missing, unexpected = model.load_state_dict(state_dict, strict=False)
63
  model.to(device)
64
  model.eval()
 
7
  import torch
8
 
9
  from .encoders import normalize_encoder_name
10
+ from .model import IMRNN, ModelConfig
11
 
12
 
13
  def default_checkpoint_name(encoder: str, dataset: str) -> str:
 
29
 
30
  def save_checkpoint(
31
  path: Path,
32
+ model: IMRNN,
33
  metadata: dict[str, Any],
34
  ) -> None:
35
  payload = {
 
56
  checkpoint_path: Path,
57
  model_config: ModelConfig,
58
  device: str,
59
+ ) -> tuple[IMRNN, dict[str, Any], list[str], list[str]]:
60
  state_dict, metadata = load_checkpoint(checkpoint_path)
61
+ model = IMRNN(model_config)
62
  missing, unexpected = model.load_state_dict(state_dict, strict=False)
63
  model.to(device)
64
  model.eval()
src/imrnns/cli.py CHANGED
@@ -19,7 +19,7 @@ from .checkpoints import default_checkpoint_name, load_model, save_checkpoint
19
  from .data import ContrastiveCachedDataset, load_cached_split
20
  from .encoders import get_encoder_spec
21
  from .evaluation import evaluate_model
22
- from .model import BiHyperNetIR, ModelConfig
23
  from .training import TrainingConfig, train_model
24
 
25
 
@@ -85,7 +85,7 @@ def _command_cache(args: argparse.Namespace) -> int:
85
 
86
  def _command_train(args: argparse.Namespace) -> int:
87
  encoder_spec, cache_dir, train_split, val_split, test_split = _load_training_inputs(args)
88
- model = BiHyperNetIR(
89
  ModelConfig(
90
  input_dim=encoder_spec.embedding_dim,
91
  output_dim=args.output_dim,
 
19
  from .data import ContrastiveCachedDataset, load_cached_split
20
  from .encoders import get_encoder_spec
21
  from .evaluation import evaluate_model
22
+ from .model import IMRNN, ModelConfig
23
  from .training import TrainingConfig, train_model
24
 
25
 
 
85
 
86
  def _command_train(args: argparse.Namespace) -> int:
87
  encoder_spec, cache_dir, train_split, val_split, test_split = _load_training_inputs(args)
88
+ model = IMRNN(
89
  ModelConfig(
90
  input_dim=encoder_spec.embedding_dim,
91
  output_dim=args.output_dim,
src/imrnns/evaluation.py CHANGED
@@ -10,7 +10,7 @@ import torch.nn.functional as F
10
  from tqdm import tqdm
11
 
12
  from .data import CachedSplit
13
- from .model import BiHyperNetIR
14
 
15
  try:
16
  import faiss # type: ignore
@@ -73,7 +73,7 @@ def _compute_metrics(ranked_doc_ids: list[str], qrel: dict[str, int], k_values:
73
 
74
 
75
  def evaluate_model(
76
- model: BiHyperNetIR,
77
  cached_split: CachedSplit,
78
  device: str,
79
  feedback_k: int = 100,
@@ -120,10 +120,10 @@ def evaluate_model(
120
  dim=0,
121
  ).to(device)
122
 
123
- _, _, rerank_scores = model.encode_candidates(query_embedding.float().to(device), candidate_embeddings)
124
- rerank_scores = rerank_scores.cpu().tolist()
125
  reranked = [
126
- doc_id for doc_id, _ in sorted(zip(candidate_ids, rerank_scores), key=lambda item: item[1], reverse=True)
127
  ][:ranking_k]
128
 
129
  metrics = _compute_metrics(reranked, cached_split.split.qrels[qid], k_values)
 
10
  from tqdm import tqdm
11
 
12
  from .data import CachedSplit
13
+ from .model import IMRNN
14
 
15
  try:
16
  import faiss # type: ignore
 
73
 
74
 
75
  def evaluate_model(
76
+ model: IMRNN,
77
  cached_split: CachedSplit,
78
  device: str,
79
  feedback_k: int = 100,
 
120
  dim=0,
121
  ).to(device)
122
 
123
+ _, _, adapted_scores = model.score_candidates(query_embedding.float().to(device), candidate_embeddings)
124
+ adapted_scores = adapted_scores.cpu().tolist()
125
  reranked = [
126
+ doc_id for doc_id, _ in sorted(zip(candidate_ids, adapted_scores), key=lambda item: item[1], reverse=True)
127
  ][:ranking_k]
128
 
129
  metrics = _compute_metrics(reranked, cached_split.split.qrels[qid], k_values)
src/imrnns/hub.py CHANGED
@@ -9,7 +9,7 @@ from huggingface_hub import HfApi, hf_hub_download
9
 
10
  from .checkpoints import default_checkpoint_name, load_model
11
  from .encoders import EncoderSpec, get_encoder_spec, normalize_encoder_name
12
- from .model import BiHyperNetIR, ModelConfig
13
 
14
  DEFAULT_REPO_ID = "yashsaxena21/IMRNNs"
15
  CONFIG_FILENAME = "config.json"
@@ -92,7 +92,7 @@ def load_pretrained(
92
  revision: Optional[str] = None,
93
  cache_dir: Optional[Path] = None,
94
  local_files_only: bool = False,
95
- ) -> tuple[BiHyperNetIR, dict[str, Any], EncoderSpec]:
96
  encoder_spec = get_encoder_spec(encoder)
97
  pretrained = download_checkpoint(
98
  encoder=encoder,
 
9
 
10
  from .checkpoints import default_checkpoint_name, load_model
11
  from .encoders import EncoderSpec, get_encoder_spec, normalize_encoder_name
12
+ from .model import IMRNN, ModelConfig
13
 
14
  DEFAULT_REPO_ID = "yashsaxena21/IMRNNs"
15
  CONFIG_FILENAME = "config.json"
 
92
  revision: Optional[str] = None,
93
  cache_dir: Optional[Path] = None,
94
  local_files_only: bool = False,
95
+ ) -> tuple[IMRNN, dict[str, Any], EncoderSpec]:
96
  encoder_spec = get_encoder_spec(encoder)
97
  pretrained = download_checkpoint(
98
  encoder=encoder,
src/imrnns/model.py CHANGED
@@ -99,7 +99,7 @@ class IMRNN(nn.Module):
99
  scores = torch.einsum("bd,bkd->bk", F.normalize(modulated_queries, p=2, dim=-1), F.normalize(modulated_documents, p=2, dim=-1))
100
  return modulated_queries, modulated_documents, scores
101
 
102
- def encode_candidates(
103
  self,
104
  query_embedding: torch.Tensor,
105
  candidate_document_embeddings: torch.Tensor,
@@ -112,5 +112,5 @@ class IMRNN(nn.Module):
112
  return modulated_query.squeeze(0), modulated_docs.squeeze(0), scores.squeeze(0)
113
 
114
 
115
- class BiHyperNetIR(IMRNN):
116
- """Backward-compatible alias for legacy checkpoints and code paths."""
 
99
  scores = torch.einsum("bd,bkd->bk", F.normalize(modulated_queries, p=2, dim=-1), F.normalize(modulated_documents, p=2, dim=-1))
100
  return modulated_queries, modulated_documents, scores
101
 
102
+ def score_candidates(
103
  self,
104
  query_embedding: torch.Tensor,
105
  candidate_document_embeddings: torch.Tensor,
 
112
  return modulated_query.squeeze(0), modulated_docs.squeeze(0), scores.squeeze(0)
113
 
114
 
115
+ BiHyperNetIR = IMRNN
116
+ """Backward-compatible alias retained for legacy checkpoints and code paths."""
src/imrnns/training.py CHANGED
@@ -8,7 +8,7 @@ from torch.utils.data import DataLoader
8
  from tqdm import tqdm
9
 
10
  from .data import ContrastiveCachedDataset, collate_contrastive_batch
11
- from .model import BiHyperNetIR
12
 
13
 
14
  class MultipleNegativesRankingLoss(torch.nn.Module):
@@ -45,7 +45,7 @@ def build_dataloader(dataset: ContrastiveCachedDataset, batch_size: int, shuffle
45
 
46
 
47
  def evaluate_loss(
48
- model: BiHyperNetIR,
49
  dataloader: DataLoader,
50
  device: str,
51
  loss_fn: MultipleNegativesRankingLoss,
@@ -67,7 +67,7 @@ def evaluate_loss(
67
 
68
 
69
  def train_model(
70
- model: BiHyperNetIR,
71
  train_dataset: ContrastiveCachedDataset,
72
  val_dataset: ContrastiveCachedDataset,
73
  config: TrainingConfig,
 
8
  from tqdm import tqdm
9
 
10
  from .data import ContrastiveCachedDataset, collate_contrastive_batch
11
+ from .model import IMRNN
12
 
13
 
14
  class MultipleNegativesRankingLoss(torch.nn.Module):
 
45
 
46
 
47
  def evaluate_loss(
48
+ model: IMRNN,
49
  dataloader: DataLoader,
50
  device: str,
51
  loss_fn: MultipleNegativesRankingLoss,
 
67
 
68
 
69
  def train_model(
70
+ model: IMRNN,
71
  train_dataset: ContrastiveCachedDataset,
72
  val_dataset: ContrastiveCachedDataset,
73
  config: TrainingConfig,