| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class HyperNet(nn.Module): |
| def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, dropout: float = 0.1): |
| super().__init__() |
| self.input_dim = input_dim |
| self.output_dim = output_dim |
| self.hypernet = nn.Sequential( |
| nn.Linear(input_dim, hidden_dim), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim, hidden_dim // 2), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim // 2, output_dim * 2), |
| ) |
|
|
| for layer in self.hypernet: |
| if isinstance(layer, nn.Linear): |
| nn.init.normal_(layer.weight, std=0.01) |
| nn.init.zeros_(layer.bias) |
|
|
| def forward(self, embedding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| hyper_output = self.hypernet(embedding) |
| scale = torch.sigmoid(hyper_output[:, : self.output_dim]) |
| bias = torch.tanh(hyper_output[:, self.output_dim :]) |
| return scale, bias |
|
|
|
|
| @dataclass(frozen=True) |
| class ModelConfig: |
| input_dim: int |
| output_dim: int = 256 |
| hidden_dim: int = 128 |
| dropout: float = 0.1 |
|
|
|
|
| class IMRNN(nn.Module): |
| """ |
| Adapter-only IMRNN implementation over cached dense embeddings. |
| |
| The model keeps the legacy module names (`query_hypernet`, `doc_hypernet`, |
| `query_norm`, `doc_norm`) so existing `bihypernet_*.pt` checkpoints can be |
| loaded with key remapping and `strict=False`. |
| """ |
|
|
| def __init__(self, config: ModelConfig): |
| super().__init__() |
| self.config = config |
| self.projector = nn.Linear(config.input_dim, config.output_dim) |
| self.query_hypernet = HyperNet(config.output_dim, config.hidden_dim, config.output_dim, config.dropout) |
| self.doc_hypernet = HyperNet(config.output_dim, config.hidden_dim, config.output_dim, config.dropout) |
| self.query_norm = nn.LayerNorm(config.output_dim) |
| self.doc_norm = nn.LayerNorm(config.output_dim) |
|
|
| def project(self, embeddings: torch.Tensor) -> torch.Tensor: |
| return F.normalize(self.projector(embeddings), p=2, dim=-1) |
|
|
| def modulate_documents( |
| self, |
| query_embeddings: torch.Tensor, |
| document_embeddings: torch.Tensor, |
| ) -> torch.Tensor: |
| q_scale, q_bias = self.query_hypernet(query_embeddings) |
| return self.doc_norm( |
| document_embeddings * q_scale.unsqueeze(1) + q_bias.unsqueeze(1) |
| ) |
|
|
| def modulate_query( |
| self, |
| query_embeddings: torch.Tensor, |
| modulated_documents: torch.Tensor, |
| ) -> torch.Tensor: |
| document_summary = modulated_documents.mean(dim=1) |
| d_scale, d_bias = self.doc_hypernet(document_summary) |
| return self.query_norm(query_embeddings * d_scale + d_bias) |
|
|
| def forward( |
| self, |
| query_embeddings: torch.Tensor, |
| document_embeddings: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| Args: |
| query_embeddings: [batch, input_dim] |
| document_embeddings: [batch, docs_per_query, input_dim] |
| """ |
| projected_queries = self.project(query_embeddings) |
| projected_documents = self.project(document_embeddings) |
| modulated_documents = self.modulate_documents(projected_queries, projected_documents) |
| modulated_queries = self.modulate_query(projected_queries, modulated_documents) |
| scores = torch.einsum("bd,bkd->bk", F.normalize(modulated_queries, p=2, dim=-1), F.normalize(modulated_documents, p=2, dim=-1)) |
| return modulated_queries, modulated_documents, scores |
|
|
| def score_candidates( |
| self, |
| query_embedding: torch.Tensor, |
| candidate_document_embeddings: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| if query_embedding.dim() == 1: |
| query_embedding = query_embedding.unsqueeze(0) |
| if candidate_document_embeddings.dim() == 2: |
| candidate_document_embeddings = candidate_document_embeddings.unsqueeze(0) |
| modulated_query, modulated_docs, scores = self.forward(query_embedding, candidate_document_embeddings) |
| return modulated_query.squeeze(0), modulated_docs.squeeze(0), scores.squeeze(0) |
|
|
|
|
| BiHyperNetIR = IMRNN |
| """Backward-compatible alias retained for legacy checkpoints and code paths.""" |
|
|