| | |
| | |
| | ''' |
| | @license: (C) Copyright 2025, Hey. |
| | @author: Hey |
| | @email: sanyuan.hy@alibaba-inc.com |
| | @tel: 137****6540 |
| | @datetime: 2025/12/30 11:35 |
| | @project: lucaone |
| | @file: modeling_lucaone |
| | @desc: modeling_lucaone |
| | ''' |
| | import math |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from transformers import PreTrainedModel |
| | from transformers.modeling_outputs import BaseModelOutput |
| | from transformers.modeling_outputs import MaskedLMOutput |
| | from transformers.modeling_outputs import SequenceClassifierOutput |
| | from transformers.modeling_outputs import TokenClassifierOutput |
| | from typing import Optional, List, Union, Tuple |
| | from .configuration_lucaone import LucaGPLMConfig |
| | try: |
| | from apex.normalization import FusedLayerNorm as _FusedLayerNorm |
| | class LucaGPLM1bLayerNorm(_FusedLayerNorm): |
| | @torch.jit.unused |
| | def forward(self, x): |
| | if not x.is_cuda: |
| | return super().forward(x) |
| | else: |
| | with torch.cuda.device(x.device): |
| | return super().forward(x) |
| | except ImportError: |
| | from torch.nn import LayerNorm as LucaGPLM1bLayerNorm |
| |
|
| | def gelu(x): |
| | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) |
| |
|
| | def rotate_half(x): |
| | x1, x2 = x.chunk(2, dim=-1) |
| | return torch.cat((-x2, x1), dim=-1) |
| |
|
| | def apply_rotary_pos_emb(x, cos, sin): |
| | cos = cos[:, : x.shape[-2], :] |
| | sin = sin[:, : x.shape[-2], :] |
| | return (x * cos) + (rotate_half(x) * sin) |
| |
|
| | class LucaGPLMRotaryEmbedding(torch.nn.Module): |
| | def __init__(self, dim: int, *_, **__): |
| | super().__init__() |
| | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) |
| | self.register_buffer("inv_freq", inv_freq) |
| |
|
| | self._seq_len_cached = None |
| | self._cos_cached = None |
| | self._sin_cached = None |
| |
|
| | def _update_cos_sin_tables(self, x, seq_dimension=1): |
| | seq_len = x.shape[seq_dimension] |
| |
|
| | if (seq_len != self._seq_len_cached or |
| | self._cos_cached is None or |
| | self._sin_cached is None or |
| | self._cos_cached.device != x.device): |
| | self._seq_len_cached = seq_len |
| | t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq) |
| | freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
| | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) |
| |
|
| | self._cos_cached = emb.cos()[None, :, :] |
| | self._sin_cached = emb.sin()[None, :, :] |
| |
|
| | return self._cos_cached, self._sin_cached |
| |
|
| | def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| | self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2) |
| |
|
| | return ( |
| | apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached), |
| | apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached), |
| | ) |
| |
|
| | class LucaGPLMGlobalMaskWeightedAttentionPooling1D(nn.Module): |
| | def __init__(self, embed_size, use_bias=False): |
| | super(LucaGPLMGlobalMaskWeightedAttentionPooling1D, self).__init__() |
| | self.embed_size = embed_size |
| | self.use_bias = use_bias |
| |
|
| | self.W = nn.Parameter(torch.Tensor(self.embed_size)) |
| | nn.init.trunc_normal_(self.W, std=0.01) |
| | if self.use_bias: |
| | self.b = nn.Parameter(torch.Tensor(1)) |
| | nn.init.trunc_normal_(self.b, std=0.01) |
| |
|
| | def forward(self, x, mask=None): |
| | |
| | logits = torch.matmul(x, self.W) |
| | if self.use_bias: |
| | logits += self.b |
| |
|
| | if mask is not None: |
| | attention_probs = nn.Softmax(dim=-1)(logits + (1.0 - mask) * -10000) |
| | else: |
| | attention_probs = nn.Softmax(dim=-1)(logits) |
| | x = torch.sum(torch.unsqueeze(attention_probs, dim=-1) * x, dim=1) |
| | return x |
| |
|
| | def __repr__(self): |
| | return self.__class__.__name__ + ' (' + str(self.embed_size) + (', bias=%r)' % self.use_bias) |
| |
|
| | class LucaGPLMGlobalMaskContextAttentionPooling1D(nn.Module): |
| | def __init__(self, embed_size, units=None, use_additive_bias=False, use_attention_bias=False): |
| | super(LucaGPLMGlobalMaskContextAttentionPooling1D, self).__init__() |
| | self.embed_size = embed_size |
| | self.use_additive_bias = use_additive_bias |
| | self.use_attention_bias = use_attention_bias |
| | self.units = units if units else embed_size |
| |
|
| | self.U = nn.Parameter(torch.Tensor(self.embed_size, self.units)) |
| | self.V = nn.Parameter(torch.Tensor(self.embed_size, self.units)) |
| | if self.use_additive_bias: |
| | self.b1 = nn.Parameter(torch.Tensor(self.units)) |
| | nn.init.trunc_normal_(self.b1, std=0.01) |
| | if self.use_attention_bias: |
| | self.b2 = nn.Parameter(torch.Tensor(1)) |
| | nn.init.trunc_normal_(self.b2, std=0.01) |
| |
|
| | self.c = nn.Parameter(torch.Tensor(self.units)) |
| |
|
| | nn.init.trunc_normal_(self.U, std=0.01) |
| | nn.init.trunc_normal_(self.V, std=0.01) |
| | nn.init.trunc_normal_(self.c, std=0.01) |
| |
|
| | def forward(self, x, mask=None): |
| | |
| | q = torch.matmul(x, self.U) |
| | k = torch.matmul(x, self.V) |
| | if self.use_additive_bias: |
| | h = torch.tanh(q + k + self.b1) |
| | else: |
| | h = torch.tanh(q + k) |
| |
|
| | if self.use_attention_bias: |
| | e = torch.matmul(h, self.c) + self.b2 |
| | else: |
| | e = torch.matmul(h, self.c) |
| | if mask is not None: |
| | attention_probs = nn.Softmax(dim=-1)(e + (1.0 - mask) * -10000) |
| | else: |
| | attention_probs = nn.Softmax(dim=-1)(e) |
| | x = torch.sum(torch.unsqueeze(attention_probs, dim=-1) * x, dim=1) |
| | return x |
| |
|
| | def __repr__(self): |
| | return self.__class__.__name__ + ' (' + str(self.embed_size) + ' -> ' + str(self.units) + ', bias=(%r, %r))' % (self.use_additive_bias, self.use_attention_bias) |
| |
|
| | class LucaGPLMGlobalMaskValueAttentionPooling1D(nn.Module): |
| | def __init__(self, embed_size, units=None, use_additive_bias=False, use_attention_bias=False): |
| | super(LucaGPLMGlobalMaskValueAttentionPooling1D, self).__init__() |
| | self.embed_size = embed_size |
| | self.use_additive_bias = use_additive_bias |
| | self.use_attention_bias = use_attention_bias |
| | self.units = units if units else embed_size |
| |
|
| | self.U = nn.Parameter(torch.Tensor(self.embed_size, self.units)) |
| | self.V = nn.Parameter(torch.Tensor(self.embed_size, self.units)) |
| | if self.use_additive_bias: |
| | self.b1 = nn.Parameter(torch.Tensor(self.units)) |
| | nn.init.trunc_normal_(self.b1, std=0.01) |
| | if self.use_attention_bias: |
| | self.b2 = nn.Parameter(torch.Tensor(self.embed_size)) |
| | nn.init.trunc_normal_(self.b2, std=0.01) |
| |
|
| | self.W = nn.Parameter(torch.Tensor(self.units, self.embed_size)) |
| |
|
| | nn.init.trunc_normal_(self.U, std=0.01) |
| | nn.init.trunc_normal_(self.V, std=0.01) |
| | nn.init.trunc_normal_(self.W, std=0.01) |
| |
|
| | def forward(self, x, mask=None): |
| | |
| | q = torch.matmul(x, self.U) |
| | k = torch.matmul(x, self.V) |
| | if self.use_additive_bias: |
| | h = torch.tanh(q + k + self.b1) |
| | else: |
| | h = torch.tanh(q + k) |
| |
|
| | |
| | if self.use_attention_bias: |
| | e = torch.matmul(h, self.W) + self.b2 |
| | else: |
| | e = torch.matmul(h, self.W) |
| | if mask is not None: |
| | attention_probs = nn.Softmax(dim=1)(e + torch.unsqueeze((1.0 - mask) * -10000, dim=-1)) |
| | else: |
| | attention_probs = nn.Softmax(dim=1)(e) |
| | x = torch.sum(attention_probs * x, dim=1) |
| | return x |
| |
|
| | def __repr__(self): |
| | return self.__class__.__name__ + ' (' + str(self.embed_size) + ' -> ' + str(self.units) + ', bias=(%r, %r))' % (self.use_additive_bias, self.use_attention_bias) |
| |
|
| | class LucaGPLM1LayerNorm(nn.Module): |
| | def __init__(self, hidden_size, eps=1e-12, affine=True): |
| | super().__init__() |
| | self.hidden_size = (hidden_size,) if isinstance(hidden_size, int) else tuple(hidden_size) |
| | self.eps = eps |
| | self.affine = bool(affine) |
| | if self.affine: |
| | self.weight = nn.Parameter(torch.ones(hidden_size)) |
| | self.bias = nn.Parameter(torch.zeros(hidden_size)) |
| | else: |
| | self.weight, self.bias = None, None |
| |
|
| | def forward(self, x): |
| | dims = tuple(-(i + 1) for i in range(len(self.hidden_size))) |
| | means = x.mean(dims, keepdim=True) |
| | x_zeromean = x - means |
| | variances = x_zeromean.pow(2).mean(dims, keepdim=True) |
| | x = x_zeromean / torch.sqrt(variances + self.eps) |
| | if self.affine: |
| | x = (self.weight * x) + self.bias |
| | return x |
| |
|
| | class LucaGPLMMultiheadAttention(nn.Module): |
| | def __init__( |
| | self, |
| | embed_dim, |
| | num_heads, |
| | kdim=None, |
| | vdim=None, |
| | dropout=0.0, |
| | bias=True, |
| | add_bias_kv: bool = False, |
| | add_zero_attn: bool = False, |
| | self_attention: bool = False, |
| | encoder_decoder_attention: bool = False, |
| | use_rotary_embeddings: bool = False, |
| | ): |
| | super().__init__() |
| | self.embed_dim = embed_dim |
| | self.kdim = kdim if kdim is not None else embed_dim |
| | self.vdim = vdim if vdim is not None else embed_dim |
| | self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim |
| |
|
| | self.num_heads = num_heads |
| | self.dropout = dropout |
| | self.head_dim = embed_dim // num_heads |
| | assert ( |
| | self.head_dim * num_heads == self.embed_dim |
| | ), "embed_dim must be divisible by num_heads" |
| | self.scaling = self.head_dim**-0.5 |
| |
|
| | self.self_attention = self_attention |
| | self.encoder_decoder_attention = encoder_decoder_attention |
| |
|
| | assert not self.self_attention or self.qkv_same_dim, ( |
| | "Self-attention requires query, key and " "value to be of the same size" |
| | ) |
| |
|
| | self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias) |
| | self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias) |
| | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
| |
|
| | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
| |
|
| | if add_bias_kv: |
| | self.bias_k = nn.Parameter(torch.Tensor(1, 1, embed_dim)) |
| | self.bias_v = nn.Parameter(torch.Tensor(1, 1, embed_dim)) |
| | else: |
| | self.bias_k = self.bias_v = None |
| |
|
| | self.add_zero_attn = add_zero_attn |
| |
|
| | self.reset_parameters() |
| |
|
| | self.rot_emb = None |
| | if use_rotary_embeddings: |
| | self.rot_emb = LucaGPLMRotaryEmbedding(dim=self.head_dim) |
| |
|
| | def reset_parameters(self): |
| | nn.init.xavier_uniform_(self.k_proj.weight, gain=nn.init.calculate_gain("relu")) |
| | nn.init.xavier_uniform_(self.v_proj.weight, gain=nn.init.calculate_gain("relu")) |
| | nn.init.xavier_uniform_(self.q_proj.weight, gain=nn.init.calculate_gain("relu")) |
| | nn.init.xavier_uniform_(self.out_proj.weight, gain=nn.init.calculate_gain("relu")) |
| | |
| | if self.out_proj.bias is not None: |
| | nn.init.constant_(self.out_proj.bias, 0.0) |
| | if self.bias_k is not None: |
| | nn.init.xavier_normal_(self.bias_k) |
| | if self.bias_v is not None: |
| | nn.init.xavier_normal_(self.bias_v) |
| |
|
| | def forward( |
| | self, |
| | query, |
| | key: Optional[torch.Tensor] = None, |
| | value: Optional[torch.Tensor] = None, |
| | key_padding_mask: Optional[torch.Tensor] = None, |
| | need_weights: bool = True, |
| | attn_mask: Optional[torch.Tensor] = None, |
| | need_head_weights: bool = False, |
| | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
| | if need_head_weights: |
| | need_weights = True |
| |
|
| | tgt_len, bsz, embed_dim = query.size() |
| | assert embed_dim == self.embed_dim |
| |
|
| | if self.self_attention: |
| | q = self.q_proj(query) |
| | k = self.k_proj(query) |
| | v = self.v_proj(query) |
| | else: |
| | assert key is not None and value is not None |
| | q = self.q_proj(query) |
| | k = self.k_proj(key) |
| | v = self.v_proj(value) |
| |
|
| | q *= self.scaling |
| |
|
| | if self.bias_k is not None: |
| | assert self.bias_v is not None |
| | k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) |
| | v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) |
| | if attn_mask is not None: |
| | attn_mask = torch.cat( |
| | [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 |
| | ) |
| | if key_padding_mask is not None: |
| | key_padding_mask = torch.cat( |
| | [ |
| | key_padding_mask, |
| | key_padding_mask.new_zeros(key_padding_mask.size(0), 1), |
| | ], |
| | dim=1, |
| | ) |
| |
|
| | q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) |
| | if k is not None: |
| | k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) |
| | if v is not None: |
| | v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) |
| |
|
| | assert k is not None |
| | src_len = k.size(1) |
| |
|
| | if self.rot_emb: |
| | q, k = self.rot_emb(q, k) |
| |
|
| | attn_weights = torch.bmm(q, k.transpose(1, 2)) |
| | assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] |
| |
|
| | if attn_mask is not None: |
| | attn_mask = attn_mask.unsqueeze(0) |
| | attn_weights += attn_mask |
| |
|
| | if key_padding_mask is not None: |
| | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) |
| | attn_weights = attn_weights.masked_fill( |
| | key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf") |
| | ) |
| | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
| |
|
| | attn_weights_float = F.softmax(attn_weights, dim=-1) |
| | attn_weights = attn_weights_float.type_as(attn_weights) |
| | attn_probs = F.dropout( |
| | attn_weights_float.type_as(attn_weights), |
| | p=self.dropout, |
| | training=self.training, |
| | ) |
| |
|
| | assert v is not None |
| | attn = torch.bmm(attn_probs, v) |
| | assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] |
| | attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) |
| | attn = self.out_proj(attn) |
| |
|
| | attn_weights_output: Optional[torch.Tensor] = None |
| | if need_weights: |
| | attn_weights_output = attn_weights_float.view( |
| | bsz, self.num_heads, tgt_len, src_len |
| | ).type_as(attn).transpose(1, 0) |
| | if not need_head_weights: |
| | |
| | attn_weights_output = attn_weights_output.mean(dim=0) |
| |
|
| | return attn, attn_weights_output |
| |
|
| | class LucaGPLMMultiheadAttentionWithSDPA(nn.Module): |
| | def __init__( |
| | self, |
| | embed_dim, |
| | num_heads, |
| | kdim=None, |
| | vdim=None, |
| | dropout=0.0, |
| | bias=True, |
| | add_bias_kv: bool = False, |
| | add_zero_attn: bool = False, |
| | self_attention: bool = False, |
| | encoder_decoder_attention: bool = False, |
| | use_rotary_embeddings: bool = True, |
| | ): |
| | super().__init__() |
| | self.embed_dim = embed_dim |
| | self.kdim = kdim if kdim is not None else embed_dim |
| | self.vdim = vdim if vdim is not None else embed_dim |
| | self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim |
| |
|
| | self.num_heads = num_heads |
| | self.dropout = dropout |
| | self.head_dim = embed_dim // num_heads |
| | assert ( |
| | self.head_dim * num_heads == self.embed_dim |
| | ), "embed_dim must be divisible by num_heads" |
| | self.scaling = self.head_dim**-0.5 |
| |
|
| | self.self_attention = self_attention |
| | self.encoder_decoder_attention = encoder_decoder_attention |
| |
|
| | assert not self.self_attention or self.qkv_same_dim, ( |
| | "Self-attention requires query, key and " "value to be of the same size" |
| | ) |
| |
|
| | self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias) |
| | self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias) |
| | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
| |
|
| | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
| |
|
| | if add_bias_kv: |
| | self.bias_k = nn.Parameter(torch.Tensor(1, 1, embed_dim)) |
| | self.bias_v = nn.Parameter(torch.Tensor(1, 1, embed_dim)) |
| | else: |
| | self.bias_k = self.bias_v = None |
| |
|
| | self.add_zero_attn = add_zero_attn |
| |
|
| | self.reset_parameters() |
| |
|
| | self.rot_emb = None |
| | if use_rotary_embeddings: |
| | self.rot_emb = LucaGPLMRotaryEmbedding(dim=self.head_dim) |
| |
|
| | def reset_parameters(self): |
| | nn.init.xavier_uniform_(self.k_proj.weight, gain=nn.init.calculate_gain("relu")) |
| | nn.init.xavier_uniform_(self.v_proj.weight, gain=nn.init.calculate_gain("relu")) |
| | nn.init.xavier_uniform_(self.q_proj.weight, gain=nn.init.calculate_gain("relu")) |
| | nn.init.xavier_uniform_(self.out_proj.weight, gain=nn.init.calculate_gain("relu")) |
| | |
| | if self.out_proj.bias is not None: |
| | nn.init.constant_(self.out_proj.bias, 0.0) |
| | if self.bias_k is not None: |
| | nn.init.xavier_normal_(self.bias_k) |
| | if self.bias_v is not None: |
| | nn.init.xavier_normal_(self.bias_v) |
| |
|
| | def forward( |
| | self, |
| | query, |
| | key: Optional[torch.Tensor] = None, |
| | value: Optional[torch.Tensor] = None, |
| | key_padding_mask: Optional[torch.Tensor] = None, |
| | need_weights: bool = True, |
| | attn_mask: Optional[torch.Tensor] = None, |
| | need_head_weights: bool = False, |
| | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
| | |
| | tgt_len, bsz, embed_dim = query.size() |
| | assert embed_dim == self.embed_dim |
| |
|
| | if self.self_attention: |
| | q = self.q_proj(query) |
| | k = self.k_proj(query) |
| | v = self.v_proj(query) |
| | else: |
| | assert key is not None and value is not None |
| | q = self.q_proj(query) |
| | k = self.k_proj(key) |
| | v = self.v_proj(value) |
| |
|
| | if self.bias_k is not None: |
| | assert self.bias_v is not None |
| | k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) |
| | v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) |
| | if attn_mask is not None: |
| | attn_mask = torch.cat( |
| | [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 |
| | ) |
| | if key_padding_mask is not None: |
| | key_padding_mask = torch.cat( |
| | [ |
| | key_padding_mask, |
| | key_padding_mask.new_zeros(key_padding_mask.size(0), 1), |
| | ], |
| | dim=1, |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | if not need_head_weights and hasattr(F, "scaled_dot_product_attention"): |
| | |
| | |
| | q_sdpa = q.view(tgt_len, bsz, self.num_heads, self.head_dim).permute(1, 2, 0, 3) |
| | k_sdpa = k.view(-1, bsz, self.num_heads, self.head_dim).permute(1, 2, 0, 3) |
| | v_sdpa = v.view(-1, bsz, self.num_heads, self.head_dim).permute(1, 2, 0, 3) |
| |
|
| | |
| | if self.rot_emb: |
| | |
| | |
| | q_sdpa, k_sdpa = self.rot_emb(q_sdpa, k_sdpa) |
| |
|
| | |
| | |
| | |
| | sdpa_mask = None |
| | if attn_mask is not None or key_padding_mask is not None: |
| | |
| | target_shape = (bsz, 1, tgt_len, k_sdpa.size(2)) |
| | sdpa_mask = torch.zeros(target_shape, device=q.device, dtype=q.dtype) |
| | |
| | if key_padding_mask is not None: |
| | |
| | sdpa_mask = sdpa_mask.masked_fill( |
| | key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), |
| | float("-inf") |
| | ) |
| | |
| | if attn_mask is not None: |
| | if attn_mask.dim() == 2: |
| | sdpa_mask = sdpa_mask + attn_mask.unsqueeze(0).unsqueeze(0) |
| | elif attn_mask.dim() == 3: |
| | pass |
| | else: |
| | sdpa_mask = sdpa_mask + attn_mask |
| |
|
| | |
| | |
| | attn_output = F.scaled_dot_product_attention( |
| | q_sdpa, |
| | k_sdpa, |
| | v_sdpa, |
| | attn_mask=sdpa_mask, |
| | dropout_p=self.dropout if self.training else 0.0, |
| | is_causal=False |
| | ) |
| |
|
| | |
| | |
| | attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(tgt_len, bsz, self.embed_dim) |
| | |
| | |
| | attn_output = self.out_proj(attn_output) |
| |
|
| | |
| | return attn_output, None |
| |
|
| | q = q * self.scaling |
| | |
| | |
| | |
| | |
| | q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) |
| | if k is not None: |
| | k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) |
| | if v is not None: |
| | v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) |
| |
|
| | assert k is not None |
| | src_len = k.size(1) |
| |
|
| | if self.rot_emb: |
| | q, k = self.rot_emb(q, k) |
| |
|
| | attn_weights = torch.bmm(q, k.transpose(1, 2)) |
| | assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] |
| |
|
| | if attn_mask is not None: |
| | attn_mask = attn_mask.unsqueeze(0) |
| | attn_weights += attn_mask |
| |
|
| | if key_padding_mask is not None: |
| | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) |
| | attn_weights = attn_weights.masked_fill( |
| | key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf") |
| | ) |
| | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
| |
|
| | attn_weights_float = F.softmax(attn_weights, dim=-1) |
| | attn_weights = attn_weights_float.type_as(attn_weights) |
| | attn_probs = F.dropout( |
| | attn_weights_float.type_as(attn_weights), |
| | p=self.dropout, |
| | training=self.training, |
| | ) |
| |
|
| | assert v is not None |
| | attn = torch.bmm(attn_probs, v) |
| | assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] |
| | attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) |
| | attn = self.out_proj(attn) |
| |
|
| | attn_weights_output: Optional[torch.Tensor] = None |
| | if need_weights: |
| | attn_weights_output = attn_weights_float.view( |
| | bsz, self.num_heads, tgt_len, src_len |
| | ).type_as(attn).transpose(1, 0) |
| | if not need_head_weights: |
| | |
| | attn_weights_output = attn_weights_output.mean(dim=0) |
| |
|
| | return attn, attn_weights_output |
| |
|
| | class LucaGPLMRobertaLMHead(nn.Module): |
| | def __init__(self, embed_dim, output_dim): |
| | super().__init__() |
| | self.dense = nn.Linear(embed_dim, embed_dim) |
| | self.layer_norm = LucaGPLM1bLayerNorm(embed_dim) |
| | |
| | self.decoder = nn.Linear(embed_dim, output_dim, bias=False) |
| | self.bias = nn.Parameter(torch.zeros(output_dim)) |
| |
|
| | def forward(self, features): |
| | x = self.dense(features) |
| | x = gelu(x) |
| | x = self.layer_norm(x) |
| | |
| | |
| | x = self.decoder(x) + self.bias |
| | return x |
| |
|
| | class LucaGPLMTransformerLayer(nn.Module): |
| | def __init__( |
| | self, |
| | embed_dim, |
| | ffn_embed_dim, |
| | attention_heads, |
| | add_bias_kv=True, |
| | use_lucagplm1b_layer_norm=False, |
| | use_rotary_embeddings: bool=True, |
| | ): |
| | super().__init__() |
| | self.embed_dim = embed_dim |
| | self.ffn_embed_dim = ffn_embed_dim |
| | self.attention_heads = attention_heads |
| | self.use_rotary_embeddings = use_rotary_embeddings |
| | |
| | LucaGPLMLayerNorm = LucaGPLM1bLayerNorm if use_lucagplm1b_layer_norm else LucaGPLM1LayerNorm |
| |
|
| | self.pre_layer_norm = LucaGPLMLayerNorm(self.embed_dim) |
| |
|
| | self.self_attn = LucaGPLMMultiheadAttentionWithSDPA( |
| | self.embed_dim, |
| | self.attention_heads, |
| | add_bias_kv=add_bias_kv, |
| | add_zero_attn=False, |
| | self_attention=True, |
| | use_rotary_embeddings=self.use_rotary_embeddings, |
| | ) |
| |
|
| | |
| | self.post_layer_norm = LucaGPLMLayerNorm(self.embed_dim) |
| |
|
| | |
| | self.fc1 = nn.Linear(self.embed_dim, self.ffn_embed_dim) |
| |
|
| | |
| | self.fc2 = nn.Linear(self.ffn_embed_dim, self.embed_dim) |
| |
|
| | def forward( |
| | self, |
| | x, |
| | self_attn_mask=None, |
| | self_attn_padding_mask=None, |
| | need_head_weights=False |
| | ): |
| | residual = x |
| | x = self.pre_layer_norm(x) |
| | x, attn = self.self_attn( |
| | query=x, |
| | key=x, |
| | value=x, |
| | key_padding_mask=self_attn_padding_mask, |
| | need_weights=True, |
| | need_head_weights=need_head_weights, |
| | attn_mask=self_attn_mask, |
| | ) |
| | x = residual + x |
| |
|
| | residual = x |
| | x = self.post_layer_norm(x) |
| | x = gelu(self.fc1(x)) |
| | x = self.fc2(x) |
| | x = residual + x |
| |
|
| | return x, attn |
| |
|
| | class LucaGPLMEmbeddings(nn.Module): |
| | def __init__(self, config: LucaGPLMConfig): |
| | super().__init__() |
| | |
| | |
| | self.no_position_embeddings = getattr(config, 'no_position_embeddings', False) |
| | self.no_token_type_embeddings = getattr(config, 'no_token_type_embeddings', False) |
| | self.use_embed_layer_norm = getattr(config, 'use_embed_layer_norm', True) |
| | self.embed_scale = getattr(config, 'embed_scale', 1.0) |
| | self.token_dropout = getattr(config, 'token_dropout', False) |
| | |
| | |
| | self.mask_idx = getattr(config, 'mask_token_id', 4) |
| | self.padding_idx = getattr(config, 'pad_token_id', 0) |
| |
|
| | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) |
| | |
| | |
| | if not self.no_position_embeddings: |
| | self.embed_pos = nn.Embedding(config.max_position_embeddings, config.hidden_size) |
| | else: |
| | self.embed_pos = None |
| | |
| | |
| | if not self.no_token_type_embeddings: |
| | self.embed_type = nn.Embedding(config.type_vocab_size, config.hidden_size) |
| | else: |
| | self.embed_type = None |
| | |
| | |
| | if self.use_embed_layer_norm: |
| | self.embed_layer_norm = LucaGPLM1bLayerNorm(config.hidden_size) |
| | else: |
| | self.embed_layer_norm = None |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | token_type_ids: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| | input_shape = input_ids.size() |
| | seq_length = input_shape[1] |
| |
|
| | |
| | inputs_embeds = self.embed_scale * self.embed_tokens(input_ids) |
| | |
| | |
| | if not self.no_position_embeddings and self.embed_pos is not None: |
| | if position_ids is None: |
| | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) |
| | position_ids = position_ids.unsqueeze(0).expand(input_shape) |
| | position_embeddings = self.embed_scale * self.embed_pos(position_ids) |
| | inputs_embeds = inputs_embeds + position_embeddings |
| |
|
| | |
| | if not self.no_token_type_embeddings and self.embed_type is not None: |
| | if token_type_ids is None: |
| | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=input_ids.device) |
| | token_type_embeddings = self.embed_scale * self.embed_type(token_type_ids) |
| | inputs_embeds = inputs_embeds + token_type_embeddings |
| | |
| | |
| | if self.use_embed_layer_norm and self.embed_layer_norm is not None: |
| | embeddings = self.embed_layer_norm(inputs_embeds) |
| | else: |
| | embeddings = inputs_embeds |
| |
|
| | |
| | if self.token_dropout and self.training: |
| | |
| | embeddings = embeddings.masked_fill((input_ids == self.mask_idx).unsqueeze(-1), 0.0) |
| | |
| | |
| | mask_ratio_train = 0.15 * 0.8 |
| | padding_mask = input_ids.eq(self.padding_idx) |
| | src_lengths = (~padding_mask).sum(-1) |
| | mask_ratio_observed = (input_ids == self.mask_idx).sum(-1).to(embeddings.dtype) / src_lengths |
| | embeddings = embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None] |
| |
|
| | |
| | padding_mask = input_ids.eq(self.padding_idx) |
| | if padding_mask.any(): |
| | embeddings = embeddings * (1 - padding_mask.unsqueeze(-1).type_as(embeddings)) |
| |
|
| | return embeddings |
| |
|
| | class LucaGPLMEncoder(nn.Module): |
| | def __init__(self, config: LucaGPLMConfig): |
| | super().__init__() |
| |
|
| | self.layers = nn.ModuleList([ |
| | LucaGPLMTransformerLayer( |
| | config.hidden_size, |
| | 4 * config.hidden_size, |
| | config.num_attention_heads, |
| | add_bias_kv=False, |
| | use_lucagplm1b_layer_norm=True, |
| | use_rotary_embeddings=True, |
| | ) |
| | for _ in range(config.num_hidden_layers) |
| | ]) |
| | |
| | self.use_last_layer_norm = getattr(config, 'use_last_layer_norm', True) |
| | if self.use_last_layer_norm: |
| | self.last_layer_norm = LucaGPLM1bLayerNorm(config.hidden_size) |
| | else: |
| | self.last_layer_norm = None |
| |
|
| | self.padding_idx = config.pad_token_id |
| | self.gradient_checkpointing = False |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | output_attentions: bool = False, |
| | output_hidden_states: bool = False, |
| | return_dict: bool = True, |
| | need_head_weights: bool = False, |
| | repr_layers: Optional[List[int]] = None, |
| | use_last_layer_norm: bool = True, |
| | ) -> Union[Tuple[torch.Tensor], BaseModelOutput]: |
| | all_hidden_states = () if output_hidden_states else None |
| | all_attentions = () if output_attentions else None |
| | |
| | if repr_layers is None: |
| | repr_layers = [-1] |
| | |
| | |
| | layer_size = len(self.layers) |
| | repr_layers = [(i + layer_size + 1) % (layer_size + 1) for i in repr_layers] |
| | repr_layers = set(repr_layers) |
| | hidden_representations = {} |
| |
|
| | |
| | if attention_mask is None: |
| | padding_mask = hidden_states.new_zeros(hidden_states.shape[:2]).eq(self.padding_idx) |
| | else: |
| | |
| | padding_mask = attention_mask.eq(0) |
| |
|
| | |
| | if 0 in repr_layers: |
| | hidden_representations[0] = hidden_states |
| |
|
| | |
| | hidden_states = hidden_states.transpose(0, 1) |
| | |
| | if not padding_mask.any(): |
| | padding_mask = None |
| |
|
| | |
| | if need_head_weights or output_attentions: |
| | attn_weights = [] |
| |
|
| | for layer_idx, layer_module in enumerate(self.layers): |
| | if output_hidden_states: |
| | all_hidden_states = all_hidden_states + (hidden_states.transpose(0, 1),) |
| |
|
| | if self.gradient_checkpointing and self.training: |
| | layer_outputs = self._gradient_checkpointing_func( |
| | layer_module.__call__, |
| | hidden_states, |
| | None, |
| | padding_mask, |
| | need_head_weights or output_attentions, |
| | ) |
| | else: |
| | layer_outputs = layer_module( |
| | hidden_states, |
| | self_attn_mask=None, |
| | self_attn_padding_mask=padding_mask, |
| | need_head_weights=need_head_weights or output_attentions, |
| | ) |
| |
|
| | hidden_states, attn = layer_outputs |
| |
|
| | if (layer_idx + 1) in repr_layers: |
| | hidden_representations[layer_idx + 1] = hidden_states.transpose(0, 1) |
| |
|
| | if need_head_weights or output_attentions: |
| | |
| | attn_weights.append(attn.transpose(1, 0)) |
| |
|
| | |
| | if self.last_layer_norm is not None and use_last_layer_norm: |
| | hidden_states = self.last_layer_norm(hidden_states) |
| |
|
| | |
| | hidden_states = hidden_states.transpose(0, 1) |
| |
|
| | |
| | if (layer_idx + 1) in repr_layers: |
| | hidden_representations[layer_idx + 1] = hidden_states |
| |
|
| | if output_hidden_states: |
| | all_hidden_states = all_hidden_states + (hidden_states,) |
| |
|
| | if need_head_weights or output_attentions: |
| | |
| | if attn_weights: |
| | |
| | all_attentions = torch.stack(attn_weights, 1) |
| | if padding_mask is not None: |
| | attention_mask_expanded = 1 - padding_mask.type_as(all_attentions) |
| | attention_mask_expanded = attention_mask_expanded.unsqueeze(1) * attention_mask_expanded.unsqueeze(2) |
| | all_attentions = all_attentions * attention_mask_expanded[:, None, None, :, :] |
| | |
| | if not output_attentions: |
| | all_attentions = None |
| |
|
| | if not return_dict: |
| | return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) |
| |
|
| | return BaseModelOutput( |
| | last_hidden_state=hidden_states, |
| | hidden_states=all_hidden_states, |
| | attentions=all_attentions, |
| | ) |
| |
|
| | class LucaGPLMPreTrainedModel(PreTrainedModel): |
| | config_class = LucaGPLMConfig |
| | base_model_prefix = "lucaone" |
| | supports_gradient_checkpointing = True |
| | _no_split_modules = ["LucaGPLMTransformerLayer"] |
| |
|
| | def _init_weights(self, module): |
| | if isinstance(module, nn.Linear): |
| | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| | if module.bias is not None: |
| | module.bias.data.zero_() |
| | elif isinstance(module, nn.Embedding): |
| | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| | if module.padding_idx is not None: |
| | module.weight.data[module.padding_idx].zero_() |
| | elif isinstance(module, (LucaGPLM1LayerNorm, LucaGPLM1bLayerNorm)): |
| | if hasattr(module, 'weight') and module.weight is not None: |
| | module.weight.data.fill_(1.0) |
| | if hasattr(module, 'bias') and module.bias is not None: |
| | module.bias.data.zero_() |
| |
|
| | class LucaGPLMModel(LucaGPLMPreTrainedModel): |
| | """ |
| | The LucaGPLM model for extracting sequence representations and optionally predicting contacts. |
| | Based on the original LucaGPLM implementation but restructured to use modern transformers architecture. |
| | """ |
| | |
| | def __init__(self, config: LucaGPLMConfig): |
| | super().__init__(config) |
| | self.config = config |
| | self.embeddings = LucaGPLMEmbeddings(self.config) |
| | self.encoder = LucaGPLMEncoder(self.config) |
| | self.post_init() |
| |
|
| | def get_input_embeddings(self): |
| | return self.embeddings.embed_tokens |
| |
|
| | def set_input_embeddings(self, value): |
| | self.embeddings.embed_tokens = value |
| |
|
| | def forward( |
| | self, |
| | input_ids: Optional[torch.Tensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | token_type_ids: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.Tensor] = None, |
| | inputs_embeds: Optional[torch.Tensor] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_contacts: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | need_head_weights: Optional[bool] = None, |
| | repr_layers: Optional[List[int]] = None, |
| | use_last_layer_norm: Optional[bool] = True, |
| | ) -> Union[Tuple[torch.Tensor], BaseModelOutput]: |
| | |
| | output_attentions = output_attentions if output_attentions is not None else getattr(self.config, 'output_attentions', False) |
| | output_hidden_states = output_hidden_states if output_hidden_states is not None else getattr(self.config, 'output_hidden_states', False) |
| | return_contacts = return_contacts if return_contacts is not None else False |
| | return_dict = return_dict if return_dict is not None else getattr(self.config, 'use_return_dict', True) |
| | need_head_weights = need_head_weights if need_head_weights is not None else return_contacts |
| | use_last_layer_norm = use_last_layer_norm if use_last_layer_norm is not None else True |
| |
|
| | |
| | if return_contacts: |
| | output_attentions = True |
| | need_head_weights = True |
| |
|
| | if input_ids is not None and inputs_embeds is not None: |
| | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
| | elif input_ids is not None: |
| | input_shape = input_ids.size() |
| | elif inputs_embeds is not None: |
| | input_shape = inputs_embeds.size()[:-1] |
| | else: |
| | raise ValueError("You have to specify either input_ids or inputs_embeds") |
| |
|
| | batch_size, seq_length = input_shape |
| | device = input_ids.device if input_ids is not None else inputs_embeds.device |
| |
|
| | |
| | if attention_mask is None: |
| | attention_mask = torch.ones(input_shape, device=device) |
| |
|
| | |
| | if inputs_embeds is None: |
| | embedding_output = self.embeddings( |
| | input_ids=input_ids, |
| | position_ids=position_ids, |
| | token_type_ids=token_type_ids, |
| | ) |
| | else: |
| | embedding_output = inputs_embeds |
| |
|
| | |
| | encoder_outputs = self.encoder( |
| | embedding_output, |
| | attention_mask=attention_mask, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | need_head_weights=need_head_weights, |
| | repr_layers=repr_layers, |
| | use_last_layer_norm=use_last_layer_norm, |
| | ) |
| | |
| | sequence_output = encoder_outputs[0] |
| | |
| | |
| | contacts = None |
| | if return_contacts and encoder_outputs.attentions is not None: |
| | |
| | |
| | attentions = encoder_outputs.attentions |
| | |
| | averaged_attention = attentions.mean(dim=(1, 2)) |
| | contacts = (averaged_attention + averaged_attention.transpose(-1, -2)) / 2 |
| | |
| | |
| | if attention_mask is not None: |
| | |
| | seq_lens = attention_mask.sum(dim=1) |
| | |
| |
|
| | if not return_dict: |
| | outputs = (sequence_output, ) + encoder_outputs[1:] |
| | if contacts is not None: |
| | outputs = outputs + (contacts,) |
| | return outputs |
| |
|
| | |
| | output = BaseModelOutput( |
| | last_hidden_state=sequence_output, |
| | hidden_states=encoder_outputs.hidden_states, |
| | attentions=encoder_outputs.attentions, |
| | ) |
| | |
| | |
| | if contacts is not None: |
| | output.contacts = contacts |
| | |
| | return output |
| |
|
| | class LucaGPLMForMaskedLM(LucaGPLMPreTrainedModel): |
| | def __init__(self, config): |
| | super().__init__(config) |
| | |
| | self.lucaone = LucaGPLMModel(config) |
| |
|
| | |
| | self.lm_head = LucaGPLMRobertaLMHead( |
| | embed_dim=config.hidden_size, |
| | output_dim=config.vocab_size |
| | ) |
| | self._tied_weights_keys = [ |
| | "lucaone.embeddings.embed_tokens.weight", |
| | "lm_head.decoder.weight" |
| | ] |
| | |
| | self.post_init() |
| |
|
| | def get_input_embeddings(self): |
| | return self.lucaone.get_input_embeddings() |
| |
|
| | def get_output_embeddings(self): |
| | return self.lm_head.decoder |
| |
|
| | def set_output_embeddings(self, new_embeddings): |
| | self.lm_head.decoder = new_embeddings |
| |
|
| | def forward( |
| | self, |
| | input_ids: Optional[torch.Tensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | token_type_ids: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.Tensor] = None, |
| | labels: Optional[torch.Tensor] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | ) -> Union[Tuple, MaskedLMOutput]: |
| |
|
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | |
| | outputs = self.lucaone( |
| | input_ids, |
| | attention_mask=attention_mask, |
| | token_type_ids=token_type_ids, |
| | position_ids=position_ids, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | ) |
| |
|
| | sequence_output = outputs[0] |
| |
|
| | |
| | prediction_scores = self.lm_head(sequence_output) |
| |
|
| | masked_lm_loss = None |
| | if labels is not None: |
| | |
| | loss_fct = nn.CrossEntropyLoss(ignore_index=-100) |
| | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) |
| |
|
| | if not return_dict: |
| | output = (prediction_scores,) + outputs[2:] |
| | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output |
| |
|
| | return MaskedLMOutput( |
| | loss=masked_lm_loss, |
| | logits=prediction_scores, |
| | hidden_states=outputs.hidden_states, |
| | attentions=outputs.attentions, |
| | ) |
| |
|
| | class LucaGPLMForSequenceClassification(LucaGPLMPreTrainedModel): |
| | def __init__(self, config): |
| | if hasattr(config, "classifier_num_labels") and config.classifier_num_labels > 0: |
| | config.num_labels = config.classifier_num_labels |
| | super().__init__(config) |
| | self.num_labels = config.num_labels |
| | self.task_level = config.task_level |
| | self.task_type = config.task_type |
| | assert self.task_level == "seq_level" |
| | self.classifier_pooling_type = config.classifier_pooling_type |
| | self.classifier_loss_type = config.classifier_loss_type |
| | self.classifier_loss_reduction = config.classifier_loss_reduction |
| | self.classifier_pos_weight = config.classifier_pos_weight |
| | self.classifier_weight = config.classifier_weight |
| | self.lucaone = LucaGPLMModel(config) |
| | if self.classifier_pooling_type == "value_attention": |
| | self.pooler = LucaGPLMGlobalMaskValueAttentionPooling1D(config.hidden_size) |
| | elif self.classifier_pooling_type == "context_attention": |
| | self.pooler = LucaGPLMGlobalMaskContextAttentionPooling1D(embed_size=config.hidden_size) |
| | elif self.classifier_pooling_type == "weighted_attention": |
| | self.pooler = LucaGPLMGlobalMaskWeightedAttentionPooling1D(embed_size=config.hidden_size) |
| | else: |
| | self.pooler = None |
| | self.dropout = nn.Dropout(config.classifier_dropout_prob) |
| |
|
| | self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
| | if self.task_type == "multi_class": |
| | weight = None |
| | if self.classifier_weight: |
| | if isinstance(self.classifier_weight, str) or isinstance(self.classifier_weight, int): |
| | weight = torch.tensor([float(self.classifier_weight)] * self.num_labels, dtype=torch.float32) |
| | elif isinstance(self.classifier_weight, float): |
| | weight = torch.tensor([self.classifier_weight] * self.num_labels, dtype=torch.float32) |
| | elif isinstance(self.classifier_weight, list): |
| | weight = torch.tensor(self.classifier_weight, dtype=torch.float32) |
| | self.loss_fct = nn.CrossEntropyLoss(weight=weight, reduction="mean") |
| | elif self.task_type == "binary_class": |
| | pos_weight = None |
| | if self.classifier_pos_weight: |
| | if isinstance(self.classifier_pos_weight, str) or isinstance(self.classifier_pos_weight, int): |
| | pos_weight = torch.tensor([float(self.classifier_pos_weight)], dtype=torch.float32) |
| | elif isinstance(self.classifier_pos_weight, float): |
| | pos_weight = torch.tensor([self.classifier_pos_weight], dtype=torch.float32) |
| | self.loss_fct = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction="mean") |
| | elif self.task_type == "regression": |
| | if self.classifier_loss_type == "mae": |
| | self.loss_fct = nn.L1Loss(reduction="mean") |
| | else: |
| | self.loss_fct = nn.MSELoss(reduction="mean") |
| | elif self.task_type == "multi_label": |
| | pos_weight = None |
| | if self.classifier_pos_weight: |
| | if isinstance(self.classifier_pos_weight, str) or isinstance(self.classifier_pos_weight, int): |
| | pos_weight = torch.tensor([float(self.classifier_pos_weight)] * self.num_labels, dtype=torch.float32) |
| | elif isinstance(self.classifier_pos_weight, float): |
| | pos_weight = torch.tensor([self.classifier_pos_weight] * self.num_labels, dtype=torch.float32) |
| | self.loss_fct = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction=self.classifier_loss_reduction) |
| | else: |
| | raise ValueError("Invalid task type: %s" % self.task_type) |
| | self.post_init() |
| |
|
| | def forward( |
| | self, |
| | input_ids=None, |
| | token_type_ids=None, |
| | attention_mask=None, |
| | labels=None, |
| | return_dict=None |
| | ): |
| | return_dict = return_dict if return_dict is not None else getattr(self.config, 'use_return_dict', True) |
| | outputs = self.lucaone( |
| | input_ids, |
| | token_type_ids=token_type_ids, |
| | attention_mask=attention_mask, |
| | return_dict=return_dict |
| | ) |
| | if self.pooler is not None: |
| | pooled_output = self.pooler(outputs[0]) |
| | elif self.classifier_pooling_type == "cls": |
| | |
| | pooled_output = outputs[0][:, 0, :] |
| | elif self.classifier_pooling_type == "mean": |
| | pooled_output = outputs[0].mean(dim=1) |
| | else: |
| | raise ValueError("Invalid classifier pooling type: %s" % self.classifier_pooling_type) |
| |
|
| | pooled_output = self.dropout(pooled_output) |
| | logits = self.classifier(pooled_output) |
| |
|
| | loss = None |
| | if labels is not None: |
| | if self.task_type == "multi_class": |
| | loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
| | elif self.task_type == "binary_class": |
| | loss = self.loss_fct(logits.view(-1), labels.view(-1).float()) |
| | elif self.task_type == "regression": |
| | loss = self.loss_fct(logits.view(-1), labels.view(-1)) |
| | elif self.task_type == "multi_label": |
| | loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels).float()) |
| | else: |
| | raise ValueError("Invalid task type: %s" % self.task_type) |
| |
|
| | if not return_dict: |
| | output = (logits,) + outputs[1:] |
| | return ((loss,) + output) if loss is not None else output |
| |
|
| | return SequenceClassifierOutput(loss=loss, logits=logits) |
| |
|
| | class LucaGPLMForTokenClassification(LucaGPLMPreTrainedModel): |
| | def __init__(self, config): |
| | if hasattr(config, "classifier_num_labels") and config.classifier_num_labels > 0: |
| | config.num_labels = config.classifier_num_labels |
| | super().__init__(config) |
| | self.num_labels = config.num_labels |
| | self.task_level = config.task_level |
| | self.task_type = config.task_type |
| | assert self.task_level == "token_level" |
| | self.classifier_pooling_type = config.classifier_pooling_type |
| | self.classifier_loss_type = config.classifier_loss_type |
| | self.classifier_loss_reduction = config.classifier_loss_reduction |
| | self.classifier_pos_weight = config.classifier_pos_weight |
| | self.classifier_weight = config.classifier_weight |
| | self.lucaone = LucaGPLMModel(config) |
| | self.dropout = nn.Dropout(config.classifier_dropout_prob) |
| | self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
| | if self.task_type == "multi_class": |
| | weight = None |
| | if self.classifier_weight: |
| | |
| | if isinstance(self.classifier_weight, str) or isinstance(self.classifier_weight, int): |
| | weight = torch.tensor([float(self.classifier_weight)] * self.num_labels, dtype=torch.float32) |
| | elif isinstance(self.classifier_weight, float): |
| | weight = torch.tensor([self.classifier_weight] * self.num_labels, dtype=torch.float32) |
| | elif isinstance(self.classifier_weight, list): |
| | weight = torch.tensor(self.classifier_weight, dtype=torch.float32) |
| | self.loss_fct = nn.CrossEntropyLoss(weight=weight, reduction="mean") |
| | elif self.task_type == "binary_class": |
| | pos_weight = None |
| | if self.classifier_pos_weight: |
| | if isinstance(self.classifier_pos_weight, str) or isinstance(self.classifier_pos_weight, int): |
| | pos_weight = torch.tensor([float(self.classifier_pos_weight)], dtype=torch.float32) |
| | elif isinstance(self.classifier_pos_weight, float): |
| | pos_weight = torch.tensor([float(self.classifier_pos_weight)], dtype=torch.float32) |
| | self.loss_fct = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction="mean") |
| | elif self.task_type == "regression": |
| | if self.classifier_loss_type == "mae": |
| | self.loss_fct = nn.L1Loss(reduction="mean") |
| | else: |
| | self.loss_fct = nn.MSELoss(reduction="mean") |
| | elif self.task_type == "multi_label": |
| | pos_weight = None |
| | if self.classifier_pos_weight: |
| | if isinstance(self.classifier_pos_weight, str) or isinstance(self.classifier_pos_weight, int): |
| | pos_weight = torch.tensor([float(self.classifier_pos_weight)] * self.num_labels, dtype=torch.float32) |
| | elif isinstance(self.classifier_pos_weight, float): |
| | pos_weight = torch.tensor([self.classifier_pos_weight] * self.num_labels, dtype=torch.float32) |
| | self.loss_fct = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction=self.classifier_loss_reduction) |
| | else: |
| | raise ValueError("Invalid task type: %s" % self.task_type) |
| | self.post_init() |
| |
|
| | def forward( |
| | self, |
| | input_ids=None, |
| | token_type_ids=None, |
| | attention_mask=None, |
| | labels=None, |
| | return_dict=None |
| | ): |
| | return_dict = return_dict if return_dict is not None else getattr(self.config, 'use_return_dict', True) |
| | outputs = self.lucaone( |
| | input_ids, |
| | token_type_ids=token_type_ids, |
| | attention_mask=attention_mask, |
| | return_dict=return_dict |
| | ) |
| | sequence_output = outputs[0][:, 1:-1, :] |
| |
|
| | sequence_output = self.dropout(sequence_output) |
| | logits = self.classifier(sequence_output) |
| |
|
| | loss = None |
| | if labels is not None: |
| | if self.task_type == "multi_class": |
| | loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
| | elif self.task_type == "binary_class": |
| | loss = self.loss_fct(logits.view(-1), labels.view(-1).float()) |
| | elif self.task_type == "regression": |
| | loss = self.loss_fct(logits.view(-1), labels.view(-1)) |
| | elif self.task_type == "multi_label": |
| | loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels).float()) |
| | else: |
| | raise ValueError("Invalid task type: %s" % self.task_type) |
| |
|
| |
|
| | if not return_dict: |
| | output = (logits,) + outputs[1:] |
| | return ((loss,) + output) if loss is not None else output |
| | return TokenClassifierOutput(loss=loss, logits=logits) |
| |
|
| | __all__ = [ |
| | "LucaGPLMModel", |
| | "LucaGPLMPreTrainedModel", |
| | "LucaGPLMForMaskedLM", |
| | "LucaGPLMForSequenceClassification", |
| | "LucaGPLMForTokenClassification" |
| | ] |
| |
|