File size: 1,802 Bytes
ace9173 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import torch
import torch.nn as nn
from transformers import Qwen2PreTrainedModel, Qwen2Config
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm, Qwen2DecoderLayer
class Qwen2Connector(Qwen2PreTrainedModel):
def __init__(self, config: Qwen2Config):
super().__init__(config)
self.layers = nn.ModuleList(
[Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
for layer in self.layers:
layer.self_attn.is_causal = False
self._attn_implementation = config._attn_implementation
assert self._attn_implementation == 'flash_attention_2'
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def forward(self, inputs_embeds):
position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device)
position_ids = position_ids.expand(inputs_embeds.shape[0], -1)
hidden_states = inputs_embeds
for encoder_layer in self.layers:
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
None,
position_ids,
use_reentrant=False
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask=None,
position_ids=position_ids,
)
hidden_states = layer_outputs[0]
hidden_states = self.norm(hidden_states)
return hidden_states
|