Text Classification
Transformers
Safetensors
English
emcoder
emotion-recognition
bayesian-deep-learning
mc-dropout
uncertainty-quantification
multi-label-classification
custom_code
Eval Results (legacy)
Instructions to use yezdata/EmCoder with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use yezdata/EmCoder with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-classification", model="yezdata/EmCoder", trust_remote_code=True)# Load model directly from transformers import AutoModelForSequenceClassification model = AutoModelForSequenceClassification.from_pretrained("yezdata/EmCoder", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .rope_embeddings import RotaryEmbedding | |
| from transformers import PreTrainedModel, AutoConfig, AutoModelForSequenceClassification | |
| from transformers.modeling_outputs import SequenceClassifierOutput | |
| from .configuration_emcoder import EmCoderConfig | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| variance = x.pow(2).mean(-1, keepdim=True) | |
| return x * torch.rsqrt(variance + self.eps) * self.weight | |
| class SwiGLU(nn.Module): | |
| def __init__(self, d_model: int, d_ffn: int): | |
| super().__init__() | |
| self.wi = nn.Linear(d_model, 2 * d_ffn, bias=False) | |
| self.wo = nn.Linear(d_ffn, d_model, bias=False) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x1, x2 = self.wi(x).chunk(2, dim=-1) | |
| return self.wo(x1 * F.silu(x2)) | |
| class EmCoderEncoderLayer(nn.Module): | |
| """Custom Pre-LN Transformer Encoder Layer with RoPE and FlashAttention.""" | |
| def __init__(self, config: EmCoderConfig, rope: RotaryEmbedding): | |
| super().__init__() | |
| self.n_head = config.n_head | |
| self.d_head = config.d_model // config.n_head | |
| self.rope = rope | |
| # Attention projections | |
| self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False) | |
| self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False) | |
| self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False) | |
| self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False) | |
| self.ln1 = RMSNorm(config.d_model) | |
| self.ln2 = RMSNorm(config.d_model) | |
| self.ffn = SwiGLU(config.d_model, config.d_ffn) | |
| self.dropout = nn.Dropout(config.dropout) | |
| # mark for initialization | |
| self.out_proj._is_residual = True | |
| self.ffn.wo._is_residual = True | |
| def forward(self, x: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: | |
| # MULTI-HEAD ATTENTION | |
| residual = x | |
| nx = self.ln1(x) | |
| B, S, _ = nx.shape | |
| # Projections -> (B, H, S, D_head) | |
| q = self.q_proj(nx).view(B, S, self.n_head, self.d_head).transpose(1, 2) | |
| k = self.k_proj(nx).view(B, S, self.n_head, self.d_head).transpose(1, 2) | |
| v = self.v_proj(nx).view(B, S, self.n_head, self.d_head).transpose(1, 2) | |
| q = self.rope.rotate_queries_or_keys(q) | |
| k = self.rope.rotate_queries_or_keys(k) | |
| attn_out = F.scaled_dot_product_attention( | |
| q, | |
| k, | |
| v, | |
| attn_mask=attn_mask, | |
| dropout_p=self.dropout.p if self.dropout.training else 0.0, | |
| ) | |
| # Join heads -> (B, S, D_model) | |
| attn_out = attn_out.transpose(1, 2).contiguous().view(B, S, -1) | |
| x = residual + self.dropout(self.out_proj(attn_out)) | |
| x = x + self.dropout(self.ffn(self.ln2(x))) | |
| return x | |
| class EmCoderEncoder(nn.Module): | |
| """The core encoder architecture of EmCoder Transformer.""" | |
| def __init__(self, config: EmCoderConfig): | |
| super().__init__() | |
| self.token_embedding = nn.Embedding(config.vocab_size, config.d_model) | |
| self.embed_norm = RMSNorm(config.d_model) | |
| self.dropout = nn.Dropout(config.dropout) | |
| self.rope = RotaryEmbedding(dim=config.d_model // config.n_head) | |
| self.layers = nn.ModuleList( | |
| [EmCoderEncoderLayer(config, self.rope) for _ in range(config.n_layers)] | |
| ) | |
| self.final_norm = RMSNorm(config.d_model) | |
| def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: | |
| """Standard forward pass through the encoder.""" | |
| x = self.token_embedding(x) | |
| x = self.embed_norm(x) | |
| x = self.dropout(x) | |
| B, S = mask.shape | |
| attn_mask = mask.view(B, 1, 1, S).to(dtype=torch.bool) | |
| for layer in self.layers: | |
| x = layer(x, attn_mask) | |
| return self.final_norm(x) | |
| class EmCoder(PreTrainedModel): | |
| """The full EmCoder model, including the backbone encoder and the classification head.""" | |
| config_class = EmCoderConfig | |
| def __init__(self, config: EmCoderConfig): | |
| super().__init__(config) | |
| self.encoder = EmCoderEncoder(config) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(config.d_model, config.d_model), | |
| nn.GELU(), | |
| nn.Dropout(config.dropout), | |
| nn.Linear(config.d_model, config.num_labels), | |
| ) | |
| self.post_init() | |
| def _init_weights(self, module: nn.Module) -> None: | |
| if isinstance(module, nn.Linear): | |
| # scale down the init for residual connections | |
| if getattr(module, "_is_residual", False): | |
| std = 0.02 / ((2 * self.config.n_layers) ** 0.5) | |
| else: | |
| std = 0.02 | |
| nn.init.trunc_normal_(module.weight, std=std) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.Embedding): | |
| nn.init.trunc_normal_(module.weight, std=0.02) | |
| elif isinstance(module, RMSNorm): | |
| nn.init.ones_(module.weight) | |
| def _set_mc_dropout(self, active: bool = True): | |
| for m in self.modules(): | |
| if isinstance(m, nn.Dropout): | |
| m.train(active) | |
| def _masked_mean_pooling( | |
| features: torch.Tensor, mask: torch.Tensor | |
| ) -> torch.Tensor: | |
| mask = mask.unsqueeze(-1) # (B, S, 1) | |
| masked_features = features * mask # (B, S, D) | |
| sum_masked_features = masked_features.sum(dim=1) # (B, D) | |
| count_tokens = torch.clamp(mask.sum(dim=1), min=1e-9) # (B, 1) | |
| return sum_masked_features / count_tokens # (B, D) | |
| def mc_forward( | |
| self, | |
| input_ids: torch.Tensor | None = None, | |
| attention_mask: torch.Tensor | None = None, | |
| labels: torch.Tensor | None = None, | |
| n_samples: int = 10, | |
| max_batch_size: int | None = None, | |
| return_dict: bool | None = None, | |
| **kwargs, | |
| ) -> tuple[torch.Tensor, ...] | SequenceClassifierOutput: | |
| """ | |
| Performs Monte Carlo Dropout inference to quantify uncertainty. | |
| Args: | |
| input_ids: Input token IDs of shape (B, S). | |
| attention_mask: Attention mask of shape (B, S). | |
| n_samples: Total number of Monte Carlo samples. | |
| max_batch_size: Maximum number of samples in one forward pass. | |
| Returns: | |
| Logits of shape (n_samples, B, num_labels). | |
| """ | |
| return_dict = return_dict if return_dict is not None else True | |
| x = input_ids if input_ids is not None else kwargs.get("input_ids") | |
| mask = attention_mask if attention_mask is not None else kwargs.get("attention_mask") | |
| if x is None or mask is None: | |
| raise ValueError("input_ids and attention_mask must be provided") | |
| if max_batch_size is None: | |
| max_batch_size = n_samples | |
| B, S = x.shape | |
| num_labels = self.classifier[-1].out_features | |
| all_logits = torch.empty((n_samples, B, num_labels), device=x.device) | |
| is_training = self.training | |
| self._set_mc_dropout(active=True) | |
| try: | |
| with torch.no_grad(): | |
| for i in range(0, n_samples, max_batch_size): | |
| batch_samples = min(max_batch_size, n_samples - i) | |
| x_stacked = x.repeat(batch_samples, 1) # (batch_samples * B, S) | |
| mask_stacked = mask.repeat(batch_samples, 1) # (batch_samples * B, S) | |
| features = self.encoder( | |
| x_stacked, mask_stacked | |
| ) # (batch_samples * B, S, D) | |
| pooled = self._masked_mean_pooling(features, mask_stacked) | |
| logits = self.classifier(pooled) # (n_samples * B, num_labels) | |
| all_logits[i : i + batch_samples] = logits.view(batch_samples, B, -1) | |
| finally: | |
| self._set_mc_dropout(active=is_training) | |
| loss = None | |
| if labels is not None: | |
| loss_fct = nn.BCEWithLogitsLoss() | |
| logits_mean = all_logits.mean(dim=0) # (B, num_labels) | |
| target_labels = labels.to(dtype=all_logits.dtype).view(logits_mean.shape) | |
| loss = loss_fct(logits_mean, target_labels) | |
| if not return_dict: | |
| output = (all_logits,) | |
| return ((loss,) + output) if loss is not None else output | |
| return SequenceClassifierOutput( | |
| loss=loss, | |
| logits=all_logits, | |
| hidden_states=None, | |
| attentions=None, | |
| ) | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor | None = None, | |
| attention_mask: torch.Tensor | None = None, | |
| labels: torch.Tensor | None = None, | |
| return_dict: bool | None = None, | |
| **kwargs, | |
| ) -> tuple[torch.Tensor, ...] | SequenceClassifierOutput: | |
| """Standard forward pass without MC Dropout.""" | |
| return_dict = return_dict if return_dict is not None else True | |
| x = input_ids if input_ids is not None else kwargs.get("input_ids") | |
| mask = attention_mask if attention_mask is not None else kwargs.get("attention_mask") | |
| if x is None or mask is None: | |
| raise ValueError("input_ids and attention_mask must be provided") | |
| features = self.encoder(x, mask) | |
| pooled = self._masked_mean_pooling(features, mask) | |
| logits = self.classifier(pooled) | |
| loss = None | |
| if labels is not None: | |
| loss_fct = nn.BCEWithLogitsLoss() | |
| target_labels = labels.to(dtype=logits.dtype).view(logits.shape) | |
| loss = loss_fct(logits, target_labels) | |
| if not return_dict: | |
| output = (logits,) | |
| return ((loss,) + output) if loss is not None else output | |
| return SequenceClassifierOutput( | |
| loss=loss, | |
| logits=logits, | |
| hidden_states=None, | |
| attentions=None, | |
| ) |