|
|
import os |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import PreTrainedModel, PretrainedConfig |
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OptimizedEMGCell(nn.Module): |
|
|
def __init__(self, input_size, hidden_size, dropout_rate=0.1, use_layer_norm=False): |
|
|
super(OptimizedEMGCell, self).__init__() |
|
|
self.input_size = input_size |
|
|
self.hidden_size = hidden_size |
|
|
self.use_layer_norm = use_layer_norm |
|
|
self.clamp_min = -1 |
|
|
self.clamp_max = 1 |
|
|
|
|
|
|
|
|
self.input_transform_linear = nn.Linear(input_size, hidden_size * 2) |
|
|
self.hidden_transform_linear = nn.Linear(hidden_size, hidden_size * 2) |
|
|
|
|
|
|
|
|
self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else None |
|
|
|
|
|
|
|
|
if use_layer_norm: |
|
|
self.input_norm = nn.LayerNorm(hidden_size) |
|
|
self.hidden_norm = nn.LayerNorm(hidden_size) |
|
|
self.cell_norm = nn.LayerNorm(hidden_size) |
|
|
|
|
|
self.init_weights() |
|
|
|
|
|
def init_weights(self): |
|
|
for linear in [self.input_transform_linear, self.hidden_transform_linear]: |
|
|
|
|
|
nn.init.uniform_(linear.weight, -0.1, 0.1) |
|
|
nn.init.zeros_(linear.bias) |
|
|
|
|
|
def forward(self, input, hidden): |
|
|
h_prev, c_prev = hidden |
|
|
|
|
|
|
|
|
input_connections = self.input_transform_linear(input) |
|
|
hidden_connections = self.hidden_transform_linear(h_prev) |
|
|
|
|
|
|
|
|
i_move, i_merge = torch.chunk(input_connections, 2, dim=-1) |
|
|
h_move, h_merge = torch.chunk(hidden_connections, 2, dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
merge_gate = torch.clamp(i_merge * torch.sigmoid(h_merge), self.clamp_min, self.clamp_max) |
|
|
move_gate = torch.clamp(torch.sigmoid(i_move) * h_move, self.clamp_min, self.clamp_max) |
|
|
|
|
|
if self.use_layer_norm: |
|
|
c_prev = self.cell_norm(c_prev) |
|
|
|
|
|
context_gate = torch.tanh(torch.clamp(c_prev + merge_gate, self.clamp_min, self.clamp_max)) |
|
|
|
|
|
if self.use_layer_norm: |
|
|
context_gate = self.input_norm(context_gate) |
|
|
|
|
|
c_next = context_gate |
|
|
|
|
|
if self.use_layer_norm: |
|
|
c_next = self.hidden_norm(c_next) |
|
|
|
|
|
|
|
|
m_next = (1 - move_gate) * merge_gate + move_gate * c_next |
|
|
if self.dropout is not None: |
|
|
m_next = self.dropout(m_next) |
|
|
|
|
|
return m_next, c_next |
|
|
|
|
|
|
|
|
class OptimizedEMG(nn.Module): |
|
|
"""Enhanced EMG with gradient checkpointing and other optimizations""" |
|
|
def __init__(self, input_size, hidden_size, num_layers, dropout_rate=0.1, |
|
|
use_gradient_checkpointing=False): |
|
|
super(OptimizedEMG, self).__init__() |
|
|
self.input_size = input_size |
|
|
self.hidden_size = hidden_size |
|
|
self.num_layers = num_layers |
|
|
self.use_gradient_checkpointing = use_gradient_checkpointing |
|
|
|
|
|
self.cells = nn.ModuleList([ |
|
|
OptimizedEMGCell( |
|
|
input_size if i == 0 else hidden_size, |
|
|
hidden_size, |
|
|
dropout_rate |
|
|
) for i in range(num_layers) |
|
|
]) |
|
|
|
|
|
def forward(self, x, hidden=None): |
|
|
batch_size, seq_len, _ = x.size() |
|
|
|
|
|
if hidden is None: |
|
|
hidden = [(torch.zeros(batch_size, self.hidden_size, device=x.device), |
|
|
torch.zeros(batch_size, self.hidden_size, device=x.device)) |
|
|
for _ in range(self.num_layers)] |
|
|
|
|
|
outputs = [] |
|
|
|
|
|
for t in range(seq_len): |
|
|
layer_input = x[:, t, :] |
|
|
|
|
|
for layer_idx, cell in enumerate(self.cells): |
|
|
m_prev, c_prev = hidden[layer_idx] |
|
|
|
|
|
if self.use_gradient_checkpointing and self.training: |
|
|
m_next, c_next = torch.utils.checkpoint.checkpoint( |
|
|
cell, layer_input, (m_prev, c_prev), use_reentrant=False |
|
|
) |
|
|
else: |
|
|
m_next, c_next = cell(layer_input, (m_prev, c_prev)) |
|
|
|
|
|
hidden[layer_idx] = (m_next, c_next) |
|
|
layer_input = m_next |
|
|
|
|
|
outputs.append(layer_input) |
|
|
|
|
|
output = torch.stack(outputs, dim=1) |
|
|
return output, hidden |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EMGConfig(PretrainedConfig): |
|
|
"""Configuration class for EMG model""" |
|
|
model_type = "emg" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocab_size=50000, |
|
|
embedding_dim=512, |
|
|
hidden_dim=512, |
|
|
num_layers=2, |
|
|
dropout=0.1, |
|
|
use_layer_norm=True, |
|
|
use_gradient_checkpointing=False, |
|
|
tie_word_embeddings=True, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.vocab_size = vocab_size |
|
|
self.embedding_dim = embedding_dim |
|
|
self.hidden_dim = hidden_dim |
|
|
self.num_layers = num_layers |
|
|
self.dropout = dropout |
|
|
self.use_layer_norm = use_layer_norm |
|
|
self.use_gradient_checkpointing = use_gradient_checkpointing |
|
|
self.tie_word_embeddings = tie_word_embeddings |
|
|
|
|
|
|
|
|
class EMGLanguageModel(PreTrainedModel): |
|
|
"""Hugging Face compatible EMG Language Model""" |
|
|
config_class = EMGConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
|
|
|
self.embedding = nn.Embedding(config.vocab_size, config.embedding_dim) |
|
|
self.emg = OptimizedEMG( |
|
|
config.embedding_dim, |
|
|
config.hidden_dim, |
|
|
config.num_layers, |
|
|
config.dropout, |
|
|
config.use_gradient_checkpointing |
|
|
) |
|
|
self.output_projection = nn.Linear(config.hidden_dim, config.vocab_size) |
|
|
|
|
|
|
|
|
if config.tie_word_embeddings and config.embedding_dim == config.hidden_dim: |
|
|
self.output_projection.weight = self.embedding.weight |
|
|
|
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def _init_weights(self, module): |
|
|
"""Initialize the weights""" |
|
|
if isinstance(module, nn.Linear): |
|
|
nn.init.xavier_uniform_(module.weight) |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
elif isinstance(module, nn.Embedding): |
|
|
nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
|
|
|
def forward(self, input_ids, hidden=None, labels=None, **kwargs): |
|
|
embedded = self.embedding(input_ids) |
|
|
output, hidden = self.emg(embedded, hidden) |
|
|
logits = self.output_projection(output) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
|
|
|
|
|
loss_fct = nn.CrossEntropyLoss(ignore_index=-100) |
|
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), |
|
|
shift_labels.view(-1)) |
|
|
|
|
|
return {'loss': loss, 'logits': logits, 'hidden_states': hidden} |
|
|
|
|
|
def generate(self, input_ids, max_length=50, temperature=1.0, top_k=50): |
|
|
self.eval() |
|
|
generated = input_ids |
|
|
hidden = None |
|
|
|
|
|
for _ in range(max_length): |
|
|
outputs = self.forward(generated[:, -1:], hidden) |
|
|
logits = outputs['logits'][:, -1, :] / temperature |
|
|
|
|
|
|
|
|
top_k_logits, top_k_indices = torch.topk(logits, top_k) |
|
|
probs = F.softmax(top_k_logits, dim=-1) |
|
|
next_token = top_k_indices.gather(1, torch.multinomial(probs, num_samples=1)) |
|
|
|
|
|
generated = torch.cat([generated, next_token], dim=1) |
|
|
|
|
|
return generated |