import torch import torch.nn as nn import torch.nn.functional as F # ========================================================= # CONFIG # ========================================================= VOCAB_SIZE = 11 SEQ_LEN = 64 NUM_CLASSES = 10 D_MODEL = 512 CONFIG = { # ----------------------------------------------------- # CNN # ----------------------------------------------------- "cnn": { "blocks": 7, "channels": 960, "kernel": 3 }, # ----------------------------------------------------- # GRU # ----------------------------------------------------- "gru": { "hidden": 960, "layers": 4 }, # ----------------------------------------------------- # TRANSFORMER # ----------------------------------------------------- "transformer": { "layers": 6, "heads": 8, "ffn": 2048, "dropout": 0.1 }, # ----------------------------------------------------- # MAMBA-LIKE # ----------------------------------------------------- "mamba": { "layers": 10, "state_dim": 1408 } } # ========================================================= # CNN EXPERT # ========================================================= class CNNBlock(nn.Module): def __init__(self, channels, kernel): super().__init__() self.conv1 = nn.Conv1d( D_MODEL, channels, kernel_size=kernel, padding=kernel // 2 ) self.conv2 = nn.Conv1d( channels, D_MODEL, kernel_size=kernel, padding=kernel // 2 ) self.norm = nn.LayerNorm(D_MODEL) def forward(self, x): # x = [B, S, D] residual = x x = x.transpose(1, 2) # [B, D, S] x = self.conv1(x) x = F.gelu(x) x = self.conv2(x) x = F.gelu(x) x = x.transpose(1, 2) # [B, S, D] x = x + residual return self.norm(x) class CNNExpert(nn.Module): def __init__(self, config): super().__init__() self.blocks = nn.ModuleList([ CNNBlock( channels=config["channels"], kernel=config["kernel"] ) for _ in range(config["blocks"]) ]) self.norm = nn.LayerNorm(D_MODEL) def forward(self, x): for block in self.blocks: x = block(x) return self.norm(x) # ========================================================= # GRU EXPERT # ========================================================= class GRUExpert(nn.Module): def __init__(self, config): super().__init__() self.gru = nn.GRU( input_size=D_MODEL, hidden_size=config["hidden"], num_layers=config["layers"], batch_first=True ) self.proj = nn.Linear( config["hidden"], D_MODEL ) self.norm = nn.LayerNorm(D_MODEL) def forward(self, x): x, _ = self.gru(x) x = self.proj(x) return self.norm(x) # ========================================================= # TRANSFORMER EXPERT # ========================================================= class TransformerExpert(nn.Module): def __init__(self, config): super().__init__() encoder_layer = nn.TransformerEncoderLayer( d_model=D_MODEL, nhead=config["heads"], dim_feedforward=config["ffn"], dropout=config["dropout"], batch_first=True, activation="gelu" ) self.encoder = nn.TransformerEncoder( encoder_layer, num_layers=config["layers"] ) self.norm = nn.LayerNorm(D_MODEL) def forward(self, x): x = self.encoder(x) return self.norm(x) # ========================================================= # MAMBA-LIKE BLOCK # ========================================================= class MambaLikeBlock(nn.Module): def __init__(self, state_dim): super().__init__() self.in_proj = nn.Linear( D_MODEL, state_dim ) self.gate = nn.Linear( D_MODEL, state_dim ) self.out_proj = nn.Linear( state_dim, D_MODEL ) self.norm = nn.LayerNorm(D_MODEL) def forward(self, x): residual = x h = self.in_proj(x) g = torch.sigmoid( self.gate(x) ) x = h * g x = self.out_proj(x) x = x + residual return self.norm(x) class MambaExpert(nn.Module): def __init__(self, config): super().__init__() self.blocks = nn.ModuleList([ MambaLikeBlock( state_dim=config["state_dim"] ) for _ in range(config["layers"]) ]) self.norm = nn.LayerNorm(D_MODEL) def forward(self, x): for block in self.blocks: x = block(x) return self.norm(x) # ========================================================= # HYBRID MODEL # ========================================================= class GenoLiteHybrid(nn.Module): def __init__(self): super().__init__() # ------------------------------------------------- # EMBEDDING # ------------------------------------------------- self.embedding = nn.Embedding( VOCAB_SIZE, D_MODEL ) # ------------------------------------------------- # EXPERTS # ------------------------------------------------- self.cnn = CNNExpert(CONFIG["cnn"]) self.gru = GRUExpert(CONFIG["gru"]) self.transformer = TransformerExpert( CONFIG["transformer"] ) self.mamba = MambaExpert(CONFIG["mamba"]) # ------------------------------------------------- # FUSION # ------------------------------------------------- self.fusion = nn.Sequential( nn.Linear( D_MODEL * 4, D_MODEL ), nn.GELU(), nn.Dropout(0.1), nn.LayerNorm(D_MODEL) ) # ------------------------------------------------- # CLASSIFIER # ------------------------------------------------- self.classifier = nn.Sequential( nn.Linear( D_MODEL, 512 ), nn.GELU(), nn.Dropout(0.1), nn.Linear( 512, NUM_CLASSES ) ) def forward(self, x): # ------------------------------------------------- # EMBEDDING # ------------------------------------------------- x = self.embedding(x) # ------------------------------------------------- # EXPERTS # ------------------------------------------------- cnn_out = self.cnn(x) gru_out = self.gru(x) tf_out = self.transformer(x) mamba_out = self.mamba(x) # ------------------------------------------------- # FUSION # ------------------------------------------------- fused = torch.cat( [ cnn_out, gru_out, tf_out, mamba_out ], dim=-1 ) fused = self.fusion(fused) # ------------------------------------------------- # GLOBAL POOLING # ------------------------------------------------- pooled = fused.mean(dim=1) # ------------------------------------------------- # CLASSIFIER # ------------------------------------------------- logits = self.classifier(pooled) return logits # ========================================================= # PARAM COUNTER # ========================================================= def count_params(model): return sum( p.numel() for p in model.parameters() ) # ========================================================= # TEST # ========================================================= if __name__ == "__main__": model = GenoLiteHybrid() x = torch.randint( 0, 11, (2, 64) ) y = model(x) print("\n================ TEST ================\n") print("Input shape :", x.shape) print("Output shape:", y.shape) total_params = count_params(model) print(f"\nTotal Params: {total_params / 1e6:.2f}M") print("\n======================================\n")