| import torch | |
| from torch import nn | |
| class DummyPooling(nn.Module): | |
| def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]: | |
| return {'sentence_embedding': features['token_embeddings']} | |
| def save(self, save_dir: str, **kwargs) -> None: | |
| pass | |
| def load(load_dir: str, **kwargs) -> "DummyPooling": | |
| return DummyPooling() |