LucaOne-default-step36M / modeling_lucaone.py
LucaGroup's picture
Update weights and modeling code to latest version
6c59832 verified
#!/usr/bin/env python
# encoding: utf-8
'''
@license: (C) Copyright 2025, Hey.
@author: Hey
@email: sanyuan.hy@alibaba-inc.com
@tel: 137****6540
@datetime: 2025/12/30 11:35
@project: lucaone
@file: modeling_lucaone
@desc: modeling_lucaone
'''
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutput
from transformers.modeling_outputs import MaskedLMOutput
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.modeling_outputs import TokenClassifierOutput
from typing import Optional, List, Union, Tuple
from .configuration_lucaone import LucaGPLMConfig
try:
from apex.normalization import FusedLayerNorm as _FusedLayerNorm
class LucaGPLM1bLayerNorm(_FusedLayerNorm):
@torch.jit.unused
def forward(self, x):
if not x.is_cuda:
return super().forward(x)
else:
with torch.cuda.device(x.device):
return super().forward(x)
except ImportError:
from torch.nn import LayerNorm as LucaGPLM1bLayerNorm
def gelu(x):
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(x, cos, sin):
cos = cos[:, : x.shape[-2], :]
sin = sin[:, : x.shape[-2], :]
return (x * cos) + (rotate_half(x) * sin)
class LucaGPLMRotaryEmbedding(torch.nn.Module):
def __init__(self, dim: int, *_, **__):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self._seq_len_cached = None
self._cos_cached = None
self._sin_cached = None
def _update_cos_sin_tables(self, x, seq_dimension=1):
seq_len = x.shape[seq_dimension]
if (seq_len != self._seq_len_cached or
self._cos_cached is None or
self._sin_cached is None or
self._cos_cached.device != x.device):
self._seq_len_cached = seq_len
t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self._cos_cached = emb.cos()[None, :, :]
self._sin_cached = emb.sin()[None, :, :]
return self._cos_cached, self._sin_cached
def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
return (
apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
)
class LucaGPLMGlobalMaskWeightedAttentionPooling1D(nn.Module):
def __init__(self, embed_size, use_bias=False):
super(LucaGPLMGlobalMaskWeightedAttentionPooling1D, self).__init__()
self.embed_size = embed_size
self.use_bias = use_bias
self.W = nn.Parameter(torch.Tensor(self.embed_size))
nn.init.trunc_normal_(self.W, std=0.01)
if self.use_bias:
self.b = nn.Parameter(torch.Tensor(1))
nn.init.trunc_normal_(self.b, std=0.01)
def forward(self, x, mask=None):
# (B, Len, Embed) x (Embed,) = (B, Len)
logits = torch.matmul(x, self.W)
if self.use_bias:
logits += self.b
if mask is not None:
attention_probs = nn.Softmax(dim=-1)(logits + (1.0 - mask) * -10000)
else:
attention_probs = nn.Softmax(dim=-1)(logits)
x = torch.sum(torch.unsqueeze(attention_probs, dim=-1) * x, dim=1)
return x
def __repr__(self):
return self.__class__.__name__ + ' (' + str(self.embed_size) + (', bias=%r)' % self.use_bias)
class LucaGPLMGlobalMaskContextAttentionPooling1D(nn.Module):
def __init__(self, embed_size, units=None, use_additive_bias=False, use_attention_bias=False):
super(LucaGPLMGlobalMaskContextAttentionPooling1D, self).__init__()
self.embed_size = embed_size
self.use_additive_bias = use_additive_bias
self.use_attention_bias = use_attention_bias
self.units = units if units else embed_size
self.U = nn.Parameter(torch.Tensor(self.embed_size, self.units))
self.V = nn.Parameter(torch.Tensor(self.embed_size, self.units))
if self.use_additive_bias:
self.b1 = nn.Parameter(torch.Tensor(self.units))
nn.init.trunc_normal_(self.b1, std=0.01)
if self.use_attention_bias:
self.b2 = nn.Parameter(torch.Tensor(1))
nn.init.trunc_normal_(self.b2, std=0.01)
self.c = nn.Parameter(torch.Tensor(self.units))
nn.init.trunc_normal_(self.U, std=0.01)
nn.init.trunc_normal_(self.V, std=0.01)
nn.init.trunc_normal_(self.c, std=0.01)
def forward(self, x, mask=None):
# (B, Len, Embed) x (Embed, Units) = (B, Len, Units)
q = torch.matmul(x, self.U)
k = torch.matmul(x, self.V)
if self.use_additive_bias:
h = torch.tanh(q + k + self.b1)
else:
h = torch.tanh(q + k)
if self.use_attention_bias:
e = torch.matmul(h, self.c) + self.b2
else:
e = torch.matmul(h, self.c)
if mask is not None:
attention_probs = nn.Softmax(dim=-1)(e + (1.0 - mask) * -10000)
else:
attention_probs = nn.Softmax(dim=-1)(e)
x = torch.sum(torch.unsqueeze(attention_probs, dim=-1) * x, dim=1)
return x
def __repr__(self):
return self.__class__.__name__ + ' (' + str(self.embed_size) + ' -> ' + str(self.units) + ', bias=(%r, %r))' % (self.use_additive_bias, self.use_attention_bias)
class LucaGPLMGlobalMaskValueAttentionPooling1D(nn.Module):
def __init__(self, embed_size, units=None, use_additive_bias=False, use_attention_bias=False):
super(LucaGPLMGlobalMaskValueAttentionPooling1D, self).__init__()
self.embed_size = embed_size
self.use_additive_bias = use_additive_bias
self.use_attention_bias = use_attention_bias
self.units = units if units else embed_size
self.U = nn.Parameter(torch.Tensor(self.embed_size, self.units))
self.V = nn.Parameter(torch.Tensor(self.embed_size, self.units))
if self.use_additive_bias:
self.b1 = nn.Parameter(torch.Tensor(self.units))
nn.init.trunc_normal_(self.b1, std=0.01)
if self.use_attention_bias:
self.b2 = nn.Parameter(torch.Tensor(self.embed_size))
nn.init.trunc_normal_(self.b2, std=0.01)
self.W = nn.Parameter(torch.Tensor(self.units, self.embed_size))
nn.init.trunc_normal_(self.U, std=0.01)
nn.init.trunc_normal_(self.V, std=0.01)
nn.init.trunc_normal_(self.W, std=0.01)
def forward(self, x, mask=None):
# (B, Len, Embed) x (Embed, Units) = (B, Len, Units)
q = torch.matmul(x, self.U)
k = torch.matmul(x, self.V)
if self.use_additive_bias:
h = torch.tanh(q + k + self.b1)
else:
h = torch.tanh(q + k)
# (B, Len, Units) x (Units, Embed) = (B, Len, Embed)
if self.use_attention_bias:
e = torch.matmul(h, self.W) + self.b2
else:
e = torch.matmul(h, self.W)
if mask is not None:
attention_probs = nn.Softmax(dim=1)(e + torch.unsqueeze((1.0 - mask) * -10000, dim=-1))
else:
attention_probs = nn.Softmax(dim=1)(e)
x = torch.sum(attention_probs * x, dim=1)
return x
def __repr__(self):
return self.__class__.__name__ + ' (' + str(self.embed_size) + ' -> ' + str(self.units) + ', bias=(%r, %r))' % (self.use_additive_bias, self.use_attention_bias)
class LucaGPLM1LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12, affine=True):
super().__init__()
self.hidden_size = (hidden_size,) if isinstance(hidden_size, int) else tuple(hidden_size)
self.eps = eps
self.affine = bool(affine)
if self.affine:
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
else:
self.weight, self.bias = None, None
def forward(self, x):
dims = tuple(-(i + 1) for i in range(len(self.hidden_size)))
means = x.mean(dims, keepdim=True)
x_zeromean = x - means
variances = x_zeromean.pow(2).mean(dims, keepdim=True)
x = x_zeromean / torch.sqrt(variances + self.eps)
if self.affine:
x = (self.weight * x) + self.bias
return x
class LucaGPLMMultiheadAttention(nn.Module):
def __init__(
self,
embed_dim,
num_heads,
kdim=None,
vdim=None,
dropout=0.0,
bias=True,
add_bias_kv: bool = False,
add_zero_attn: bool = False,
self_attention: bool = False,
encoder_decoder_attention: bool = False,
use_rotary_embeddings: bool = False,
):
super().__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim**-0.5
self.self_attention = self_attention
self.encoder_decoder_attention = encoder_decoder_attention
assert not self.self_attention or self.qkv_same_dim, (
"Self-attention requires query, key and " "value to be of the same size"
)
self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)
self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
if add_bias_kv:
self.bias_k = nn.Parameter(torch.Tensor(1, 1, embed_dim))
self.bias_v = nn.Parameter(torch.Tensor(1, 1, embed_dim))
else:
self.bias_k = self.bias_v = None
self.add_zero_attn = add_zero_attn
self.reset_parameters()
self.rot_emb = None
if use_rotary_embeddings:
self.rot_emb = LucaGPLMRotaryEmbedding(dim=self.head_dim)
def reset_parameters(self):
nn.init.xavier_uniform_(self.k_proj.weight, gain=nn.init.calculate_gain("relu"))
nn.init.xavier_uniform_(self.v_proj.weight, gain=nn.init.calculate_gain("relu"))
nn.init.xavier_uniform_(self.q_proj.weight, gain=nn.init.calculate_gain("relu"))
nn.init.xavier_uniform_(self.out_proj.weight, gain=nn.init.calculate_gain("relu"))
if self.out_proj.bias is not None:
nn.init.constant_(self.out_proj.bias, 0.0)
if self.bias_k is not None:
nn.init.xavier_normal_(self.bias_k)
if self.bias_v is not None:
nn.init.xavier_normal_(self.bias_v)
def forward(
self,
query,
key: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[torch.Tensor] = None,
need_head_weights: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if need_head_weights:
need_weights = True
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
if self.self_attention:
q = self.q_proj(query)
k = self.k_proj(query)
v = self.v_proj(query)
else:
assert key is not None and value is not None
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
q *= self.scaling
if self.bias_k is not None:
assert self.bias_v is not None
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = torch.cat(
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[
key_padding_mask,
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
],
dim=1,
)
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
if k is not None:
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
if v is not None:
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
assert k is not None
src_len = k.size(1)
if self.rot_emb:
q, k = self.rot_emb(q, k)
attn_weights = torch.bmm(q, k.transpose(1, 2))
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
attn_weights += attn_mask
if key_padding_mask is not None:
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights_float = F.softmax(attn_weights, dim=-1)
attn_weights = attn_weights_float.type_as(attn_weights)
attn_probs = F.dropout(
attn_weights_float.type_as(attn_weights),
p=self.dropout,
training=self.training,
)
assert v is not None
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
attn_weights_output: Optional[torch.Tensor] = None
if need_weights:
attn_weights_output = attn_weights_float.view(
bsz, self.num_heads, tgt_len, src_len
).type_as(attn).transpose(1, 0)
if not need_head_weights:
# average attention weights over heads
attn_weights_output = attn_weights_output.mean(dim=0)
return attn, attn_weights_output
class LucaGPLMMultiheadAttentionWithSDPA(nn.Module):
def __init__(
self,
embed_dim,
num_heads,
kdim=None,
vdim=None,
dropout=0.0,
bias=True,
add_bias_kv: bool = False,
add_zero_attn: bool = False,
self_attention: bool = False,
encoder_decoder_attention: bool = False,
use_rotary_embeddings: bool = True,
):
super().__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim**-0.5
self.self_attention = self_attention
self.encoder_decoder_attention = encoder_decoder_attention
assert not self.self_attention or self.qkv_same_dim, (
"Self-attention requires query, key and " "value to be of the same size"
)
self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)
self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
if add_bias_kv:
self.bias_k = nn.Parameter(torch.Tensor(1, 1, embed_dim))
self.bias_v = nn.Parameter(torch.Tensor(1, 1, embed_dim))
else:
self.bias_k = self.bias_v = None
self.add_zero_attn = add_zero_attn
self.reset_parameters()
self.rot_emb = None
if use_rotary_embeddings:
self.rot_emb = LucaGPLMRotaryEmbedding(dim=self.head_dim)
def reset_parameters(self):
nn.init.xavier_uniform_(self.k_proj.weight, gain=nn.init.calculate_gain("relu"))
nn.init.xavier_uniform_(self.v_proj.weight, gain=nn.init.calculate_gain("relu"))
nn.init.xavier_uniform_(self.q_proj.weight, gain=nn.init.calculate_gain("relu"))
nn.init.xavier_uniform_(self.out_proj.weight, gain=nn.init.calculate_gain("relu"))
if self.out_proj.bias is not None:
nn.init.constant_(self.out_proj.bias, 0.0)
if self.bias_k is not None:
nn.init.xavier_normal_(self.bias_k)
if self.bias_v is not None:
nn.init.xavier_normal_(self.bias_v)
def forward(
self,
query,
key: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[torch.Tensor] = None,
need_head_weights: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
if self.self_attention:
q = self.q_proj(query)
k = self.k_proj(query)
v = self.v_proj(query)
else:
assert key is not None and value is not None
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
if self.bias_k is not None:
assert self.bias_v is not None
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = torch.cat(
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[
key_padding_mask,
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
],
dim=1,
)
# ----------------------------------------------------------------------
# Flash Attention Optimization
# ----------------------------------------------------------------------
# 如果不需要返回 head weights 且 PyTorch 版本支持,则使用 Flash Attention
if not need_head_weights and hasattr(F, "scaled_dot_product_attention"):
# Reshape inputs to (Batch, Head, Seq_Len, Dim) for SDPA
# q, k, v input shape: (Seq_Len, Batch, Embed_Dim)
q_sdpa = q.view(tgt_len, bsz, self.num_heads, self.head_dim).permute(1, 2, 0, 3)
k_sdpa = k.view(-1, bsz, self.num_heads, self.head_dim).permute(1, 2, 0, 3)
v_sdpa = v.view(-1, bsz, self.num_heads, self.head_dim).permute(1, 2, 0, 3)
# Apply Rotary Embedding if needed
if self.rot_emb:
# Rotary expects inputs (..., Seq_Len, Dim)
# It handles broadcasting over Batch and Head
q_sdpa, k_sdpa = self.rot_emb(q_sdpa, k_sdpa)
# Prepare Mask
# SDPA accepts a broadcastable boolean mask or float mask
# key_padding_mask is (Batch, Seq_Len), True where padding
sdpa_mask = None
if attn_mask is not None or key_padding_mask is not None:
# Start with a float mask suitable for SDPA
target_shape = (bsz, 1, tgt_len, k_sdpa.size(2))
sdpa_mask = torch.zeros(target_shape, device=q.device, dtype=q.dtype)
if key_padding_mask is not None:
# key_padding_mask is (Batch, Seq_Len) -> (Batch, 1, 1, Seq_Len)
sdpa_mask = sdpa_mask.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf")
)
if attn_mask is not None:
if attn_mask.dim() == 2:
sdpa_mask = sdpa_mask + attn_mask.unsqueeze(0).unsqueeze(0)
elif attn_mask.dim() == 3:
pass
else:
sdpa_mask = sdpa_mask + attn_mask
# Call Flash Attention
# 【关键修改】:添加 scale=1.0,因为 q 已经被手动缩放过了
attn_output = F.scaled_dot_product_attention(
q_sdpa,
k_sdpa,
v_sdpa,
attn_mask=sdpa_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=False
)
# Reshape back to (Seq_Len, Batch, Embed_Dim)
# (B, H, L, D) -> (L, B, H, D) -> (L, B, E)
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(tgt_len, bsz, self.embed_dim)
# Linear projection
attn_output = self.out_proj(attn_output)
# Return None for weights (optimization trade-off)
return attn_output, None
q = q * self.scaling
# ----------------------------------------------------------------------
# Original Implementation (Fallback)
# ----------------------------------------------------------------------
# print('Fall back to slow implementation.')
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
if k is not None:
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
if v is not None:
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
assert k is not None
src_len = k.size(1)
if self.rot_emb:
q, k = self.rot_emb(q, k)
attn_weights = torch.bmm(q, k.transpose(1, 2))
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
attn_weights += attn_mask
if key_padding_mask is not None:
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights_float = F.softmax(attn_weights, dim=-1)
attn_weights = attn_weights_float.type_as(attn_weights)
attn_probs = F.dropout(
attn_weights_float.type_as(attn_weights),
p=self.dropout,
training=self.training,
)
assert v is not None
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
attn_weights_output: Optional[torch.Tensor] = None
if need_weights:
attn_weights_output = attn_weights_float.view(
bsz, self.num_heads, tgt_len, src_len
).type_as(attn).transpose(1, 0)
if not need_head_weights:
# average attention weights over heads
attn_weights_output = attn_weights_output.mean(dim=0)
return attn, attn_weights_output
class LucaGPLMRobertaLMHead(nn.Module):
def __init__(self, embed_dim, output_dim):
super().__init__()
self.dense = nn.Linear(embed_dim, embed_dim)
self.layer_norm = LucaGPLM1bLayerNorm(embed_dim)
# 使用标准的 nn.Linear
self.decoder = nn.Linear(embed_dim, output_dim, bias=False)
self.bias = nn.Parameter(torch.zeros(output_dim))
def forward(self, features):
x = self.dense(features)
x = gelu(x)
x = self.layer_norm(x)
# project back to size of vocabulary with bias
# x = F.linear(x, self.weight) + self.bias
x = self.decoder(x) + self.bias
return x
class LucaGPLMTransformerLayer(nn.Module):
def __init__(
self,
embed_dim,
ffn_embed_dim,
attention_heads,
add_bias_kv=True,
use_lucagplm1b_layer_norm=False,
use_rotary_embeddings: bool=True,
):
super().__init__()
self.embed_dim = embed_dim
self.ffn_embed_dim = ffn_embed_dim
self.attention_heads = attention_heads
self.use_rotary_embeddings = use_rotary_embeddings
LucaGPLMLayerNorm = LucaGPLM1bLayerNorm if use_lucagplm1b_layer_norm else LucaGPLM1LayerNorm
self.pre_layer_norm = LucaGPLMLayerNorm(self.embed_dim)
self.self_attn = LucaGPLMMultiheadAttentionWithSDPA(
self.embed_dim,
self.attention_heads,
add_bias_kv=add_bias_kv,
add_zero_attn=False,
self_attention=True,
use_rotary_embeddings=self.use_rotary_embeddings,
)
# post layer norm
self.post_layer_norm = LucaGPLMLayerNorm(self.embed_dim)
# dimension increase by the fully connected layer
self.fc1 = nn.Linear(self.embed_dim, self.ffn_embed_dim)
# dimension reduction by the fully connected layer
self.fc2 = nn.Linear(self.ffn_embed_dim, self.embed_dim)
def forward(
self,
x,
self_attn_mask=None,
self_attn_padding_mask=None,
need_head_weights=False
):
residual = x
x = self.pre_layer_norm(x)
x, attn = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
need_weights=True,
need_head_weights=need_head_weights,
attn_mask=self_attn_mask,
)
x = residual + x
residual = x
x = self.post_layer_norm(x)
x = gelu(self.fc1(x))
x = self.fc2(x)
x = residual + x
return x, attn
class LucaGPLMEmbeddings(nn.Module):
def __init__(self, config: LucaGPLMConfig):
super().__init__()
# Store config flags for forward pass
self.no_position_embeddings = getattr(config, 'no_position_embeddings', False)
self.no_token_type_embeddings = getattr(config, 'no_token_type_embeddings', False)
self.use_embed_layer_norm = getattr(config, 'use_embed_layer_norm', True)
self.embed_scale = getattr(config, 'embed_scale', 1.0)
self.token_dropout = getattr(config, 'token_dropout', False)
# Token ids for special tokens (matching old model)
self.mask_idx = getattr(config, 'mask_token_id', 4)
self.padding_idx = getattr(config, 'pad_token_id', 0)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
# Only create position embeddings if not disabled
if not self.no_position_embeddings:
self.embed_pos = nn.Embedding(config.max_position_embeddings, config.hidden_size)
else:
self.embed_pos = None
# Only create token type embeddings if not disabled
if not self.no_token_type_embeddings:
self.embed_type = nn.Embedding(config.type_vocab_size, config.hidden_size)
else:
self.embed_type = None
# Only create layer norm if enabled
if self.use_embed_layer_norm:
self.embed_layer_norm = LucaGPLM1bLayerNorm(config.hidden_size)
else:
self.embed_layer_norm = None
def forward(
self,
input_ids: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
input_shape = input_ids.size()
seq_length = input_shape[1]
# Start with token embeddings and apply embed_scale
inputs_embeds = self.embed_scale * self.embed_tokens(input_ids)
# Add position embeddings if enabled
if not self.no_position_embeddings and self.embed_pos is not None:
if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand(input_shape)
position_embeddings = self.embed_scale * self.embed_pos(position_ids)
inputs_embeds = inputs_embeds + position_embeddings
# Add token type embeddings if enabled
if not self.no_token_type_embeddings and self.embed_type is not None:
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=input_ids.device)
token_type_embeddings = self.embed_scale * self.embed_type(token_type_ids)
inputs_embeds = inputs_embeds + token_type_embeddings
# Apply layer norm if enabled
if self.use_embed_layer_norm and self.embed_layer_norm is not None:
embeddings = self.embed_layer_norm(inputs_embeds)
else:
embeddings = inputs_embeds
# Apply token dropout (matching old model behavior)
if self.token_dropout and self.training:
# Zero out masked token embeddings
embeddings = embeddings.masked_fill((input_ids == self.mask_idx).unsqueeze(-1), 0.0)
# Apply token dropout scaling
mask_ratio_train = 0.15 * 0.8
padding_mask = input_ids.eq(self.padding_idx)
src_lengths = (~padding_mask).sum(-1)
mask_ratio_observed = (input_ids == self.mask_idx).sum(-1).to(embeddings.dtype) / src_lengths
embeddings = embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
# Apply padding mask to embeddings
padding_mask = input_ids.eq(self.padding_idx)
if padding_mask.any():
embeddings = embeddings * (1 - padding_mask.unsqueeze(-1).type_as(embeddings))
return embeddings
class LucaGPLMEncoder(nn.Module):
def __init__(self, config: LucaGPLMConfig):
super().__init__()
self.layers = nn.ModuleList([
LucaGPLMTransformerLayer(
config.hidden_size,
4 * config.hidden_size, # ffn_embed_dim = 4 * embed_dim
config.num_attention_heads,
add_bias_kv=False,
use_lucagplm1b_layer_norm=True,
use_rotary_embeddings=True,
)
for _ in range(config.num_hidden_layers)
])
self.use_last_layer_norm = getattr(config, 'use_last_layer_norm', True)
if self.use_last_layer_norm:
self.last_layer_norm = LucaGPLM1bLayerNorm(config.hidden_size)
else:
self.last_layer_norm = None
self.padding_idx = config.pad_token_id
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
need_head_weights: bool = False,
repr_layers: Optional[List[int]] = None,
use_last_layer_norm: bool = True,
) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
if repr_layers is None:
repr_layers = [-1]
# 转换为原始模型的索引系统
layer_size = len(self.layers)
repr_layers = [(i + layer_size + 1) % (layer_size + 1) for i in repr_layers]
repr_layers = set(repr_layers)
hidden_representations = {}
# Process attention mask - 原始模型期望的是padding mask
if attention_mask is None:
padding_mask = hidden_states.new_zeros(hidden_states.shape[:2]).eq(self.padding_idx)
else:
# 原始模型中 padding_mask 是 True 表示 padding位置
padding_mask = attention_mask.eq(0)
# 0: embedding layer
if 0 in repr_layers:
hidden_representations[0] = hidden_states
# 转换为 (seq_len, batch_size, hidden_size) 格式,与原始模型一致
hidden_states = hidden_states.transpose(0, 1)
if not padding_mask.any():
padding_mask = None
# 是否需要返回head weights
if need_head_weights or output_attentions:
attn_weights = []
for layer_idx, layer_module in enumerate(self.layers):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states.transpose(0, 1),)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
None, # self_attn_mask
padding_mask,
need_head_weights or output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
self_attn_mask=None,
self_attn_padding_mask=padding_mask,
need_head_weights=need_head_weights or output_attentions,
)
hidden_states, attn = layer_outputs
if (layer_idx + 1) in repr_layers:
hidden_representations[layer_idx + 1] = hidden_states.transpose(0, 1)
if need_head_weights or output_attentions:
# (H, B, L, L) => (B, H, L, L)
attn_weights.append(attn.transpose(1, 0))
# 应用最后的layer norm
if self.last_layer_norm is not None and use_last_layer_norm:
hidden_states = self.last_layer_norm(hidden_states)
# 转换回 (batch_size, seq_len, hidden_size) 格式
hidden_states = hidden_states.transpose(0, 1)
# last hidden representation should have layer norm applied
if (layer_idx + 1) in repr_layers:
hidden_representations[layer_idx + 1] = hidden_states
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if need_head_weights or output_attentions:
# 将attention weights转换为正确格式
if attn_weights:
# B x Layers x H x L x L
all_attentions = torch.stack(attn_weights, 1)
if padding_mask is not None:
attention_mask_expanded = 1 - padding_mask.type_as(all_attentions)
attention_mask_expanded = attention_mask_expanded.unsqueeze(1) * attention_mask_expanded.unsqueeze(2)
all_attentions = all_attentions * attention_mask_expanded[:, None, None, :, :]
if not output_attentions:
all_attentions = None
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
class LucaGPLMPreTrainedModel(PreTrainedModel):
config_class = LucaGPLMConfig
base_model_prefix = "lucaone"
supports_gradient_checkpointing = True
_no_split_modules = ["LucaGPLMTransformerLayer"]
def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, (LucaGPLM1LayerNorm, LucaGPLM1bLayerNorm)):
if hasattr(module, 'weight') and module.weight is not None:
module.weight.data.fill_(1.0)
if hasattr(module, 'bias') and module.bias is not None:
module.bias.data.zero_()
class LucaGPLMModel(LucaGPLMPreTrainedModel):
"""
The LucaGPLM model for extracting sequence representations and optionally predicting contacts.
Based on the original LucaGPLM implementation but restructured to use modern transformers architecture.
"""
def __init__(self, config: LucaGPLMConfig):
super().__init__(config)
self.config = config
self.embeddings = LucaGPLMEmbeddings(self.config)
self.encoder = LucaGPLMEncoder(self.config)
self.post_init()
def get_input_embeddings(self):
return self.embeddings.embed_tokens
def set_input_embeddings(self, value):
self.embeddings.embed_tokens = value
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_contacts: Optional[bool] = None,
return_dict: Optional[bool] = None,
need_head_weights: Optional[bool] = None,
repr_layers: Optional[List[int]] = None,
use_last_layer_norm: Optional[bool] = True,
) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else getattr(self.config, 'output_attentions', False)
output_hidden_states = output_hidden_states if output_hidden_states is not None else getattr(self.config, 'output_hidden_states', False)
return_contacts = return_contacts if return_contacts is not None else False
return_dict = return_dict if return_dict is not None else getattr(self.config, 'use_return_dict', True)
need_head_weights = need_head_weights if need_head_weights is not None else return_contacts # Need attention weights for contacts
use_last_layer_norm = use_last_layer_norm if use_last_layer_norm is not None else True
# Force output_attentions=True when return_contacts=True since we need attention weights
if return_contacts:
output_attentions = True
need_head_weights = True
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device
# Create attention mask if not provided
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device)
# Get embeddings
if inputs_embeds is None:
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
)
else:
embedding_output = inputs_embeds
# Pass through encoder
encoder_outputs = self.encoder(
embedding_output,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
need_head_weights=need_head_weights,
repr_layers=repr_layers,
use_last_layer_norm=use_last_layer_norm,
)
sequence_output = encoder_outputs[0]
# Handle contact prediction
contacts = None
if return_contacts and encoder_outputs.attentions is not None:
# Simple contact prediction using attention weights
# This is a simplified implementation - you can enhance this later
attentions = encoder_outputs.attentions
# Average over layers and heads, then symmetrize
averaged_attention = attentions.mean(dim=(1, 2)) # Average over layers and heads
contacts = (averaged_attention + averaged_attention.transpose(-1, -2)) / 2
# Remove special tokens (BOS/EOS) if present
if attention_mask is not None:
# Find actual sequence positions (non-padding)
seq_lens = attention_mask.sum(dim=1)
# For now, keep the full contact map - you can trim special tokens later if needed
if not return_dict:
outputs = (sequence_output, ) + encoder_outputs[1:]
if contacts is not None:
outputs = outputs + (contacts,)
return outputs
# Create output object with contacts
output = BaseModelOutput(
last_hidden_state=sequence_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
# Add contacts as an attribute if computed
if contacts is not None:
output.contacts = contacts
return output
class LucaGPLMForMaskedLM(LucaGPLMPreTrainedModel):
def __init__(self, config):
super().__init__(config)
# 基础编码器
self.lucaone = LucaGPLMModel(config)
# MLM 预测头
self.lm_head = LucaGPLMRobertaLMHead(
embed_dim=config.hidden_size,
output_dim=config.vocab_size
)
self._tied_weights_keys = [
"lucaone.embeddings.embed_tokens.weight",
"lm_head.decoder.weight"
]
# 初始化权重并进行权重绑定
self.post_init()
def get_input_embeddings(self):
return self.lucaone.get_input_embeddings()
def get_output_embeddings(self):
return self.lm_head.decoder
def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None, # MLM 训练时的标签
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, MaskedLMOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 1. 获取基础模型的输出 (Hidden States)
outputs = self.lucaone(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0] # (batch_size, seq_len, hidden_size)
# 2. 通过 MLM Head 得到预测结果 (Logits)
prediction_scores = self.lm_head(sequence_output)
masked_lm_loss = None
if labels is not None:
# 3. 计算 MLM Loss
loss_fct = nn.CrossEntropyLoss(ignore_index=-100) # 默认 ignore_index=-100
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return MaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class LucaGPLMForSequenceClassification(LucaGPLMPreTrainedModel):
def __init__(self, config):
if hasattr(config, "classifier_num_labels") and config.classifier_num_labels > 0:
config.num_labels = config.classifier_num_labels
super().__init__(config)
self.num_labels = config.num_labels
self.task_level = config.task_level
self.task_type = config.task_type
assert self.task_level == "seq_level"
self.classifier_pooling_type = config.classifier_pooling_type
self.classifier_loss_type = config.classifier_loss_type
self.classifier_loss_reduction = config.classifier_loss_reduction
self.classifier_pos_weight = config.classifier_pos_weight
self.classifier_weight = config.classifier_weight
self.lucaone = LucaGPLMModel(config) # 基础模型
if self.classifier_pooling_type == "value_attention":
self.pooler = LucaGPLMGlobalMaskValueAttentionPooling1D(config.hidden_size)
elif self.classifier_pooling_type == "context_attention":
self.pooler = LucaGPLMGlobalMaskContextAttentionPooling1D(embed_size=config.hidden_size)
elif self.classifier_pooling_type == "weighted_attention":
self.pooler = LucaGPLMGlobalMaskWeightedAttentionPooling1D(embed_size=config.hidden_size)
else:
self.pooler = None
self.dropout = nn.Dropout(config.classifier_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
if self.task_type == "multi_class":
weight = None
if self.classifier_weight:
if isinstance(self.classifier_weight, str) or isinstance(self.classifier_weight, int):
weight = torch.tensor([float(self.classifier_weight)] * self.num_labels, dtype=torch.float32)
elif isinstance(self.classifier_weight, float):
weight = torch.tensor([self.classifier_weight] * self.num_labels, dtype=torch.float32)
elif isinstance(self.classifier_weight, list):
weight = torch.tensor(self.classifier_weight, dtype=torch.float32)
self.loss_fct = nn.CrossEntropyLoss(weight=weight, reduction="mean")
elif self.task_type == "binary_class":
pos_weight = None
if self.classifier_pos_weight:
if isinstance(self.classifier_pos_weight, str) or isinstance(self.classifier_pos_weight, int):
pos_weight = torch.tensor([float(self.classifier_pos_weight)], dtype=torch.float32)
elif isinstance(self.classifier_pos_weight, float):
pos_weight = torch.tensor([self.classifier_pos_weight], dtype=torch.float32)
self.loss_fct = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction="mean")
elif self.task_type == "regression":
if self.classifier_loss_type == "mae":
self.loss_fct = nn.L1Loss(reduction="mean")
else:
self.loss_fct = nn.MSELoss(reduction="mean")
elif self.task_type == "multi_label":
pos_weight = None
if self.classifier_pos_weight:
if isinstance(self.classifier_pos_weight, str) or isinstance(self.classifier_pos_weight, int):
pos_weight = torch.tensor([float(self.classifier_pos_weight)] * self.num_labels, dtype=torch.float32)
elif isinstance(self.classifier_pos_weight, float):
pos_weight = torch.tensor([self.classifier_pos_weight] * self.num_labels, dtype=torch.float32)
self.loss_fct = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction=self.classifier_loss_reduction)
else:
raise ValueError("Invalid task type: %s" % self.task_type)
self.post_init()
def forward(
self,
input_ids=None,
token_type_ids=None,
attention_mask=None,
labels=None,
return_dict=None
):
return_dict = return_dict if return_dict is not None else getattr(self.config, 'use_return_dict', True)
outputs = self.lucaone(
input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
return_dict=return_dict
)
if self.pooler is not None:
pooled_output = self.pooler(outputs[0])
elif self.classifier_pooling_type == "cls":
# 取 CLS token
pooled_output = outputs[0][:, 0, :]
elif self.classifier_pooling_type == "mean":
pooled_output = outputs[0].mean(dim=1)
else:
raise ValueError("Invalid classifier pooling type: %s" % self.classifier_pooling_type)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
if self.task_type == "multi_class":
loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.task_type == "binary_class":
loss = self.loss_fct(logits.view(-1), labels.view(-1).float())
elif self.task_type == "regression":
loss = self.loss_fct(logits.view(-1), labels.view(-1))
elif self.task_type == "multi_label":
loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels).float())
else:
raise ValueError("Invalid task type: %s" % self.task_type)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(loss=loss, logits=logits)
class LucaGPLMForTokenClassification(LucaGPLMPreTrainedModel):
def __init__(self, config):
if hasattr(config, "classifier_num_labels") and config.classifier_num_labels > 0:
config.num_labels = config.classifier_num_labels
super().__init__(config)
self.num_labels = config.num_labels
self.task_level = config.task_level
self.task_type = config.task_type
assert self.task_level == "token_level"
self.classifier_pooling_type = config.classifier_pooling_type
self.classifier_loss_type = config.classifier_loss_type
self.classifier_loss_reduction = config.classifier_loss_reduction
self.classifier_pos_weight = config.classifier_pos_weight
self.classifier_weight = config.classifier_weight
self.lucaone = LucaGPLMModel(config) # 基础模型
self.dropout = nn.Dropout(config.classifier_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
if self.task_type == "multi_class":
weight = None
if self.classifier_weight:
# [1, 1, 1, ,1, 1...] length: num_labels
if isinstance(self.classifier_weight, str) or isinstance(self.classifier_weight, int):
weight = torch.tensor([float(self.classifier_weight)] * self.num_labels, dtype=torch.float32)
elif isinstance(self.classifier_weight, float):
weight = torch.tensor([self.classifier_weight] * self.num_labels, dtype=torch.float32)
elif isinstance(self.classifier_weight, list):
weight = torch.tensor(self.classifier_weight, dtype=torch.float32)
self.loss_fct = nn.CrossEntropyLoss(weight=weight, reduction="mean")
elif self.task_type == "binary_class":
pos_weight = None
if self.classifier_pos_weight:
if isinstance(self.classifier_pos_weight, str) or isinstance(self.classifier_pos_weight, int):
pos_weight = torch.tensor([float(self.classifier_pos_weight)], dtype=torch.float32)
elif isinstance(self.classifier_pos_weight, float):
pos_weight = torch.tensor([float(self.classifier_pos_weight)], dtype=torch.float32)
self.loss_fct = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction="mean")
elif self.task_type == "regression":
if self.classifier_loss_type == "mae":
self.loss_fct = nn.L1Loss(reduction="mean")
else:
self.loss_fct = nn.MSELoss(reduction="mean")
elif self.task_type == "multi_label":
pos_weight = None
if self.classifier_pos_weight:
if isinstance(self.classifier_pos_weight, str) or isinstance(self.classifier_pos_weight, int):
pos_weight = torch.tensor([float(self.classifier_pos_weight)] * self.num_labels, dtype=torch.float32)
elif isinstance(self.classifier_pos_weight, float):
pos_weight = torch.tensor([self.classifier_pos_weight] * self.num_labels, dtype=torch.float32)
self.loss_fct = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction=self.classifier_loss_reduction)
else:
raise ValueError("Invalid task type: %s" % self.task_type)
self.post_init()
def forward(
self,
input_ids=None,
token_type_ids=None,
attention_mask=None,
labels=None,
return_dict=None
):
return_dict = return_dict if return_dict is not None else getattr(self.config, 'use_return_dict', True)
outputs = self.lucaone(
input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
return_dict=return_dict
)
sequence_output = outputs[0][:, 1:-1, :] # (B, L, H)
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
if self.task_type == "multi_class":
loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.task_type == "binary_class":
loss = self.loss_fct(logits.view(-1), labels.view(-1).float())
elif self.task_type == "regression":
loss = self.loss_fct(logits.view(-1), labels.view(-1))
elif self.task_type == "multi_label":
loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels).float())
else:
raise ValueError("Invalid task type: %s" % self.task_type)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(loss=loss, logits=logits)
__all__ = [
"LucaGPLMModel",
"LucaGPLMPreTrainedModel",
"LucaGPLMForMaskedLM",
"LucaGPLMForSequenceClassification",
"LucaGPLMForTokenClassification"
]