| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
|
| VOCAB_SIZE = 4
|
| SEQ_LEN = 64
|
| NUM_CLASSES = 3
|
|
|
| D_MODEL = 512
|
|
|
| CONFIG = {
|
|
|
|
|
|
|
|
|
|
|
| "cnn": {
|
| "blocks": 7,
|
| "channels": 960,
|
| "kernel": 3
|
| },
|
|
|
|
|
|
|
|
|
|
|
| "gru": {
|
| "hidden": 960,
|
| "layers": 4
|
| },
|
|
|
|
|
|
|
|
|
|
|
| "transformer": {
|
| "layers": 6,
|
| "heads": 8,
|
| "ffn": 2048,
|
| "dropout": 0.1
|
| },
|
|
|
|
|
|
|
|
|
|
|
| "mamba": {
|
| "layers": 10,
|
| "state_dim": 1408
|
| }
|
| }
|
|
|
|
|
|
|
|
|
|
|
| class CNNBlock(nn.Module):
|
| def __init__(self, channels, kernel):
|
| super().__init__()
|
|
|
| self.conv1 = nn.Conv1d(
|
| D_MODEL,
|
| channels,
|
| kernel_size=kernel,
|
| padding=kernel // 2
|
| )
|
|
|
| self.conv2 = nn.Conv1d(
|
| channels,
|
| D_MODEL,
|
| kernel_size=kernel,
|
| padding=kernel // 2
|
| )
|
|
|
| self.norm = nn.LayerNorm(D_MODEL)
|
|
|
| def forward(self, x):
|
|
|
|
|
|
|
| residual = x
|
|
|
| x = x.transpose(1, 2)
|
|
|
| x = self.conv1(x)
|
| x = F.gelu(x)
|
|
|
| x = self.conv2(x)
|
| x = F.gelu(x)
|
|
|
| x = x.transpose(1, 2)
|
|
|
| x = x + residual
|
|
|
| return self.norm(x)
|
|
|
| class CNNExpert(nn.Module):
|
| def __init__(self, config):
|
| super().__init__()
|
|
|
| self.blocks = nn.ModuleList([
|
| CNNBlock(
|
| channels=config["channels"],
|
| kernel=config["kernel"]
|
| )
|
| for _ in range(config["blocks"])
|
| ])
|
|
|
| self.norm = nn.LayerNorm(D_MODEL)
|
|
|
| def forward(self, x):
|
|
|
| for block in self.blocks:
|
| x = block(x)
|
|
|
| return self.norm(x)
|
|
|
|
|
|
|
|
|
|
|
| class GRUExpert(nn.Module):
|
| def __init__(self, config):
|
| super().__init__()
|
|
|
| self.gru = nn.GRU(
|
| input_size=D_MODEL,
|
| hidden_size=config["hidden"],
|
| num_layers=config["layers"],
|
| batch_first=True
|
| )
|
|
|
| self.proj = nn.Linear(
|
| config["hidden"],
|
| D_MODEL
|
| )
|
|
|
| self.norm = nn.LayerNorm(D_MODEL)
|
|
|
| def forward(self, x):
|
|
|
| x, _ = self.gru(x)
|
|
|
| x = self.proj(x)
|
|
|
| return self.norm(x)
|
|
|
|
|
|
|
|
|
|
|
| class TransformerExpert(nn.Module):
|
| def __init__(self, config):
|
| super().__init__()
|
|
|
| encoder_layer = nn.TransformerEncoderLayer(
|
| d_model=D_MODEL,
|
| nhead=config["heads"],
|
| dim_feedforward=config["ffn"],
|
| dropout=config["dropout"],
|
| batch_first=True,
|
| activation="gelu"
|
| )
|
|
|
| self.encoder = nn.TransformerEncoder(
|
| encoder_layer,
|
| num_layers=config["layers"]
|
| )
|
|
|
| self.norm = nn.LayerNorm(D_MODEL)
|
|
|
| def forward(self, x):
|
|
|
| x = self.encoder(x)
|
|
|
| return self.norm(x)
|
|
|
|
|
|
|
|
|
|
|
| class MambaLikeBlock(nn.Module):
|
| def __init__(self, state_dim):
|
| super().__init__()
|
|
|
| self.in_proj = nn.Linear(
|
| D_MODEL,
|
| state_dim
|
| )
|
|
|
| self.gate = nn.Linear(
|
| D_MODEL,
|
| state_dim
|
| )
|
|
|
| self.out_proj = nn.Linear(
|
| state_dim,
|
| D_MODEL
|
| )
|
|
|
| self.norm = nn.LayerNorm(D_MODEL)
|
|
|
| def forward(self, x):
|
|
|
| residual = x
|
|
|
| h = self.in_proj(x)
|
|
|
| g = torch.sigmoid(
|
| self.gate(x)
|
| )
|
|
|
| x = h * g
|
|
|
| x = self.out_proj(x)
|
|
|
| x = x + residual
|
|
|
| return self.norm(x)
|
|
|
| class MambaExpert(nn.Module):
|
| def __init__(self, config):
|
| super().__init__()
|
|
|
| self.blocks = nn.ModuleList([
|
| MambaLikeBlock(
|
| state_dim=config["state_dim"]
|
| )
|
| for _ in range(config["layers"])
|
| ])
|
|
|
| self.norm = nn.LayerNorm(D_MODEL)
|
|
|
| def forward(self, x):
|
|
|
| for block in self.blocks:
|
| x = block(x)
|
|
|
| return self.norm(x)
|
|
|
|
|
|
|
|
|
|
|
| class GenoLiteHybrid(nn.Module):
|
| def __init__(self):
|
| super().__init__()
|
|
|
|
|
|
|
|
|
|
|
| self.embedding = nn.Embedding(
|
| VOCAB_SIZE,
|
| D_MODEL
|
| )
|
|
|
|
|
|
|
|
|
|
|
| self.cnn = CNNExpert(CONFIG["cnn"])
|
|
|
| self.gru = GRUExpert(CONFIG["gru"])
|
|
|
| self.transformer = TransformerExpert(
|
| CONFIG["transformer"]
|
| )
|
|
|
| self.mamba = MambaExpert(CONFIG["mamba"])
|
|
|
|
|
|
|
|
|
|
|
| self.fusion = nn.Sequential(
|
|
|
| nn.Linear(
|
| D_MODEL * 4,
|
| D_MODEL
|
| ),
|
|
|
| nn.GELU(),
|
|
|
| nn.Dropout(0.1),
|
|
|
| nn.LayerNorm(D_MODEL)
|
| )
|
|
|
|
|
|
|
|
|
|
|
| self.classifier = nn.Sequential(
|
|
|
| nn.Linear(
|
| D_MODEL,
|
| 512
|
| ),
|
|
|
| nn.GELU(),
|
|
|
| nn.Dropout(0.1),
|
|
|
| nn.Linear(
|
| 512,
|
| NUM_CLASSES
|
| )
|
| )
|
|
|
| def forward(self, x):
|
|
|
|
|
|
|
|
|
|
|
| x = self.embedding(x)
|
|
|
|
|
|
|
|
|
|
|
| cnn_out = self.cnn(x)
|
|
|
| gru_out = self.gru(x)
|
|
|
| tf_out = self.transformer(x)
|
|
|
| mamba_out = self.mamba(x)
|
|
|
|
|
|
|
|
|
|
|
| fused = torch.cat(
|
| [
|
| cnn_out,
|
| gru_out,
|
| tf_out,
|
| mamba_out
|
| ],
|
| dim=-1
|
| )
|
|
|
| fused = self.fusion(fused)
|
|
|
|
|
|
|
|
|
|
|
| pooled = fused.mean(dim=1)
|
|
|
|
|
|
|
|
|
|
|
| logits = self.classifier(pooled)
|
|
|
| return logits
|
|
|
|
|
|
|
|
|
|
|
| def count_params(model):
|
| return sum(
|
| p.numel()
|
| for p in model.parameters()
|
| )
|
|
|
|
|
|
|
|
|
|
|
| if __name__ == "__main__":
|
|
|
| model = GenoLiteHybrid()
|
|
|
| x = torch.randint(
|
| 0,
|
| 4,
|
| (2, 64)
|
| )
|
|
|
| y = model(x)
|
|
|
| print("\n================ TEST ================\n")
|
|
|
| print("Input shape :", x.shape)
|
|
|
| print("Output shape:", y.shape)
|
|
|
| total_params = count_params(model)
|
|
|
| print(f"\nTotal Params: {total_params / 1e6:.2f}M")
|
|
|
| print("\n======================================\n") |