IMRNNs / src /imrnns /model.py
yashsaxena21's picture
Upload folder using huggingface_hub
11bc1ef verified
from __future__ import annotations
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
class HyperNet(nn.Module):
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, dropout: float = 0.1):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.hypernet = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, output_dim * 2),
)
for layer in self.hypernet:
if isinstance(layer, nn.Linear):
nn.init.normal_(layer.weight, std=0.01)
nn.init.zeros_(layer.bias)
def forward(self, embedding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
hyper_output = self.hypernet(embedding)
scale = torch.sigmoid(hyper_output[:, : self.output_dim])
bias = torch.tanh(hyper_output[:, self.output_dim :])
return scale, bias
@dataclass(frozen=True)
class ModelConfig:
input_dim: int
output_dim: int = 256
hidden_dim: int = 128
dropout: float = 0.1
class IMRNN(nn.Module):
"""
Adapter-only IMRNN implementation over cached dense embeddings.
The model keeps the legacy module names (`query_hypernet`, `doc_hypernet`,
`query_norm`, `doc_norm`) so existing `bihypernet_*.pt` checkpoints can be
loaded with key remapping and `strict=False`.
"""
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
self.projector = nn.Linear(config.input_dim, config.output_dim)
self.query_hypernet = HyperNet(config.output_dim, config.hidden_dim, config.output_dim, config.dropout)
self.doc_hypernet = HyperNet(config.output_dim, config.hidden_dim, config.output_dim, config.dropout)
self.query_norm = nn.LayerNorm(config.output_dim)
self.doc_norm = nn.LayerNorm(config.output_dim)
def project(self, embeddings: torch.Tensor) -> torch.Tensor:
return F.normalize(self.projector(embeddings), p=2, dim=-1)
def modulate_documents(
self,
query_embeddings: torch.Tensor,
document_embeddings: torch.Tensor,
) -> torch.Tensor:
q_scale, q_bias = self.query_hypernet(query_embeddings)
return self.doc_norm(
document_embeddings * q_scale.unsqueeze(1) + q_bias.unsqueeze(1)
)
def modulate_query(
self,
query_embeddings: torch.Tensor,
modulated_documents: torch.Tensor,
) -> torch.Tensor:
document_summary = modulated_documents.mean(dim=1)
d_scale, d_bias = self.doc_hypernet(document_summary)
return self.query_norm(query_embeddings * d_scale + d_bias)
def forward(
self,
query_embeddings: torch.Tensor,
document_embeddings: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
query_embeddings: [batch, input_dim]
document_embeddings: [batch, docs_per_query, input_dim]
"""
projected_queries = self.project(query_embeddings)
projected_documents = self.project(document_embeddings)
modulated_documents = self.modulate_documents(projected_queries, projected_documents)
modulated_queries = self.modulate_query(projected_queries, modulated_documents)
scores = torch.einsum("bd,bkd->bk", F.normalize(modulated_queries, p=2, dim=-1), F.normalize(modulated_documents, p=2, dim=-1))
return modulated_queries, modulated_documents, scores
def score_candidates(
self,
query_embedding: torch.Tensor,
candidate_document_embeddings: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if query_embedding.dim() == 1:
query_embedding = query_embedding.unsqueeze(0)
if candidate_document_embeddings.dim() == 2:
candidate_document_embeddings = candidate_document_embeddings.unsqueeze(0)
modulated_query, modulated_docs, scores = self.forward(query_embedding, candidate_document_embeddings)
return modulated_query.squeeze(0), modulated_docs.squeeze(0), scores.squeeze(0)
BiHyperNetIR = IMRNN
"""Backward-compatible alias retained for legacy checkpoints and code paths."""