|
|
|
|
|
|
|
|
''' |
|
|
@license: (C) Copyright 2025, Hey. |
|
|
@author: Hey |
|
|
@email: sanyuan.hy@alibaba-inc.com |
|
|
@tel: 137****6540 |
|
|
@datetime: 2025/12/30 11:35 |
|
|
@project: lucaone |
|
|
@file: modeling_lucaone |
|
|
@desc: modeling_lucaone |
|
|
''' |
|
|
import math |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import PreTrainedModel |
|
|
from transformers.modeling_outputs import BaseModelOutput |
|
|
from transformers.modeling_outputs import MaskedLMOutput |
|
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
from transformers.modeling_outputs import TokenClassifierOutput |
|
|
from typing import Optional, List, Union, Tuple |
|
|
from .configuration_lucaone import LucaGPLMConfig |
|
|
try: |
|
|
from apex.normalization import FusedLayerNorm as _FusedLayerNorm |
|
|
class LucaGPLM1bLayerNorm(_FusedLayerNorm): |
|
|
@torch.jit.unused |
|
|
def forward(self, x): |
|
|
if not x.is_cuda: |
|
|
return super().forward(x) |
|
|
else: |
|
|
with torch.cuda.device(x.device): |
|
|
return super().forward(x) |
|
|
except ImportError: |
|
|
from torch.nn import LayerNorm as LucaGPLM1bLayerNorm |
|
|
|
|
|
def gelu(x): |
|
|
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) |
|
|
|
|
|
def rotate_half(x): |
|
|
x1, x2 = x.chunk(2, dim=-1) |
|
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
def apply_rotary_pos_emb(x, cos, sin): |
|
|
cos = cos[:, : x.shape[-2], :] |
|
|
sin = sin[:, : x.shape[-2], :] |
|
|
return (x * cos) + (rotate_half(x) * sin) |
|
|
|
|
|
class LucaGPLMRotaryEmbedding(torch.nn.Module): |
|
|
def __init__(self, dim: int, *_, **__): |
|
|
super().__init__() |
|
|
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) |
|
|
self.register_buffer("inv_freq", inv_freq) |
|
|
|
|
|
self._seq_len_cached = None |
|
|
self._cos_cached = None |
|
|
self._sin_cached = None |
|
|
|
|
|
def _update_cos_sin_tables(self, x, seq_dimension=1): |
|
|
seq_len = x.shape[seq_dimension] |
|
|
|
|
|
if (seq_len != self._seq_len_cached or |
|
|
self._cos_cached is None or |
|
|
self._sin_cached is None or |
|
|
self._cos_cached.device != x.device): |
|
|
self._seq_len_cached = seq_len |
|
|
t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq) |
|
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
|
|
emb = torch.cat((freqs, freqs), dim=-1).to(x.device) |
|
|
|
|
|
self._cos_cached = emb.cos()[None, :, :] |
|
|
self._sin_cached = emb.sin()[None, :, :] |
|
|
|
|
|
return self._cos_cached, self._sin_cached |
|
|
|
|
|
def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2) |
|
|
|
|
|
return ( |
|
|
apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached), |
|
|
apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached), |
|
|
) |
|
|
|
|
|
class LucaGPLMGlobalMaskWeightedAttentionPooling1D(nn.Module): |
|
|
def __init__(self, embed_size, use_bias=False): |
|
|
super(LucaGPLMGlobalMaskWeightedAttentionPooling1D, self).__init__() |
|
|
self.embed_size = embed_size |
|
|
self.use_bias = use_bias |
|
|
|
|
|
self.W = nn.Parameter(torch.Tensor(self.embed_size)) |
|
|
nn.init.trunc_normal_(self.W, std=0.01) |
|
|
if self.use_bias: |
|
|
self.b = nn.Parameter(torch.Tensor(1)) |
|
|
nn.init.trunc_normal_(self.b, std=0.01) |
|
|
|
|
|
def forward(self, x, mask=None): |
|
|
|
|
|
logits = torch.matmul(x, self.W) |
|
|
if self.use_bias: |
|
|
logits += self.b |
|
|
|
|
|
if mask is not None: |
|
|
attention_probs = nn.Softmax(dim=-1)(logits + (1.0 - mask) * -10000) |
|
|
else: |
|
|
attention_probs = nn.Softmax(dim=-1)(logits) |
|
|
x = torch.sum(torch.unsqueeze(attention_probs, dim=-1) * x, dim=1) |
|
|
return x |
|
|
|
|
|
def __repr__(self): |
|
|
return self.__class__.__name__ + ' (' + str(self.embed_size) + (', bias=%r)' % self.use_bias) |
|
|
|
|
|
class LucaGPLMGlobalMaskContextAttentionPooling1D(nn.Module): |
|
|
def __init__(self, embed_size, units=None, use_additive_bias=False, use_attention_bias=False): |
|
|
super(LucaGPLMGlobalMaskContextAttentionPooling1D, self).__init__() |
|
|
self.embed_size = embed_size |
|
|
self.use_additive_bias = use_additive_bias |
|
|
self.use_attention_bias = use_attention_bias |
|
|
self.units = units if units else embed_size |
|
|
|
|
|
self.U = nn.Parameter(torch.Tensor(self.embed_size, self.units)) |
|
|
self.V = nn.Parameter(torch.Tensor(self.embed_size, self.units)) |
|
|
if self.use_additive_bias: |
|
|
self.b1 = nn.Parameter(torch.Tensor(self.units)) |
|
|
nn.init.trunc_normal_(self.b1, std=0.01) |
|
|
if self.use_attention_bias: |
|
|
self.b2 = nn.Parameter(torch.Tensor(1)) |
|
|
nn.init.trunc_normal_(self.b2, std=0.01) |
|
|
|
|
|
self.c = nn.Parameter(torch.Tensor(self.units)) |
|
|
|
|
|
nn.init.trunc_normal_(self.U, std=0.01) |
|
|
nn.init.trunc_normal_(self.V, std=0.01) |
|
|
nn.init.trunc_normal_(self.c, std=0.01) |
|
|
|
|
|
def forward(self, x, mask=None): |
|
|
|
|
|
q = torch.matmul(x, self.U) |
|
|
k = torch.matmul(x, self.V) |
|
|
if self.use_additive_bias: |
|
|
h = torch.tanh(q + k + self.b1) |
|
|
else: |
|
|
h = torch.tanh(q + k) |
|
|
|
|
|
if self.use_attention_bias: |
|
|
e = torch.matmul(h, self.c) + self.b2 |
|
|
else: |
|
|
e = torch.matmul(h, self.c) |
|
|
if mask is not None: |
|
|
attention_probs = nn.Softmax(dim=-1)(e + (1.0 - mask) * -10000) |
|
|
else: |
|
|
attention_probs = nn.Softmax(dim=-1)(e) |
|
|
x = torch.sum(torch.unsqueeze(attention_probs, dim=-1) * x, dim=1) |
|
|
return x |
|
|
|
|
|
def __repr__(self): |
|
|
return self.__class__.__name__ + ' (' + str(self.embed_size) + ' -> ' + str(self.units) + ', bias=(%r, %r))' % (self.use_additive_bias, self.use_attention_bias) |
|
|
|
|
|
class LucaGPLMGlobalMaskValueAttentionPooling1D(nn.Module): |
|
|
def __init__(self, embed_size, units=None, use_additive_bias=False, use_attention_bias=False): |
|
|
super(LucaGPLMGlobalMaskValueAttentionPooling1D, self).__init__() |
|
|
self.embed_size = embed_size |
|
|
self.use_additive_bias = use_additive_bias |
|
|
self.use_attention_bias = use_attention_bias |
|
|
self.units = units if units else embed_size |
|
|
|
|
|
self.U = nn.Parameter(torch.Tensor(self.embed_size, self.units)) |
|
|
self.V = nn.Parameter(torch.Tensor(self.embed_size, self.units)) |
|
|
if self.use_additive_bias: |
|
|
self.b1 = nn.Parameter(torch.Tensor(self.units)) |
|
|
nn.init.trunc_normal_(self.b1, std=0.01) |
|
|
if self.use_attention_bias: |
|
|
self.b2 = nn.Parameter(torch.Tensor(self.embed_size)) |
|
|
nn.init.trunc_normal_(self.b2, std=0.01) |
|
|
|
|
|
self.W = nn.Parameter(torch.Tensor(self.units, self.embed_size)) |
|
|
|
|
|
nn.init.trunc_normal_(self.U, std=0.01) |
|
|
nn.init.trunc_normal_(self.V, std=0.01) |
|
|
nn.init.trunc_normal_(self.W, std=0.01) |
|
|
|
|
|
def forward(self, x, mask=None): |
|
|
|
|
|
q = torch.matmul(x, self.U) |
|
|
k = torch.matmul(x, self.V) |
|
|
if self.use_additive_bias: |
|
|
h = torch.tanh(q + k + self.b1) |
|
|
else: |
|
|
h = torch.tanh(q + k) |
|
|
|
|
|
|
|
|
if self.use_attention_bias: |
|
|
e = torch.matmul(h, self.W) + self.b2 |
|
|
else: |
|
|
e = torch.matmul(h, self.W) |
|
|
if mask is not None: |
|
|
attention_probs = nn.Softmax(dim=1)(e + torch.unsqueeze((1.0 - mask) * -10000, dim=-1)) |
|
|
else: |
|
|
attention_probs = nn.Softmax(dim=1)(e) |
|
|
x = torch.sum(attention_probs * x, dim=1) |
|
|
return x |
|
|
|
|
|
def __repr__(self): |
|
|
return self.__class__.__name__ + ' (' + str(self.embed_size) + ' -> ' + str(self.units) + ', bias=(%r, %r))' % (self.use_additive_bias, self.use_attention_bias) |
|
|
|
|
|
class LucaGPLM1LayerNorm(nn.Module): |
|
|
def __init__(self, hidden_size, eps=1e-12, affine=True): |
|
|
super().__init__() |
|
|
self.hidden_size = (hidden_size,) if isinstance(hidden_size, int) else tuple(hidden_size) |
|
|
self.eps = eps |
|
|
self.affine = bool(affine) |
|
|
if self.affine: |
|
|
self.weight = nn.Parameter(torch.ones(hidden_size)) |
|
|
self.bias = nn.Parameter(torch.zeros(hidden_size)) |
|
|
else: |
|
|
self.weight, self.bias = None, None |
|
|
|
|
|
def forward(self, x): |
|
|
dims = tuple(-(i + 1) for i in range(len(self.hidden_size))) |
|
|
means = x.mean(dims, keepdim=True) |
|
|
x_zeromean = x - means |
|
|
variances = x_zeromean.pow(2).mean(dims, keepdim=True) |
|
|
x = x_zeromean / torch.sqrt(variances + self.eps) |
|
|
if self.affine: |
|
|
x = (self.weight * x) + self.bias |
|
|
return x |
|
|
|
|
|
class LucaGPLMMultiheadAttention(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
embed_dim, |
|
|
num_heads, |
|
|
kdim=None, |
|
|
vdim=None, |
|
|
dropout=0.0, |
|
|
bias=True, |
|
|
add_bias_kv: bool = False, |
|
|
add_zero_attn: bool = False, |
|
|
self_attention: bool = False, |
|
|
encoder_decoder_attention: bool = False, |
|
|
use_rotary_embeddings: bool = False, |
|
|
): |
|
|
super().__init__() |
|
|
self.embed_dim = embed_dim |
|
|
self.kdim = kdim if kdim is not None else embed_dim |
|
|
self.vdim = vdim if vdim is not None else embed_dim |
|
|
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim |
|
|
|
|
|
self.num_heads = num_heads |
|
|
self.dropout = dropout |
|
|
self.head_dim = embed_dim // num_heads |
|
|
assert ( |
|
|
self.head_dim * num_heads == self.embed_dim |
|
|
), "embed_dim must be divisible by num_heads" |
|
|
self.scaling = self.head_dim**-0.5 |
|
|
|
|
|
self.self_attention = self_attention |
|
|
self.encoder_decoder_attention = encoder_decoder_attention |
|
|
|
|
|
assert not self.self_attention or self.qkv_same_dim, ( |
|
|
"Self-attention requires query, key and " "value to be of the same size" |
|
|
) |
|
|
|
|
|
self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias) |
|
|
self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias) |
|
|
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
|
|
|
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
|
|
|
|
if add_bias_kv: |
|
|
self.bias_k = nn.Parameter(torch.Tensor(1, 1, embed_dim)) |
|
|
self.bias_v = nn.Parameter(torch.Tensor(1, 1, embed_dim)) |
|
|
else: |
|
|
self.bias_k = self.bias_v = None |
|
|
|
|
|
self.add_zero_attn = add_zero_attn |
|
|
|
|
|
self.reset_parameters() |
|
|
|
|
|
self.rot_emb = None |
|
|
if use_rotary_embeddings: |
|
|
self.rot_emb = LucaGPLMRotaryEmbedding(dim=self.head_dim) |
|
|
|
|
|
def reset_parameters(self): |
|
|
nn.init.xavier_uniform_(self.k_proj.weight, gain=nn.init.calculate_gain("relu")) |
|
|
nn.init.xavier_uniform_(self.v_proj.weight, gain=nn.init.calculate_gain("relu")) |
|
|
nn.init.xavier_uniform_(self.q_proj.weight, gain=nn.init.calculate_gain("relu")) |
|
|
nn.init.xavier_uniform_(self.out_proj.weight, gain=nn.init.calculate_gain("relu")) |
|
|
|
|
|
if self.out_proj.bias is not None: |
|
|
nn.init.constant_(self.out_proj.bias, 0.0) |
|
|
if self.bias_k is not None: |
|
|
nn.init.xavier_normal_(self.bias_k) |
|
|
if self.bias_v is not None: |
|
|
nn.init.xavier_normal_(self.bias_v) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
query, |
|
|
key: Optional[torch.Tensor] = None, |
|
|
value: Optional[torch.Tensor] = None, |
|
|
key_padding_mask: Optional[torch.Tensor] = None, |
|
|
need_weights: bool = True, |
|
|
attn_mask: Optional[torch.Tensor] = None, |
|
|
need_head_weights: bool = False, |
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
|
if need_head_weights: |
|
|
need_weights = True |
|
|
|
|
|
tgt_len, bsz, embed_dim = query.size() |
|
|
assert embed_dim == self.embed_dim |
|
|
|
|
|
if self.self_attention: |
|
|
q = self.q_proj(query) |
|
|
k = self.k_proj(query) |
|
|
v = self.v_proj(query) |
|
|
else: |
|
|
assert key is not None and value is not None |
|
|
q = self.q_proj(query) |
|
|
k = self.k_proj(key) |
|
|
v = self.v_proj(value) |
|
|
|
|
|
q *= self.scaling |
|
|
|
|
|
if self.bias_k is not None: |
|
|
assert self.bias_v is not None |
|
|
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) |
|
|
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) |
|
|
if attn_mask is not None: |
|
|
attn_mask = torch.cat( |
|
|
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 |
|
|
) |
|
|
if key_padding_mask is not None: |
|
|
key_padding_mask = torch.cat( |
|
|
[ |
|
|
key_padding_mask, |
|
|
key_padding_mask.new_zeros(key_padding_mask.size(0), 1), |
|
|
], |
|
|
dim=1, |
|
|
) |
|
|
|
|
|
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) |
|
|
if k is not None: |
|
|
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) |
|
|
if v is not None: |
|
|
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) |
|
|
|
|
|
assert k is not None |
|
|
src_len = k.size(1) |
|
|
|
|
|
if self.rot_emb: |
|
|
q, k = self.rot_emb(q, k) |
|
|
|
|
|
attn_weights = torch.bmm(q, k.transpose(1, 2)) |
|
|
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] |
|
|
|
|
|
if attn_mask is not None: |
|
|
attn_mask = attn_mask.unsqueeze(0) |
|
|
attn_weights += attn_mask |
|
|
|
|
|
if key_padding_mask is not None: |
|
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) |
|
|
attn_weights = attn_weights.masked_fill( |
|
|
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf") |
|
|
) |
|
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
|
|
|
|
|
attn_weights_float = F.softmax(attn_weights, dim=-1) |
|
|
attn_weights = attn_weights_float.type_as(attn_weights) |
|
|
attn_probs = F.dropout( |
|
|
attn_weights_float.type_as(attn_weights), |
|
|
p=self.dropout, |
|
|
training=self.training, |
|
|
) |
|
|
|
|
|
assert v is not None |
|
|
attn = torch.bmm(attn_probs, v) |
|
|
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] |
|
|
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) |
|
|
attn = self.out_proj(attn) |
|
|
|
|
|
attn_weights_output: Optional[torch.Tensor] = None |
|
|
if need_weights: |
|
|
attn_weights_output = attn_weights_float.view( |
|
|
bsz, self.num_heads, tgt_len, src_len |
|
|
).type_as(attn).transpose(1, 0) |
|
|
if not need_head_weights: |
|
|
|
|
|
attn_weights_output = attn_weights_output.mean(dim=0) |
|
|
|
|
|
return attn, attn_weights_output |
|
|
|
|
|
class LucaGPLMMultiheadAttentionWithSDPA(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
embed_dim, |
|
|
num_heads, |
|
|
kdim=None, |
|
|
vdim=None, |
|
|
dropout=0.0, |
|
|
bias=True, |
|
|
add_bias_kv: bool = False, |
|
|
add_zero_attn: bool = False, |
|
|
self_attention: bool = False, |
|
|
encoder_decoder_attention: bool = False, |
|
|
use_rotary_embeddings: bool = True, |
|
|
): |
|
|
super().__init__() |
|
|
self.embed_dim = embed_dim |
|
|
self.kdim = kdim if kdim is not None else embed_dim |
|
|
self.vdim = vdim if vdim is not None else embed_dim |
|
|
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim |
|
|
|
|
|
self.num_heads = num_heads |
|
|
self.dropout = dropout |
|
|
self.head_dim = embed_dim // num_heads |
|
|
assert ( |
|
|
self.head_dim * num_heads == self.embed_dim |
|
|
), "embed_dim must be divisible by num_heads" |
|
|
self.scaling = self.head_dim**-0.5 |
|
|
|
|
|
self.self_attention = self_attention |
|
|
self.encoder_decoder_attention = encoder_decoder_attention |
|
|
|
|
|
assert not self.self_attention or self.qkv_same_dim, ( |
|
|
"Self-attention requires query, key and " "value to be of the same size" |
|
|
) |
|
|
|
|
|
self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias) |
|
|
self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias) |
|
|
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
|
|
|
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
|
|
|
|
if add_bias_kv: |
|
|
self.bias_k = nn.Parameter(torch.Tensor(1, 1, embed_dim)) |
|
|
self.bias_v = nn.Parameter(torch.Tensor(1, 1, embed_dim)) |
|
|
else: |
|
|
self.bias_k = self.bias_v = None |
|
|
|
|
|
self.add_zero_attn = add_zero_attn |
|
|
|
|
|
self.reset_parameters() |
|
|
|
|
|
self.rot_emb = None |
|
|
if use_rotary_embeddings: |
|
|
self.rot_emb = LucaGPLMRotaryEmbedding(dim=self.head_dim) |
|
|
|
|
|
def reset_parameters(self): |
|
|
nn.init.xavier_uniform_(self.k_proj.weight, gain=nn.init.calculate_gain("relu")) |
|
|
nn.init.xavier_uniform_(self.v_proj.weight, gain=nn.init.calculate_gain("relu")) |
|
|
nn.init.xavier_uniform_(self.q_proj.weight, gain=nn.init.calculate_gain("relu")) |
|
|
nn.init.xavier_uniform_(self.out_proj.weight, gain=nn.init.calculate_gain("relu")) |
|
|
|
|
|
if self.out_proj.bias is not None: |
|
|
nn.init.constant_(self.out_proj.bias, 0.0) |
|
|
if self.bias_k is not None: |
|
|
nn.init.xavier_normal_(self.bias_k) |
|
|
if self.bias_v is not None: |
|
|
nn.init.xavier_normal_(self.bias_v) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
query, |
|
|
key: Optional[torch.Tensor] = None, |
|
|
value: Optional[torch.Tensor] = None, |
|
|
key_padding_mask: Optional[torch.Tensor] = None, |
|
|
need_weights: bool = True, |
|
|
attn_mask: Optional[torch.Tensor] = None, |
|
|
need_head_weights: bool = False, |
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
|
|
|
|
tgt_len, bsz, embed_dim = query.size() |
|
|
assert embed_dim == self.embed_dim |
|
|
|
|
|
if self.self_attention: |
|
|
q = self.q_proj(query) |
|
|
k = self.k_proj(query) |
|
|
v = self.v_proj(query) |
|
|
else: |
|
|
assert key is not None and value is not None |
|
|
q = self.q_proj(query) |
|
|
k = self.k_proj(key) |
|
|
v = self.v_proj(value) |
|
|
|
|
|
if self.bias_k is not None: |
|
|
assert self.bias_v is not None |
|
|
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) |
|
|
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) |
|
|
if attn_mask is not None: |
|
|
attn_mask = torch.cat( |
|
|
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 |
|
|
) |
|
|
if key_padding_mask is not None: |
|
|
key_padding_mask = torch.cat( |
|
|
[ |
|
|
key_padding_mask, |
|
|
key_padding_mask.new_zeros(key_padding_mask.size(0), 1), |
|
|
], |
|
|
dim=1, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not need_head_weights and hasattr(F, "scaled_dot_product_attention"): |
|
|
|
|
|
|
|
|
q_sdpa = q.view(tgt_len, bsz, self.num_heads, self.head_dim).permute(1, 2, 0, 3) |
|
|
k_sdpa = k.view(-1, bsz, self.num_heads, self.head_dim).permute(1, 2, 0, 3) |
|
|
v_sdpa = v.view(-1, bsz, self.num_heads, self.head_dim).permute(1, 2, 0, 3) |
|
|
|
|
|
|
|
|
if self.rot_emb: |
|
|
|
|
|
|
|
|
q_sdpa, k_sdpa = self.rot_emb(q_sdpa, k_sdpa) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sdpa_mask = None |
|
|
if attn_mask is not None or key_padding_mask is not None: |
|
|
|
|
|
target_shape = (bsz, 1, tgt_len, k_sdpa.size(2)) |
|
|
sdpa_mask = torch.zeros(target_shape, device=q.device, dtype=q.dtype) |
|
|
|
|
|
if key_padding_mask is not None: |
|
|
|
|
|
sdpa_mask = sdpa_mask.masked_fill( |
|
|
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), |
|
|
float("-inf") |
|
|
) |
|
|
|
|
|
if attn_mask is not None: |
|
|
if attn_mask.dim() == 2: |
|
|
sdpa_mask = sdpa_mask + attn_mask.unsqueeze(0).unsqueeze(0) |
|
|
elif attn_mask.dim() == 3: |
|
|
pass |
|
|
else: |
|
|
sdpa_mask = sdpa_mask + attn_mask |
|
|
|
|
|
|
|
|
|
|
|
attn_output = F.scaled_dot_product_attention( |
|
|
q_sdpa, |
|
|
k_sdpa, |
|
|
v_sdpa, |
|
|
attn_mask=sdpa_mask, |
|
|
dropout_p=self.dropout if self.training else 0.0, |
|
|
is_causal=False |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(tgt_len, bsz, self.embed_dim) |
|
|
|
|
|
|
|
|
attn_output = self.out_proj(attn_output) |
|
|
|
|
|
|
|
|
return attn_output, None |
|
|
|
|
|
q = q * self.scaling |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) |
|
|
if k is not None: |
|
|
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) |
|
|
if v is not None: |
|
|
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) |
|
|
|
|
|
assert k is not None |
|
|
src_len = k.size(1) |
|
|
|
|
|
if self.rot_emb: |
|
|
q, k = self.rot_emb(q, k) |
|
|
|
|
|
attn_weights = torch.bmm(q, k.transpose(1, 2)) |
|
|
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] |
|
|
|
|
|
if attn_mask is not None: |
|
|
attn_mask = attn_mask.unsqueeze(0) |
|
|
attn_weights += attn_mask |
|
|
|
|
|
if key_padding_mask is not None: |
|
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) |
|
|
attn_weights = attn_weights.masked_fill( |
|
|
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf") |
|
|
) |
|
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
|
|
|
|
|
attn_weights_float = F.softmax(attn_weights, dim=-1) |
|
|
attn_weights = attn_weights_float.type_as(attn_weights) |
|
|
attn_probs = F.dropout( |
|
|
attn_weights_float.type_as(attn_weights), |
|
|
p=self.dropout, |
|
|
training=self.training, |
|
|
) |
|
|
|
|
|
assert v is not None |
|
|
attn = torch.bmm(attn_probs, v) |
|
|
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] |
|
|
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) |
|
|
attn = self.out_proj(attn) |
|
|
|
|
|
attn_weights_output: Optional[torch.Tensor] = None |
|
|
if need_weights: |
|
|
attn_weights_output = attn_weights_float.view( |
|
|
bsz, self.num_heads, tgt_len, src_len |
|
|
).type_as(attn).transpose(1, 0) |
|
|
if not need_head_weights: |
|
|
|
|
|
attn_weights_output = attn_weights_output.mean(dim=0) |
|
|
|
|
|
return attn, attn_weights_output |
|
|
|
|
|
class LucaGPLMRobertaLMHead(nn.Module): |
|
|
def __init__(self, embed_dim, output_dim): |
|
|
super().__init__() |
|
|
self.dense = nn.Linear(embed_dim, embed_dim) |
|
|
self.layer_norm = LucaGPLM1bLayerNorm(embed_dim) |
|
|
|
|
|
self.decoder = nn.Linear(embed_dim, output_dim, bias=False) |
|
|
self.bias = nn.Parameter(torch.zeros(output_dim)) |
|
|
|
|
|
def forward(self, features): |
|
|
x = self.dense(features) |
|
|
x = gelu(x) |
|
|
x = self.layer_norm(x) |
|
|
|
|
|
|
|
|
x = self.decoder(x) + self.bias |
|
|
return x |
|
|
|
|
|
class LucaGPLMTransformerLayer(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
embed_dim, |
|
|
ffn_embed_dim, |
|
|
attention_heads, |
|
|
add_bias_kv=True, |
|
|
use_lucagplm1b_layer_norm=False, |
|
|
use_rotary_embeddings: bool=True, |
|
|
): |
|
|
super().__init__() |
|
|
self.embed_dim = embed_dim |
|
|
self.ffn_embed_dim = ffn_embed_dim |
|
|
self.attention_heads = attention_heads |
|
|
self.use_rotary_embeddings = use_rotary_embeddings |
|
|
|
|
|
LucaGPLMLayerNorm = LucaGPLM1bLayerNorm if use_lucagplm1b_layer_norm else LucaGPLM1LayerNorm |
|
|
|
|
|
self.pre_layer_norm = LucaGPLMLayerNorm(self.embed_dim) |
|
|
|
|
|
self.self_attn = LucaGPLMMultiheadAttentionWithSDPA( |
|
|
self.embed_dim, |
|
|
self.attention_heads, |
|
|
add_bias_kv=add_bias_kv, |
|
|
add_zero_attn=False, |
|
|
self_attention=True, |
|
|
use_rotary_embeddings=self.use_rotary_embeddings, |
|
|
) |
|
|
|
|
|
|
|
|
self.post_layer_norm = LucaGPLMLayerNorm(self.embed_dim) |
|
|
|
|
|
|
|
|
self.fc1 = nn.Linear(self.embed_dim, self.ffn_embed_dim) |
|
|
|
|
|
|
|
|
self.fc2 = nn.Linear(self.ffn_embed_dim, self.embed_dim) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x, |
|
|
self_attn_mask=None, |
|
|
self_attn_padding_mask=None, |
|
|
need_head_weights=False |
|
|
): |
|
|
residual = x |
|
|
x = self.pre_layer_norm(x) |
|
|
x, attn = self.self_attn( |
|
|
query=x, |
|
|
key=x, |
|
|
value=x, |
|
|
key_padding_mask=self_attn_padding_mask, |
|
|
need_weights=True, |
|
|
need_head_weights=need_head_weights, |
|
|
attn_mask=self_attn_mask, |
|
|
) |
|
|
x = residual + x |
|
|
|
|
|
residual = x |
|
|
x = self.post_layer_norm(x) |
|
|
x = gelu(self.fc1(x)) |
|
|
x = self.fc2(x) |
|
|
x = residual + x |
|
|
|
|
|
return x, attn |
|
|
|
|
|
class LucaGPLMEmbeddings(nn.Module): |
|
|
def __init__(self, config: LucaGPLMConfig): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.no_position_embeddings = getattr(config, 'no_position_embeddings', False) |
|
|
self.no_token_type_embeddings = getattr(config, 'no_token_type_embeddings', False) |
|
|
self.use_embed_layer_norm = getattr(config, 'use_embed_layer_norm', True) |
|
|
self.embed_scale = getattr(config, 'embed_scale', 1.0) |
|
|
self.token_dropout = getattr(config, 'token_dropout', False) |
|
|
|
|
|
|
|
|
self.mask_idx = getattr(config, 'mask_token_id', 4) |
|
|
self.padding_idx = getattr(config, 'pad_token_id', 0) |
|
|
|
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) |
|
|
|
|
|
|
|
|
if not self.no_position_embeddings: |
|
|
self.embed_pos = nn.Embedding(config.max_position_embeddings, config.hidden_size) |
|
|
else: |
|
|
self.embed_pos = None |
|
|
|
|
|
|
|
|
if not self.no_token_type_embeddings: |
|
|
self.embed_type = nn.Embedding(config.type_vocab_size, config.hidden_size) |
|
|
else: |
|
|
self.embed_type = None |
|
|
|
|
|
|
|
|
if self.use_embed_layer_norm: |
|
|
self.embed_layer_norm = LucaGPLM1bLayerNorm(config.hidden_size) |
|
|
else: |
|
|
self.embed_layer_norm = None |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
token_type_ids: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
) -> torch.Tensor: |
|
|
input_shape = input_ids.size() |
|
|
seq_length = input_shape[1] |
|
|
|
|
|
|
|
|
inputs_embeds = self.embed_scale * self.embed_tokens(input_ids) |
|
|
|
|
|
|
|
|
if not self.no_position_embeddings and self.embed_pos is not None: |
|
|
if position_ids is None: |
|
|
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) |
|
|
position_ids = position_ids.unsqueeze(0).expand(input_shape) |
|
|
position_embeddings = self.embed_scale * self.embed_pos(position_ids) |
|
|
inputs_embeds = inputs_embeds + position_embeddings |
|
|
|
|
|
|
|
|
if not self.no_token_type_embeddings and self.embed_type is not None: |
|
|
if token_type_ids is None: |
|
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=input_ids.device) |
|
|
token_type_embeddings = self.embed_scale * self.embed_type(token_type_ids) |
|
|
inputs_embeds = inputs_embeds + token_type_embeddings |
|
|
|
|
|
|
|
|
if self.use_embed_layer_norm and self.embed_layer_norm is not None: |
|
|
embeddings = self.embed_layer_norm(inputs_embeds) |
|
|
else: |
|
|
embeddings = inputs_embeds |
|
|
|
|
|
|
|
|
if self.token_dropout and self.training: |
|
|
|
|
|
embeddings = embeddings.masked_fill((input_ids == self.mask_idx).unsqueeze(-1), 0.0) |
|
|
|
|
|
|
|
|
mask_ratio_train = 0.15 * 0.8 |
|
|
padding_mask = input_ids.eq(self.padding_idx) |
|
|
src_lengths = (~padding_mask).sum(-1) |
|
|
mask_ratio_observed = (input_ids == self.mask_idx).sum(-1).to(embeddings.dtype) / src_lengths |
|
|
embeddings = embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None] |
|
|
|
|
|
|
|
|
padding_mask = input_ids.eq(self.padding_idx) |
|
|
if padding_mask.any(): |
|
|
embeddings = embeddings * (1 - padding_mask.unsqueeze(-1).type_as(embeddings)) |
|
|
|
|
|
return embeddings |
|
|
|
|
|
class LucaGPLMEncoder(nn.Module): |
|
|
def __init__(self, config: LucaGPLMConfig): |
|
|
super().__init__() |
|
|
|
|
|
self.layers = nn.ModuleList([ |
|
|
LucaGPLMTransformerLayer( |
|
|
config.hidden_size, |
|
|
4 * config.hidden_size, |
|
|
config.num_attention_heads, |
|
|
add_bias_kv=False, |
|
|
use_lucagplm1b_layer_norm=True, |
|
|
use_rotary_embeddings=True, |
|
|
) |
|
|
for _ in range(config.num_hidden_layers) |
|
|
]) |
|
|
|
|
|
self.use_last_layer_norm = getattr(config, 'use_last_layer_norm', True) |
|
|
if self.use_last_layer_norm: |
|
|
self.last_layer_norm = LucaGPLM1bLayerNorm(config.hidden_size) |
|
|
else: |
|
|
self.last_layer_norm = None |
|
|
|
|
|
self.padding_idx = config.pad_token_id |
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
output_attentions: bool = False, |
|
|
output_hidden_states: bool = False, |
|
|
return_dict: bool = True, |
|
|
need_head_weights: bool = False, |
|
|
repr_layers: Optional[List[int]] = None, |
|
|
use_last_layer_norm: bool = True, |
|
|
) -> Union[Tuple[torch.Tensor], BaseModelOutput]: |
|
|
all_hidden_states = () if output_hidden_states else None |
|
|
all_attentions = () if output_attentions else None |
|
|
|
|
|
if repr_layers is None: |
|
|
repr_layers = [-1] |
|
|
|
|
|
|
|
|
layer_size = len(self.layers) |
|
|
repr_layers = [(i + layer_size + 1) % (layer_size + 1) for i in repr_layers] |
|
|
repr_layers = set(repr_layers) |
|
|
hidden_representations = {} |
|
|
|
|
|
|
|
|
if attention_mask is None: |
|
|
padding_mask = hidden_states.new_zeros(hidden_states.shape[:2]).eq(self.padding_idx) |
|
|
else: |
|
|
|
|
|
padding_mask = attention_mask.eq(0) |
|
|
|
|
|
|
|
|
if 0 in repr_layers: |
|
|
hidden_representations[0] = hidden_states |
|
|
|
|
|
|
|
|
hidden_states = hidden_states.transpose(0, 1) |
|
|
|
|
|
if not padding_mask.any(): |
|
|
padding_mask = None |
|
|
|
|
|
|
|
|
if need_head_weights or output_attentions: |
|
|
attn_weights = [] |
|
|
|
|
|
for layer_idx, layer_module in enumerate(self.layers): |
|
|
if output_hidden_states: |
|
|
all_hidden_states = all_hidden_states + (hidden_states.transpose(0, 1),) |
|
|
|
|
|
if self.gradient_checkpointing and self.training: |
|
|
layer_outputs = self._gradient_checkpointing_func( |
|
|
layer_module.__call__, |
|
|
hidden_states, |
|
|
None, |
|
|
padding_mask, |
|
|
need_head_weights or output_attentions, |
|
|
) |
|
|
else: |
|
|
layer_outputs = layer_module( |
|
|
hidden_states, |
|
|
self_attn_mask=None, |
|
|
self_attn_padding_mask=padding_mask, |
|
|
need_head_weights=need_head_weights or output_attentions, |
|
|
) |
|
|
|
|
|
hidden_states, attn = layer_outputs |
|
|
|
|
|
if (layer_idx + 1) in repr_layers: |
|
|
hidden_representations[layer_idx + 1] = hidden_states.transpose(0, 1) |
|
|
|
|
|
if need_head_weights or output_attentions: |
|
|
|
|
|
attn_weights.append(attn.transpose(1, 0)) |
|
|
|
|
|
|
|
|
if self.last_layer_norm is not None and use_last_layer_norm: |
|
|
hidden_states = self.last_layer_norm(hidden_states) |
|
|
|
|
|
|
|
|
hidden_states = hidden_states.transpose(0, 1) |
|
|
|
|
|
|
|
|
if (layer_idx + 1) in repr_layers: |
|
|
hidden_representations[layer_idx + 1] = hidden_states |
|
|
|
|
|
if output_hidden_states: |
|
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
|
|
if need_head_weights or output_attentions: |
|
|
|
|
|
if attn_weights: |
|
|
|
|
|
all_attentions = torch.stack(attn_weights, 1) |
|
|
if padding_mask is not None: |
|
|
attention_mask_expanded = 1 - padding_mask.type_as(all_attentions) |
|
|
attention_mask_expanded = attention_mask_expanded.unsqueeze(1) * attention_mask_expanded.unsqueeze(2) |
|
|
all_attentions = all_attentions * attention_mask_expanded[:, None, None, :, :] |
|
|
|
|
|
if not output_attentions: |
|
|
all_attentions = None |
|
|
|
|
|
if not return_dict: |
|
|
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) |
|
|
|
|
|
return BaseModelOutput( |
|
|
last_hidden_state=hidden_states, |
|
|
hidden_states=all_hidden_states, |
|
|
attentions=all_attentions, |
|
|
) |
|
|
|
|
|
class LucaGPLMPreTrainedModel(PreTrainedModel): |
|
|
config_class = LucaGPLMConfig |
|
|
base_model_prefix = "lucaone" |
|
|
supports_gradient_checkpointing = True |
|
|
_no_split_modules = ["LucaGPLMTransformerLayer"] |
|
|
|
|
|
def _init_weights(self, module): |
|
|
if isinstance(module, nn.Linear): |
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
|
if module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
elif isinstance(module, nn.Embedding): |
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
|
if module.padding_idx is not None: |
|
|
module.weight.data[module.padding_idx].zero_() |
|
|
elif isinstance(module, (LucaGPLM1LayerNorm, LucaGPLM1bLayerNorm)): |
|
|
if hasattr(module, 'weight') and module.weight is not None: |
|
|
module.weight.data.fill_(1.0) |
|
|
if hasattr(module, 'bias') and module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
|
|
|
class LucaGPLMModel(LucaGPLMPreTrainedModel): |
|
|
""" |
|
|
The LucaGPLM model for extracting sequence representations and optionally predicting contacts. |
|
|
Based on the original LucaGPLM implementation but restructured to use modern transformers architecture. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: LucaGPLMConfig): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
self.embeddings = LucaGPLMEmbeddings(self.config) |
|
|
self.encoder = LucaGPLMEncoder(self.config) |
|
|
self.post_init() |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.embeddings.embed_tokens |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
self.embeddings.embed_tokens = value |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
token_type_ids: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_contacts: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
need_head_weights: Optional[bool] = None, |
|
|
repr_layers: Optional[List[int]] = None, |
|
|
use_last_layer_norm: Optional[bool] = True, |
|
|
) -> Union[Tuple[torch.Tensor], BaseModelOutput]: |
|
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else getattr(self.config, 'output_attentions', False) |
|
|
output_hidden_states = output_hidden_states if output_hidden_states is not None else getattr(self.config, 'output_hidden_states', False) |
|
|
return_contacts = return_contacts if return_contacts is not None else False |
|
|
return_dict = return_dict if return_dict is not None else getattr(self.config, 'use_return_dict', True) |
|
|
need_head_weights = need_head_weights if need_head_weights is not None else return_contacts |
|
|
use_last_layer_norm = use_last_layer_norm if use_last_layer_norm is not None else True |
|
|
|
|
|
|
|
|
if return_contacts: |
|
|
output_attentions = True |
|
|
need_head_weights = True |
|
|
|
|
|
if input_ids is not None and inputs_embeds is not None: |
|
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
|
|
elif input_ids is not None: |
|
|
input_shape = input_ids.size() |
|
|
elif inputs_embeds is not None: |
|
|
input_shape = inputs_embeds.size()[:-1] |
|
|
else: |
|
|
raise ValueError("You have to specify either input_ids or inputs_embeds") |
|
|
|
|
|
batch_size, seq_length = input_shape |
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
|
|
|
|
|
|
|
if attention_mask is None: |
|
|
attention_mask = torch.ones(input_shape, device=device) |
|
|
|
|
|
|
|
|
if inputs_embeds is None: |
|
|
embedding_output = self.embeddings( |
|
|
input_ids=input_ids, |
|
|
position_ids=position_ids, |
|
|
token_type_ids=token_type_ids, |
|
|
) |
|
|
else: |
|
|
embedding_output = inputs_embeds |
|
|
|
|
|
|
|
|
encoder_outputs = self.encoder( |
|
|
embedding_output, |
|
|
attention_mask=attention_mask, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
need_head_weights=need_head_weights, |
|
|
repr_layers=repr_layers, |
|
|
use_last_layer_norm=use_last_layer_norm, |
|
|
) |
|
|
|
|
|
sequence_output = encoder_outputs[0] |
|
|
|
|
|
|
|
|
contacts = None |
|
|
if return_contacts and encoder_outputs.attentions is not None: |
|
|
|
|
|
|
|
|
attentions = encoder_outputs.attentions |
|
|
|
|
|
averaged_attention = attentions.mean(dim=(1, 2)) |
|
|
contacts = (averaged_attention + averaged_attention.transpose(-1, -2)) / 2 |
|
|
|
|
|
|
|
|
if attention_mask is not None: |
|
|
|
|
|
seq_lens = attention_mask.sum(dim=1) |
|
|
|
|
|
|
|
|
if not return_dict: |
|
|
outputs = (sequence_output, ) + encoder_outputs[1:] |
|
|
if contacts is not None: |
|
|
outputs = outputs + (contacts,) |
|
|
return outputs |
|
|
|
|
|
|
|
|
output = BaseModelOutput( |
|
|
last_hidden_state=sequence_output, |
|
|
hidden_states=encoder_outputs.hidden_states, |
|
|
attentions=encoder_outputs.attentions, |
|
|
) |
|
|
|
|
|
|
|
|
if contacts is not None: |
|
|
output.contacts = contacts |
|
|
|
|
|
return output |
|
|
|
|
|
class LucaGPLMForMaskedLM(LucaGPLMPreTrainedModel): |
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
|
|
|
self.lucaone = LucaGPLMModel(config) |
|
|
|
|
|
|
|
|
self.lm_head = LucaGPLMRobertaLMHead( |
|
|
embed_dim=config.hidden_size, |
|
|
output_dim=config.vocab_size |
|
|
) |
|
|
self._tied_weights_keys = [ |
|
|
"lucaone.embeddings.embed_tokens.weight", |
|
|
"lm_head.decoder.weight" |
|
|
] |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.lucaone.get_input_embeddings() |
|
|
|
|
|
def get_output_embeddings(self): |
|
|
return self.lm_head.decoder |
|
|
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
|
self.lm_head.decoder = new_embeddings |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
token_type_ids: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
labels: Optional[torch.Tensor] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
) -> Union[Tuple, MaskedLMOutput]: |
|
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
|
|
|
outputs = self.lucaone( |
|
|
input_ids, |
|
|
attention_mask=attention_mask, |
|
|
token_type_ids=token_type_ids, |
|
|
position_ids=position_ids, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
sequence_output = outputs[0] |
|
|
|
|
|
|
|
|
prediction_scores = self.lm_head(sequence_output) |
|
|
|
|
|
masked_lm_loss = None |
|
|
if labels is not None: |
|
|
|
|
|
loss_fct = nn.CrossEntropyLoss(ignore_index=-100) |
|
|
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) |
|
|
|
|
|
if not return_dict: |
|
|
output = (prediction_scores,) + outputs[2:] |
|
|
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output |
|
|
|
|
|
return MaskedLMOutput( |
|
|
loss=masked_lm_loss, |
|
|
logits=prediction_scores, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|
|
|
class LucaGPLMForSequenceClassification(LucaGPLMPreTrainedModel): |
|
|
def __init__(self, config): |
|
|
if hasattr(config, "classifier_num_labels") and config.classifier_num_labels > 0: |
|
|
config.num_labels = config.classifier_num_labels |
|
|
super().__init__(config) |
|
|
self.num_labels = config.num_labels |
|
|
self.task_level = config.task_level |
|
|
self.task_type = config.task_type |
|
|
assert self.task_level == "seq_level" |
|
|
self.classifier_pooling_type = config.classifier_pooling_type |
|
|
self.classifier_loss_type = config.classifier_loss_type |
|
|
self.classifier_loss_reduction = config.classifier_loss_reduction |
|
|
self.classifier_pos_weight = config.classifier_pos_weight |
|
|
self.classifier_weight = config.classifier_weight |
|
|
self.lucaone = LucaGPLMModel(config) |
|
|
if self.classifier_pooling_type == "value_attention": |
|
|
self.pooler = LucaGPLMGlobalMaskValueAttentionPooling1D(config.hidden_size) |
|
|
elif self.classifier_pooling_type == "context_attention": |
|
|
self.pooler = LucaGPLMGlobalMaskContextAttentionPooling1D(embed_size=config.hidden_size) |
|
|
elif self.classifier_pooling_type == "weighted_attention": |
|
|
self.pooler = LucaGPLMGlobalMaskWeightedAttentionPooling1D(embed_size=config.hidden_size) |
|
|
else: |
|
|
self.pooler = None |
|
|
self.dropout = nn.Dropout(config.classifier_dropout_prob) |
|
|
|
|
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
if self.task_type == "multi_class": |
|
|
weight = None |
|
|
if self.classifier_weight: |
|
|
if isinstance(self.classifier_weight, str) or isinstance(self.classifier_weight, int): |
|
|
weight = torch.tensor([float(self.classifier_weight)] * self.num_labels, dtype=torch.float32) |
|
|
elif isinstance(self.classifier_weight, float): |
|
|
weight = torch.tensor([self.classifier_weight] * self.num_labels, dtype=torch.float32) |
|
|
elif isinstance(self.classifier_weight, list): |
|
|
weight = torch.tensor(self.classifier_weight, dtype=torch.float32) |
|
|
self.loss_fct = nn.CrossEntropyLoss(weight=weight, reduction="mean") |
|
|
elif self.task_type == "binary_class": |
|
|
pos_weight = None |
|
|
if self.classifier_pos_weight: |
|
|
if isinstance(self.classifier_pos_weight, str) or isinstance(self.classifier_pos_weight, int): |
|
|
pos_weight = torch.tensor([float(self.classifier_pos_weight)], dtype=torch.float32) |
|
|
elif isinstance(self.classifier_pos_weight, float): |
|
|
pos_weight = torch.tensor([self.classifier_pos_weight], dtype=torch.float32) |
|
|
self.loss_fct = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction="mean") |
|
|
elif self.task_type == "regression": |
|
|
if self.classifier_loss_type == "mae": |
|
|
self.loss_fct = nn.L1Loss(reduction="mean") |
|
|
else: |
|
|
self.loss_fct = nn.MSELoss(reduction="mean") |
|
|
elif self.task_type == "multi_label": |
|
|
pos_weight = None |
|
|
if self.classifier_pos_weight: |
|
|
if isinstance(self.classifier_pos_weight, str) or isinstance(self.classifier_pos_weight, int): |
|
|
pos_weight = torch.tensor([float(self.classifier_pos_weight)] * self.num_labels, dtype=torch.float32) |
|
|
elif isinstance(self.classifier_pos_weight, float): |
|
|
pos_weight = torch.tensor([self.classifier_pos_weight] * self.num_labels, dtype=torch.float32) |
|
|
self.loss_fct = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction=self.classifier_loss_reduction) |
|
|
else: |
|
|
raise ValueError("Invalid task type: %s" % self.task_type) |
|
|
self.post_init() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids=None, |
|
|
token_type_ids=None, |
|
|
attention_mask=None, |
|
|
labels=None, |
|
|
return_dict=None |
|
|
): |
|
|
return_dict = return_dict if return_dict is not None else getattr(self.config, 'use_return_dict', True) |
|
|
outputs = self.lucaone( |
|
|
input_ids, |
|
|
token_type_ids=token_type_ids, |
|
|
attention_mask=attention_mask, |
|
|
return_dict=return_dict |
|
|
) |
|
|
if self.pooler is not None: |
|
|
pooled_output = self.pooler(outputs[0]) |
|
|
elif self.classifier_pooling_type == "cls": |
|
|
|
|
|
pooled_output = outputs[0][:, 0, :] |
|
|
elif self.classifier_pooling_type == "mean": |
|
|
pooled_output = outputs[0].mean(dim=1) |
|
|
else: |
|
|
raise ValueError("Invalid classifier pooling type: %s" % self.classifier_pooling_type) |
|
|
|
|
|
pooled_output = self.dropout(pooled_output) |
|
|
logits = self.classifier(pooled_output) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
if self.task_type == "multi_class": |
|
|
loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
elif self.task_type == "binary_class": |
|
|
loss = self.loss_fct(logits.view(-1), labels.view(-1).float()) |
|
|
elif self.task_type == "regression": |
|
|
loss = self.loss_fct(logits.view(-1), labels.view(-1)) |
|
|
elif self.task_type == "multi_label": |
|
|
loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels).float()) |
|
|
else: |
|
|
raise ValueError("Invalid task type: %s" % self.task_type) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) + outputs[1:] |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
return SequenceClassifierOutput(loss=loss, logits=logits) |
|
|
|
|
|
class LucaGPLMForTokenClassification(LucaGPLMPreTrainedModel): |
|
|
def __init__(self, config): |
|
|
if hasattr(config, "classifier_num_labels") and config.classifier_num_labels > 0: |
|
|
config.num_labels = config.classifier_num_labels |
|
|
super().__init__(config) |
|
|
self.num_labels = config.num_labels |
|
|
self.task_level = config.task_level |
|
|
self.task_type = config.task_type |
|
|
assert self.task_level == "token_level" |
|
|
self.classifier_pooling_type = config.classifier_pooling_type |
|
|
self.classifier_loss_type = config.classifier_loss_type |
|
|
self.classifier_loss_reduction = config.classifier_loss_reduction |
|
|
self.classifier_pos_weight = config.classifier_pos_weight |
|
|
self.classifier_weight = config.classifier_weight |
|
|
self.lucaone = LucaGPLMModel(config) |
|
|
self.dropout = nn.Dropout(config.classifier_dropout_prob) |
|
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
if self.task_type == "multi_class": |
|
|
weight = None |
|
|
if self.classifier_weight: |
|
|
|
|
|
if isinstance(self.classifier_weight, str) or isinstance(self.classifier_weight, int): |
|
|
weight = torch.tensor([float(self.classifier_weight)] * self.num_labels, dtype=torch.float32) |
|
|
elif isinstance(self.classifier_weight, float): |
|
|
weight = torch.tensor([self.classifier_weight] * self.num_labels, dtype=torch.float32) |
|
|
elif isinstance(self.classifier_weight, list): |
|
|
weight = torch.tensor(self.classifier_weight, dtype=torch.float32) |
|
|
self.loss_fct = nn.CrossEntropyLoss(weight=weight, reduction="mean") |
|
|
elif self.task_type == "binary_class": |
|
|
pos_weight = None |
|
|
if self.classifier_pos_weight: |
|
|
if isinstance(self.classifier_pos_weight, str) or isinstance(self.classifier_pos_weight, int): |
|
|
pos_weight = torch.tensor([float(self.classifier_pos_weight)], dtype=torch.float32) |
|
|
elif isinstance(self.classifier_pos_weight, float): |
|
|
pos_weight = torch.tensor([float(self.classifier_pos_weight)], dtype=torch.float32) |
|
|
self.loss_fct = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction="mean") |
|
|
elif self.task_type == "regression": |
|
|
if self.classifier_loss_type == "mae": |
|
|
self.loss_fct = nn.L1Loss(reduction="mean") |
|
|
else: |
|
|
self.loss_fct = nn.MSELoss(reduction="mean") |
|
|
elif self.task_type == "multi_label": |
|
|
pos_weight = None |
|
|
if self.classifier_pos_weight: |
|
|
if isinstance(self.classifier_pos_weight, str) or isinstance(self.classifier_pos_weight, int): |
|
|
pos_weight = torch.tensor([float(self.classifier_pos_weight)] * self.num_labels, dtype=torch.float32) |
|
|
elif isinstance(self.classifier_pos_weight, float): |
|
|
pos_weight = torch.tensor([self.classifier_pos_weight] * self.num_labels, dtype=torch.float32) |
|
|
self.loss_fct = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction=self.classifier_loss_reduction) |
|
|
else: |
|
|
raise ValueError("Invalid task type: %s" % self.task_type) |
|
|
self.post_init() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids=None, |
|
|
token_type_ids=None, |
|
|
attention_mask=None, |
|
|
labels=None, |
|
|
return_dict=None |
|
|
): |
|
|
return_dict = return_dict if return_dict is not None else getattr(self.config, 'use_return_dict', True) |
|
|
outputs = self.lucaone( |
|
|
input_ids, |
|
|
token_type_ids=token_type_ids, |
|
|
attention_mask=attention_mask, |
|
|
return_dict=return_dict |
|
|
) |
|
|
sequence_output = outputs[0][:, 1:-1, :] |
|
|
|
|
|
sequence_output = self.dropout(sequence_output) |
|
|
logits = self.classifier(sequence_output) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
if self.task_type == "multi_class": |
|
|
loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
elif self.task_type == "binary_class": |
|
|
loss = self.loss_fct(logits.view(-1), labels.view(-1).float()) |
|
|
elif self.task_type == "regression": |
|
|
loss = self.loss_fct(logits.view(-1), labels.view(-1)) |
|
|
elif self.task_type == "multi_label": |
|
|
loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels).float()) |
|
|
else: |
|
|
raise ValueError("Invalid task type: %s" % self.task_type) |
|
|
|
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) + outputs[1:] |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
return TokenClassifierOutput(loss=loss, logits=logits) |
|
|
|
|
|
__all__ = [ |
|
|
"LucaGPLMModel", |
|
|
"LucaGPLMPreTrainedModel", |
|
|
"LucaGPLMForMaskedLM", |
|
|
"LucaGPLMForSequenceClassification", |
|
|
"LucaGPLMForTokenClassification" |
|
|
] |
|
|
|