GenoLite / model.py
brscftc's picture
Upload 3 files
58d4947 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
# =========================================================
# CONFIG
# =========================================================
VOCAB_SIZE = 4
SEQ_LEN = 64
NUM_CLASSES = 3
D_MODEL = 512
CONFIG = {
# -----------------------------------------------------
# CNN
# -----------------------------------------------------
"cnn": {
"blocks": 7,
"channels": 960,
"kernel": 3
},
# -----------------------------------------------------
# GRU
# -----------------------------------------------------
"gru": {
"hidden": 960,
"layers": 4
},
# -----------------------------------------------------
# TRANSFORMER
# -----------------------------------------------------
"transformer": {
"layers": 6,
"heads": 8,
"ffn": 2048,
"dropout": 0.1
},
# -----------------------------------------------------
# MAMBA-LIKE
# -----------------------------------------------------
"mamba": {
"layers": 10,
"state_dim": 1408
}
}
# =========================================================
# CNN EXPERT
# =========================================================
class CNNBlock(nn.Module):
def __init__(self, channels, kernel):
super().__init__()
self.conv1 = nn.Conv1d(
D_MODEL,
channels,
kernel_size=kernel,
padding=kernel // 2
)
self.conv2 = nn.Conv1d(
channels,
D_MODEL,
kernel_size=kernel,
padding=kernel // 2
)
self.norm = nn.LayerNorm(D_MODEL)
def forward(self, x):
# x = [B, S, D]
residual = x
x = x.transpose(1, 2) # [B, D, S]
x = self.conv1(x)
x = F.gelu(x)
x = self.conv2(x)
x = F.gelu(x)
x = x.transpose(1, 2) # [B, S, D]
x = x + residual
return self.norm(x)
class CNNExpert(nn.Module):
def __init__(self, config):
super().__init__()
self.blocks = nn.ModuleList([
CNNBlock(
channels=config["channels"],
kernel=config["kernel"]
)
for _ in range(config["blocks"])
])
self.norm = nn.LayerNorm(D_MODEL)
def forward(self, x):
for block in self.blocks:
x = block(x)
return self.norm(x)
# =========================================================
# GRU EXPERT
# =========================================================
class GRUExpert(nn.Module):
def __init__(self, config):
super().__init__()
self.gru = nn.GRU(
input_size=D_MODEL,
hidden_size=config["hidden"],
num_layers=config["layers"],
batch_first=True
)
self.proj = nn.Linear(
config["hidden"],
D_MODEL
)
self.norm = nn.LayerNorm(D_MODEL)
def forward(self, x):
x, _ = self.gru(x)
x = self.proj(x)
return self.norm(x)
# =========================================================
# TRANSFORMER EXPERT
# =========================================================
class TransformerExpert(nn.Module):
def __init__(self, config):
super().__init__()
encoder_layer = nn.TransformerEncoderLayer(
d_model=D_MODEL,
nhead=config["heads"],
dim_feedforward=config["ffn"],
dropout=config["dropout"],
batch_first=True,
activation="gelu"
)
self.encoder = nn.TransformerEncoder(
encoder_layer,
num_layers=config["layers"]
)
self.norm = nn.LayerNorm(D_MODEL)
def forward(self, x):
x = self.encoder(x)
return self.norm(x)
# =========================================================
# MAMBA-LIKE BLOCK
# =========================================================
class MambaLikeBlock(nn.Module):
def __init__(self, state_dim):
super().__init__()
self.in_proj = nn.Linear(
D_MODEL,
state_dim
)
self.gate = nn.Linear(
D_MODEL,
state_dim
)
self.out_proj = nn.Linear(
state_dim,
D_MODEL
)
self.norm = nn.LayerNorm(D_MODEL)
def forward(self, x):
residual = x
h = self.in_proj(x)
g = torch.sigmoid(
self.gate(x)
)
x = h * g
x = self.out_proj(x)
x = x + residual
return self.norm(x)
class MambaExpert(nn.Module):
def __init__(self, config):
super().__init__()
self.blocks = nn.ModuleList([
MambaLikeBlock(
state_dim=config["state_dim"]
)
for _ in range(config["layers"])
])
self.norm = nn.LayerNorm(D_MODEL)
def forward(self, x):
for block in self.blocks:
x = block(x)
return self.norm(x)
# =========================================================
# HYBRID MODEL
# =========================================================
class GenoLiteHybrid(nn.Module):
def __init__(self):
super().__init__()
# -------------------------------------------------
# EMBEDDING
# -------------------------------------------------
self.embedding = nn.Embedding(
VOCAB_SIZE,
D_MODEL
)
# -------------------------------------------------
# EXPERTS
# -------------------------------------------------
self.cnn = CNNExpert(CONFIG["cnn"])
self.gru = GRUExpert(CONFIG["gru"])
self.transformer = TransformerExpert(
CONFIG["transformer"]
)
self.mamba = MambaExpert(CONFIG["mamba"])
# -------------------------------------------------
# FUSION
# -------------------------------------------------
self.fusion = nn.Sequential(
nn.Linear(
D_MODEL * 4,
D_MODEL
),
nn.GELU(),
nn.Dropout(0.1),
nn.LayerNorm(D_MODEL)
)
# -------------------------------------------------
# CLASSIFIER
# -------------------------------------------------
self.classifier = nn.Sequential(
nn.Linear(
D_MODEL,
512
),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(
512,
NUM_CLASSES
)
)
def forward(self, x):
# -------------------------------------------------
# EMBEDDING
# -------------------------------------------------
x = self.embedding(x)
# -------------------------------------------------
# EXPERTS
# -------------------------------------------------
cnn_out = self.cnn(x)
gru_out = self.gru(x)
tf_out = self.transformer(x)
mamba_out = self.mamba(x)
# -------------------------------------------------
# FUSION
# -------------------------------------------------
fused = torch.cat(
[
cnn_out,
gru_out,
tf_out,
mamba_out
],
dim=-1
)
fused = self.fusion(fused)
# -------------------------------------------------
# GLOBAL POOLING
# -------------------------------------------------
pooled = fused.mean(dim=1)
# -------------------------------------------------
# CLASSIFIER
# -------------------------------------------------
logits = self.classifier(pooled)
return logits
# =========================================================
# PARAM COUNTER
# =========================================================
def count_params(model):
return sum(
p.numel()
for p in model.parameters()
)
# =========================================================
# TEST
# =========================================================
if __name__ == "__main__":
model = GenoLiteHybrid()
x = torch.randint(
0,
4,
(2, 64)
)
y = model(x)
print("\n================ TEST ================\n")
print("Input shape :", x.shape)
print("Output shape:", y.shape)
total_params = count_params(model)
print(f"\nTotal Params: {total_params / 1e6:.2f}M")
print("\n======================================\n")