import os import torch import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # ===================== OPTIMIZED EMG MODEL ===================== class OptimizedEMGCell(nn.Module): def __init__(self, input_size, hidden_size, dropout_rate=0.1, use_layer_norm=False): super(OptimizedEMGCell, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.use_layer_norm = use_layer_norm self.clamp_min = -1 self.clamp_max = 1 # Fused linear transformations for better efficiency self.input_transform_linear = nn.Linear(input_size, hidden_size * 2) self.hidden_transform_linear = nn.Linear(hidden_size, hidden_size * 2) # SIMPLIFIED: Use standard dropout instead of variational self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else None # Layer normalization for training stability if use_layer_norm: self.input_norm = nn.LayerNorm(hidden_size) self.hidden_norm = nn.LayerNorm(hidden_size) self.cell_norm = nn.LayerNorm(hidden_size) self.init_weights() def init_weights(self): for linear in [self.input_transform_linear, self.hidden_transform_linear]: # Use smaller initialization for RNN stability nn.init.uniform_(linear.weight, -0.1, 0.1) nn.init.zeros_(linear.bias) def forward(self, input, hidden): h_prev, c_prev = hidden # Project input and hidden states input_connections = self.input_transform_linear(input) hidden_connections = self.hidden_transform_linear(h_prev) # Split projections i_move, i_merge = torch.chunk(input_connections, 2, dim=-1) h_move, h_merge = torch.chunk(hidden_connections, 2, dim=-1) # EMG computation # merge_gate = torch.clamp(i_merge, self.clamp_min, self.clamp_max) * torch.sigmoid(torch.clamp(h_merge, self.clamp_min, self.clamp_max)) merge_gate = torch.clamp(i_merge * torch.sigmoid(h_merge), self.clamp_min, self.clamp_max) move_gate = torch.clamp(torch.sigmoid(i_move) * h_move, self.clamp_min, self.clamp_max) if self.use_layer_norm: c_prev = self.cell_norm(c_prev) context_gate = torch.tanh(torch.clamp(c_prev + merge_gate, self.clamp_min, self.clamp_max)) if self.use_layer_norm: context_gate = self.input_norm(context_gate) c_next = context_gate if self.use_layer_norm: c_next = self.hidden_norm(c_next) # Apply dropout to output instead of complex variational dropout m_next = (1 - move_gate) * merge_gate + move_gate * c_next if self.dropout is not None: m_next = self.dropout(m_next) return m_next, c_next class OptimizedEMG(nn.Module): """Enhanced EMG with gradient checkpointing and other optimizations""" def __init__(self, input_size, hidden_size, num_layers, dropout_rate=0.1, use_gradient_checkpointing=False): super(OptimizedEMG, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.num_layers = num_layers self.use_gradient_checkpointing = use_gradient_checkpointing self.cells = nn.ModuleList([ OptimizedEMGCell( input_size if i == 0 else hidden_size, hidden_size, dropout_rate ) for i in range(num_layers) ]) def forward(self, x, hidden=None): batch_size, seq_len, _ = x.size() if hidden is None: hidden = [(torch.zeros(batch_size, self.hidden_size, device=x.device), torch.zeros(batch_size, self.hidden_size, device=x.device)) for _ in range(self.num_layers)] outputs = [] for t in range(seq_len): layer_input = x[:, t, :] for layer_idx, cell in enumerate(self.cells): m_prev, c_prev = hidden[layer_idx] if self.use_gradient_checkpointing and self.training: m_next, c_next = torch.utils.checkpoint.checkpoint( cell, layer_input, (m_prev, c_prev), use_reentrant=False ) else: m_next, c_next = cell(layer_input, (m_prev, c_prev)) hidden[layer_idx] = (m_next, c_next) layer_input = m_next outputs.append(layer_input) output = torch.stack(outputs, dim=1) return output, hidden # ===================== HUGGING FACE COMPATIBLE MODEL ===================== class EMGConfig(PretrainedConfig): """Configuration class for EMG model""" model_type = "emg" def __init__( self, vocab_size=50000, embedding_dim=512, hidden_dim=512, num_layers=2, dropout=0.1, use_layer_norm=True, use_gradient_checkpointing=False, tie_word_embeddings=True, **kwargs ): super().__init__(**kwargs) self.vocab_size = vocab_size self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim self.num_layers = num_layers self.dropout = dropout self.use_layer_norm = use_layer_norm self.use_gradient_checkpointing = use_gradient_checkpointing self.tie_word_embeddings = tie_word_embeddings class EMGLanguageModel(PreTrainedModel): """Hugging Face compatible EMG Language Model""" config_class = EMGConfig def __init__(self, config): super().__init__(config) self.config = config self.embedding = nn.Embedding(config.vocab_size, config.embedding_dim) self.emg = OptimizedEMG( config.embedding_dim, config.hidden_dim, config.num_layers, config.dropout, config.use_gradient_checkpointing ) self.output_projection = nn.Linear(config.hidden_dim, config.vocab_size) # Tie embedding and output weights if dimensions match if config.tie_word_embeddings and config.embedding_dim == config.hidden_dim: self.output_projection.weight = self.embedding.weight # Initialize weights self.apply(self._init_weights) def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, input_ids, hidden=None, labels=None, **kwargs): embedded = self.embedding(input_ids) output, hidden = self.emg(embedded, hidden) logits = self.output_projection(output) loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = nn.CrossEntropyLoss(ignore_index=-100) loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) return {'loss': loss, 'logits': logits, 'hidden_states': hidden} def generate(self, input_ids, max_length=50, temperature=1.0, top_k=50): self.eval() generated = input_ids hidden = None for _ in range(max_length): outputs = self.forward(generated[:, -1:], hidden) logits = outputs['logits'][:, -1, :] / temperature # Top-k sampling top_k_logits, top_k_indices = torch.topk(logits, top_k) probs = F.softmax(top_k_logits, dim=-1) next_token = top_k_indices.gather(1, torch.multinomial(probs, num_samples=1)) generated = torch.cat([generated, next_token], dim=1) return generated