| import torch |
| import torch.nn as nn |
| from transformers import HubertConfig, HubertModel |
| from typing import List |
|
|
| class HuBERTECGConfig(HubertConfig): |
| |
| model_type = "hubert_ecg" |
| |
| def __init__(self, ensemble_length: int = 1, vocab_sizes: List[int] = [100], **kwargs): |
| super().__init__(**kwargs) |
| self.ensemble_length = ensemble_length |
| self.vocab_sizes = vocab_sizes if isinstance(vocab_sizes, list) else [vocab_sizes] |
|
|
| class HuBERTECG(HubertModel): |
| |
| config_class = HuBERTECGConfig |
| |
| def __init__(self, config: HuBERTECGConfig): |
| super().__init__(config) |
| self.config = config |
|
|
| self.pretraining_vocab_sizes = config.vocab_sizes |
| |
| assert config.ensemble_length > 0 and config.ensemble_length == len(config.vocab_sizes), f"ensemble_length {config.ensemble_length} must be equal to len(vocab_sizes) {len(config.vocab_sizes)}" |
|
|
| |
| self.final_proj = nn.ModuleList([nn.Linear(config.hidden_size, config.classifier_proj_size) for _ in range(config.ensemble_length)]) |
|
|
| |
| self.label_embedding = nn.ModuleList([nn.Embedding(vocab_size, config.classifier_proj_size) for vocab_size in config.vocab_sizes]) |
| |
| assert len(self.final_proj) == len(self.label_embedding), f"final_proj and label_embedding must have the same length" |
| |
| def logits(self, transformer_output: torch.Tensor) -> torch.Tensor: |
| |
| |
| |
| projected_outputs = [final_projection(transformer_output) for final_projection in self.final_proj] |
| |
| ensemble_logits = [torch.cosine_similarity( |
| projected_output.unsqueeze(2), |
| label_emb.weight.unsqueeze(0).unsqueeze(0), |
| dim=-1, |
| ) / 0.1 for projected_output, label_emb in zip(projected_outputs, self.label_embedding)] |
| |
| return ensemble_logits |