emg-10m-conv_test / model_eMG_simplified.py
NeTS-lab's picture
Upload EMG model with MorPiece tokenizer
9e31d55 verified
import os
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# ===================== OPTIMIZED EMG MODEL =====================
class OptimizedEMGCell(nn.Module):
def __init__(self, input_size, hidden_size, dropout_rate=0.1, use_layer_norm=False):
super(OptimizedEMGCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.use_layer_norm = use_layer_norm
self.clamp_min = -1
self.clamp_max = 1
# Fused linear transformations for better efficiency
self.input_transform_linear = nn.Linear(input_size, hidden_size * 2)
self.hidden_transform_linear = nn.Linear(hidden_size, hidden_size * 2)
# SIMPLIFIED: Use standard dropout instead of variational
self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else None
# Layer normalization for training stability
if use_layer_norm:
self.input_norm = nn.LayerNorm(hidden_size)
self.hidden_norm = nn.LayerNorm(hidden_size)
self.cell_norm = nn.LayerNorm(hidden_size)
self.init_weights()
def init_weights(self):
for linear in [self.input_transform_linear, self.hidden_transform_linear]:
# Use smaller initialization for RNN stability
nn.init.uniform_(linear.weight, -0.1, 0.1)
nn.init.zeros_(linear.bias)
def forward(self, input, hidden):
h_prev, c_prev = hidden
# Project input and hidden states
input_connections = self.input_transform_linear(input)
hidden_connections = self.hidden_transform_linear(h_prev)
# Split projections
i_move, i_merge = torch.chunk(input_connections, 2, dim=-1)
h_move, h_merge = torch.chunk(hidden_connections, 2, dim=-1)
# EMG computation
# merge_gate = torch.clamp(i_merge, self.clamp_min, self.clamp_max) * torch.sigmoid(torch.clamp(h_merge, self.clamp_min, self.clamp_max))
merge_gate = torch.clamp(i_merge * torch.sigmoid(h_merge), self.clamp_min, self.clamp_max)
move_gate = torch.clamp(torch.sigmoid(i_move) * h_move, self.clamp_min, self.clamp_max)
if self.use_layer_norm:
c_prev = self.cell_norm(c_prev)
context_gate = torch.tanh(torch.clamp(c_prev + merge_gate, self.clamp_min, self.clamp_max))
if self.use_layer_norm:
context_gate = self.input_norm(context_gate)
c_next = context_gate
if self.use_layer_norm:
c_next = self.hidden_norm(c_next)
# Apply dropout to output instead of complex variational dropout
m_next = (1 - move_gate) * merge_gate + move_gate * c_next
if self.dropout is not None:
m_next = self.dropout(m_next)
return m_next, c_next
class OptimizedEMG(nn.Module):
"""Enhanced EMG with gradient checkpointing and other optimizations"""
def __init__(self, input_size, hidden_size, num_layers, dropout_rate=0.1,
use_gradient_checkpointing=False):
super(OptimizedEMG, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.use_gradient_checkpointing = use_gradient_checkpointing
self.cells = nn.ModuleList([
OptimizedEMGCell(
input_size if i == 0 else hidden_size,
hidden_size,
dropout_rate
) for i in range(num_layers)
])
def forward(self, x, hidden=None):
batch_size, seq_len, _ = x.size()
if hidden is None:
hidden = [(torch.zeros(batch_size, self.hidden_size, device=x.device),
torch.zeros(batch_size, self.hidden_size, device=x.device))
for _ in range(self.num_layers)]
outputs = []
for t in range(seq_len):
layer_input = x[:, t, :]
for layer_idx, cell in enumerate(self.cells):
m_prev, c_prev = hidden[layer_idx]
if self.use_gradient_checkpointing and self.training:
m_next, c_next = torch.utils.checkpoint.checkpoint(
cell, layer_input, (m_prev, c_prev), use_reentrant=False
)
else:
m_next, c_next = cell(layer_input, (m_prev, c_prev))
hidden[layer_idx] = (m_next, c_next)
layer_input = m_next
outputs.append(layer_input)
output = torch.stack(outputs, dim=1)
return output, hidden
# ===================== HUGGING FACE COMPATIBLE MODEL =====================
class EMGConfig(PretrainedConfig):
"""Configuration class for EMG model"""
model_type = "emg"
def __init__(
self,
vocab_size=50000,
embedding_dim=512,
hidden_dim=512,
num_layers=2,
dropout=0.1,
use_layer_norm=True,
use_gradient_checkpointing=False,
tie_word_embeddings=True,
**kwargs
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.dropout = dropout
self.use_layer_norm = use_layer_norm
self.use_gradient_checkpointing = use_gradient_checkpointing
self.tie_word_embeddings = tie_word_embeddings
class EMGLanguageModel(PreTrainedModel):
"""Hugging Face compatible EMG Language Model"""
config_class = EMGConfig
def __init__(self, config):
super().__init__(config)
self.config = config
self.embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
self.emg = OptimizedEMG(
config.embedding_dim,
config.hidden_dim,
config.num_layers,
config.dropout,
config.use_gradient_checkpointing
)
self.output_projection = nn.Linear(config.hidden_dim, config.vocab_size)
# Tie embedding and output weights if dimensions match
if config.tie_word_embeddings and config.embedding_dim == config.hidden_dim:
self.output_projection.weight = self.embedding.weight
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, input_ids, hidden=None, labels=None, **kwargs):
embedded = self.embedding(input_ids)
output, hidden = self.emg(embedded, hidden)
logits = self.output_projection(output)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1))
return {'loss': loss, 'logits': logits, 'hidden_states': hidden}
def generate(self, input_ids, max_length=50, temperature=1.0, top_k=50):
self.eval()
generated = input_ids
hidden = None
for _ in range(max_length):
outputs = self.forward(generated[:, -1:], hidden)
logits = outputs['logits'][:, -1, :] / temperature
# Top-k sampling
top_k_logits, top_k_indices = torch.topk(logits, top_k)
probs = F.softmax(top_k_logits, dim=-1)
next_token = top_k_indices.gather(1, torch.multinomial(probs, num_samples=1))
generated = torch.cat([generated, next_token], dim=1)
return generated