brscftc commited on
Commit
e49356a
·
verified ·
1 Parent(s): 4718e19

Upload 4 files

Browse files
Files changed (4) hide show
  1. chat.py +422 -0
  2. config.json +113 -0
  3. model.pt +3 -0
  4. model.py +401 -0
chat.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from model import GenoLiteHybrid
5
+
6
+ # =========================================================
7
+ # CONFIG
8
+ # =========================================================
9
+
10
+ DEVICE = torch.device(
11
+ "cuda" if torch.cuda.is_available() else "cpu"
12
+ )
13
+
14
+ CHUNK_SIZE = 64
15
+ TOKEN_MAP = {
16
+ "U": 0,
17
+ "D": 1,
18
+ "-": 2,
19
+ "+": 3,
20
+ "J": 4,
21
+ "R": 5,
22
+ "L": 6,
23
+ "T": 7,
24
+ "C": 8,
25
+ "H": 9,
26
+ "F": 10
27
+ }
28
+
29
+ ID2LABEL = {
30
+ 0: "0",
31
+ 1: "1",
32
+ 2: "2",
33
+ 3: "3",
34
+ 4: "4",
35
+ 5: "5",
36
+ 6: "6",
37
+ 7: "7",
38
+ 8: "8",
39
+ 9: "9"
40
+ }
41
+
42
+ # =========================================================
43
+ # LOAD MODEL
44
+ # =========================================================
45
+
46
+ model = GenoLiteHybrid().to(DEVICE)
47
+
48
+ checkpoint = torch.load(
49
+ "model.pt",
50
+ map_location=DEVICE
51
+ )
52
+
53
+ # ---------------------------------------------------------
54
+ # RAW OR FULL CHECKPOINT
55
+ # ---------------------------------------------------------
56
+
57
+ if isinstance(checkpoint, dict) and \
58
+ "model_state_dict" in checkpoint:
59
+
60
+ model.load_state_dict(
61
+ checkpoint["model_state_dict"]
62
+ )
63
+
64
+ print("\nLoaded full checkpoint.")
65
+
66
+ else:
67
+
68
+ model.load_state_dict(checkpoint)
69
+
70
+ print("\nLoaded raw state_dict.")
71
+
72
+ model.eval()
73
+
74
+ print("\n===================================")
75
+ print(" MODEL LOADED")
76
+ print("===================================\n")
77
+
78
+ # =========================================================
79
+ # ENCODE
80
+ # =========================================================
81
+
82
+ def encode(seq):
83
+
84
+ return torch.tensor(
85
+
86
+ [TOKEN_MAP[c] for c in seq],
87
+
88
+ dtype=torch.long
89
+ )
90
+
91
+ # =========================================================
92
+ # CHUNKING
93
+ # =========================================================
94
+
95
+ def split_chunks(sequence):
96
+
97
+ chunks = []
98
+
99
+ for i in range(
100
+ 0,
101
+ len(sequence),
102
+ CHUNK_SIZE
103
+ ):
104
+
105
+ chunk = sequence[
106
+ i:i + CHUNK_SIZE
107
+ ]
108
+
109
+ chunks.append(chunk)
110
+
111
+ return chunks
112
+
113
+ # =========================================================
114
+ # SINGLE CHUNK INFERENCE
115
+ # =========================================================
116
+
117
+ def analyze_chunk(sequence):
118
+
119
+ x = encode(sequence)
120
+
121
+ x = x.unsqueeze(0).to(DEVICE)
122
+
123
+ with torch.no_grad():
124
+
125
+ # ---------------------------------------------
126
+ # EMBEDDING
127
+ # ---------------------------------------------
128
+
129
+ emb = model.embedding(x)
130
+
131
+ # ---------------------------------------------
132
+ # EXPERTS
133
+ # ---------------------------------------------
134
+
135
+ cnn_out = model.cnn(emb)
136
+
137
+ gru_out = model.gru(emb)
138
+
139
+ tf_out = model.transformer(emb)
140
+
141
+ mamba_out = model.mamba(emb)
142
+
143
+ # ---------------------------------------------
144
+ # EXPERT ACTIVITY
145
+ # ---------------------------------------------
146
+
147
+ cnn_score = cnn_out.abs().mean().item()
148
+
149
+ gru_score = gru_out.abs().mean().item()
150
+
151
+ tf_score = tf_out.abs().mean().item()
152
+
153
+ mamba_score = mamba_out.abs().mean().item()
154
+
155
+ total = (
156
+ cnn_score +
157
+ gru_score +
158
+ tf_score +
159
+ mamba_score
160
+ )
161
+
162
+ cnn_w = cnn_score / total
163
+ gru_w = gru_score / total
164
+ tf_w = tf_score / total
165
+ mamba_w = mamba_score / total
166
+
167
+ # ---------------------------------------------
168
+ # FINAL PRED
169
+ # ---------------------------------------------
170
+
171
+ fused = torch.cat(
172
+ [
173
+ cnn_out,
174
+ gru_out,
175
+ tf_out,
176
+ mamba_out
177
+ ],
178
+ dim=-1
179
+ )
180
+
181
+ fused = model.fusion(fused)
182
+
183
+ pooled = fused.mean(dim=1)
184
+
185
+ logits = model.classifier(pooled)
186
+
187
+ probs = F.softmax(
188
+ logits,
189
+ dim=-1
190
+ )
191
+
192
+ pred = probs.argmax(dim=-1).item()
193
+
194
+ return {
195
+
196
+ "prediction": ID2LABEL[pred],
197
+
198
+ "probs": probs[0].cpu(),
199
+
200
+ "cnn": cnn_w,
201
+ "gru": gru_w,
202
+ "tf": tf_w,
203
+ "mamba": mamba_w
204
+ }
205
+
206
+ # =========================================================
207
+ # FULL ANALYSIS
208
+ # =========================================================
209
+
210
+ def analyze_sequence(sequence):
211
+
212
+ sequence = sequence.strip().upper()
213
+
214
+ # -----------------------------------------------------
215
+ # VALIDATION
216
+ # -----------------------------------------------------
217
+
218
+ valid = all(
219
+ c in TOKEN_MAP
220
+ for c in sequence
221
+ )
222
+
223
+ if not valid:
224
+
225
+ print("\nOnly A/T/G/C allowed.\n")
226
+ return
227
+
228
+ # -----------------------------------------------------
229
+ # LENGTH CHECK
230
+ # -----------------------------------------------------
231
+
232
+ length = len(sequence)
233
+
234
+ if length < CHUNK_SIZE:
235
+
236
+ missing = CHUNK_SIZE - length
237
+
238
+ print("\n===================================")
239
+ print(" LENGTH ERROR")
240
+ print("===================================\n")
241
+
242
+ print("Input too short.\n")
243
+
244
+ print(
245
+ f"Current Length : {length}"
246
+ )
247
+
248
+ print(
249
+ f"Missing Chars : {missing}"
250
+ )
251
+
252
+ print(
253
+ f"Required Length: {CHUNK_SIZE}"
254
+ )
255
+
256
+ print("\n===================================\n")
257
+
258
+ return
259
+
260
+ # -----------------------------------------------------
261
+ # MULTIPLE CHECK
262
+ # -----------------------------------------------------
263
+
264
+ if length % CHUNK_SIZE != 0:
265
+
266
+ next_valid = (
267
+ (
268
+ length // CHUNK_SIZE
269
+ ) + 1
270
+ ) * CHUNK_SIZE
271
+
272
+ missing = next_valid - length
273
+
274
+ print("\n===================================")
275
+ print(" LENGTH ERROR")
276
+ print("===================================\n")
277
+
278
+ print(
279
+ f"Sequence length must be "
280
+ f"a multiple of {CHUNK_SIZE}.\n"
281
+ )
282
+
283
+ print(
284
+ f"Current Length : {length}"
285
+ )
286
+
287
+ print(
288
+ f"Next Valid Size: {next_valid}"
289
+ )
290
+
291
+ print(
292
+ f"Missing Chars : {missing}"
293
+ )
294
+
295
+ print("\n===================================\n")
296
+
297
+ return
298
+
299
+ # -----------------------------------------------------
300
+ # CHUNKING
301
+ # -----------------------------------------------------
302
+
303
+ chunks = split_chunks(sequence)
304
+
305
+ print("\n===================================")
306
+ print(" ANALYZING INPUT")
307
+ print("===================================\n")
308
+
309
+ print(f"Total Length : {len(sequence)}")
310
+
311
+ print(f"Chunks : {len(chunks)}")
312
+
313
+ # -----------------------------------------------------
314
+ # AGGREGATION
315
+ # -----------------------------------------------------
316
+
317
+ total_probs = torch.zeros(10)
318
+
319
+ total_cnn = 0
320
+ total_gru = 0
321
+ total_tf = 0
322
+ total_mamba = 0
323
+
324
+ # -----------------------------------------------------
325
+ # PROCESS CHUNKS
326
+ # -----------------------------------------------------
327
+
328
+ for idx, chunk in enumerate(chunks):
329
+
330
+ result = analyze_chunk(chunk)
331
+
332
+ total_probs += result["probs"]
333
+
334
+ total_cnn += result["cnn"]
335
+ total_gru += result["gru"]
336
+ total_tf += result["tf"]
337
+ total_mamba += result["mamba"]
338
+
339
+ print("\n-----------------------------------")
340
+ print(f"Chunk {idx+1}")
341
+ print("-----------------------------------\n")
342
+
343
+ print(chunk)
344
+
345
+ print("\nPrediction:")
346
+ print(result["prediction"])
347
+
348
+ print("\nProbabilities:\n")
349
+
350
+ for i in range(3):
351
+
352
+ print(
353
+ f"{ID2LABEL[i]}: "
354
+ f"{result['probs'][i].item():.4f}"
355
+ )
356
+
357
+ # -----------------------------------------------------
358
+ # AVERAGES
359
+ # -----------------------------------------------------
360
+
361
+ total_probs /= len(chunks)
362
+
363
+ total_cnn /= len(chunks)
364
+ total_gru /= len(chunks)
365
+ total_tf /= len(chunks)
366
+ total_mamba /= len(chunks)
367
+
368
+ # -----------------------------------------------------
369
+ # FINAL DECISION
370
+ # -----------------------------------------------------
371
+
372
+ final_pred = total_probs.argmax().item()
373
+
374
+ print("\n===================================")
375
+ print(" FINAL RESULT")
376
+ print("===================================\n")
377
+
378
+ print(
379
+ f"FINAL DECISION: "
380
+ f"{ID2LABEL[final_pred]}"
381
+ )
382
+
383
+ print("\n-----------------------------------")
384
+ print("Average Probabilities")
385
+ print("-----------------------------------\n")
386
+
387
+ for i in range(3):
388
+
389
+ print(
390
+ f"{ID2LABEL[i]}: "
391
+ f"{total_probs[i].item():.4f}"
392
+ )
393
+
394
+ print("\n-----------------------------------")
395
+ print("Average Expert Activity")
396
+ print("-----------------------------------\n")
397
+
398
+ print(f"CNN : {total_cnn:.4f}")
399
+ print(f"GRU : {total_gru:.4f}")
400
+ print(f"Transformer : {total_tf:.4f}")
401
+ print(f"Mamba : {total_mamba:.4f}")
402
+
403
+ print("\n===================================\n")
404
+
405
+ # =========================================================
406
+ # CHAT LOOP
407
+ # =========================================================
408
+
409
+ print("Type DNA sequence.")
410
+ print("Length must be 64 or multiples of 64.")
411
+ print("Type EXIT to quit.\n")
412
+
413
+ while True:
414
+
415
+ seq = input("logs > ")
416
+
417
+ if seq.strip().upper() == "EXIT":
418
+
419
+ print("\nBye.\n")
420
+ break
421
+
422
+ analyze_sequence(seq)
config.json ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "GenoLiteHybrid",
3
+
4
+ "vocab_size": 4,
5
+ "sequence_length": 64,
6
+ "num_classes": 3,
7
+
8
+ "d_model": 512,
9
+
10
+ "cnn": {
11
+ "enabled": true,
12
+ "blocks": 7,
13
+ "channels": 960,
14
+ "kernel_size": 3,
15
+ "residual": true,
16
+ "layernorm": true,
17
+ "activation": "gelu"
18
+ },
19
+
20
+ "gru": {
21
+ "enabled": true,
22
+ "hidden_size": 960,
23
+ "layers": 4,
24
+ "bidirectional": false,
25
+ "batch_first": true,
26
+ "projection_to_d_model": true,
27
+ "layernorm": true
28
+ },
29
+
30
+ "transformer": {
31
+ "enabled": true,
32
+ "layers": 6,
33
+ "heads": 8,
34
+ "ffn_dim": 2048,
35
+ "dropout": 0.1,
36
+ "activation": "gelu",
37
+ "batch_first": true,
38
+ "layernorm": true
39
+ },
40
+
41
+ "mamba": {
42
+ "enabled": true,
43
+ "layers": 10,
44
+ "state_dim": 1408,
45
+ "gated": true,
46
+ "residual": true,
47
+ "layernorm": true
48
+ },
49
+
50
+ "fusion": {
51
+ "input_dim": 2048,
52
+ "output_dim": 512,
53
+ "activation": "gelu",
54
+ "dropout": 0.1,
55
+ "layernorm": true
56
+ },
57
+
58
+ "classifier": {
59
+ "hidden_dim": 512,
60
+ "dropout": 0.1,
61
+ "activation": "gelu",
62
+ "num_classes": 3
63
+ },
64
+
65
+ "pooling": {
66
+ "type": "mean"
67
+ },
68
+
69
+ "training": {
70
+ "epochs": 3,
71
+ "batch_size": 3,
72
+ "learning_rate": 0.0001,
73
+ "optimizer": "AdamW",
74
+ "weight_decay": 0.01,
75
+ "gradient_clipping": 1.0,
76
+ "shuffle": true
77
+ },
78
+
79
+ "dataset": {
80
+ "type": "synthetic",
81
+ "samples_total": 9000,
82
+ "samples_per_class": 3000,
83
+
84
+ "classes": [
85
+ "OK",
86
+ "MHAP",
87
+ "PROBLEM"
88
+ ],
89
+
90
+ "difficulty_levels": [
91
+ "easy",
92
+ "medium",
93
+ "hard"
94
+ ],
95
+
96
+ "features": [
97
+ "controlled_entropy",
98
+ "motif_repetition",
99
+ "hidden_illegal_pairs",
100
+ "partial_shuffle",
101
+ "duplicate_prevention",
102
+ "class_overlap"
103
+ ]
104
+ },
105
+
106
+ "hardware": {
107
+ "device": "cpu",
108
+ "ram_gb": 8,
109
+ "cpu": "Intel i7-4700MQ"
110
+ },
111
+
112
+ "estimated_parameters": "88M+"
113
+ }
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a66d5ab0a6fbdbf546b84737bf95488cf6a5a08e73e4864c9fbab9cdd1fc00e4
3
+ size 1007279607
model.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ # =========================================================
6
+ # CONFIG
7
+ # =========================================================
8
+
9
+ VOCAB_SIZE = 11
10
+ SEQ_LEN = 64
11
+ NUM_CLASSES = 10
12
+
13
+ D_MODEL = 512
14
+
15
+ CONFIG = {
16
+
17
+ # -----------------------------------------------------
18
+ # CNN
19
+ # -----------------------------------------------------
20
+
21
+ "cnn": {
22
+ "blocks": 7,
23
+ "channels": 960,
24
+ "kernel": 3
25
+ },
26
+
27
+ # -----------------------------------------------------
28
+ # GRU
29
+ # -----------------------------------------------------
30
+
31
+ "gru": {
32
+ "hidden": 960,
33
+ "layers": 4
34
+ },
35
+
36
+ # -----------------------------------------------------
37
+ # TRANSFORMER
38
+ # -----------------------------------------------------
39
+
40
+ "transformer": {
41
+ "layers": 6,
42
+ "heads": 8,
43
+ "ffn": 2048,
44
+ "dropout": 0.1
45
+ },
46
+
47
+ # -----------------------------------------------------
48
+ # MAMBA-LIKE
49
+ # -----------------------------------------------------
50
+
51
+ "mamba": {
52
+ "layers": 10,
53
+ "state_dim": 1408
54
+ }
55
+ }
56
+
57
+ # =========================================================
58
+ # CNN EXPERT
59
+ # =========================================================
60
+
61
+ class CNNBlock(nn.Module):
62
+ def __init__(self, channels, kernel):
63
+ super().__init__()
64
+
65
+ self.conv1 = nn.Conv1d(
66
+ D_MODEL,
67
+ channels,
68
+ kernel_size=kernel,
69
+ padding=kernel // 2
70
+ )
71
+
72
+ self.conv2 = nn.Conv1d(
73
+ channels,
74
+ D_MODEL,
75
+ kernel_size=kernel,
76
+ padding=kernel // 2
77
+ )
78
+
79
+ self.norm = nn.LayerNorm(D_MODEL)
80
+
81
+ def forward(self, x):
82
+
83
+ # x = [B, S, D]
84
+
85
+ residual = x
86
+
87
+ x = x.transpose(1, 2) # [B, D, S]
88
+
89
+ x = self.conv1(x)
90
+ x = F.gelu(x)
91
+
92
+ x = self.conv2(x)
93
+ x = F.gelu(x)
94
+
95
+ x = x.transpose(1, 2) # [B, S, D]
96
+
97
+ x = x + residual
98
+
99
+ return self.norm(x)
100
+
101
+ class CNNExpert(nn.Module):
102
+ def __init__(self, config):
103
+ super().__init__()
104
+
105
+ self.blocks = nn.ModuleList([
106
+ CNNBlock(
107
+ channels=config["channels"],
108
+ kernel=config["kernel"]
109
+ )
110
+ for _ in range(config["blocks"])
111
+ ])
112
+
113
+ self.norm = nn.LayerNorm(D_MODEL)
114
+
115
+ def forward(self, x):
116
+
117
+ for block in self.blocks:
118
+ x = block(x)
119
+
120
+ return self.norm(x)
121
+
122
+ # =========================================================
123
+ # GRU EXPERT
124
+ # =========================================================
125
+
126
+ class GRUExpert(nn.Module):
127
+ def __init__(self, config):
128
+ super().__init__()
129
+
130
+ self.gru = nn.GRU(
131
+ input_size=D_MODEL,
132
+ hidden_size=config["hidden"],
133
+ num_layers=config["layers"],
134
+ batch_first=True
135
+ )
136
+
137
+ self.proj = nn.Linear(
138
+ config["hidden"],
139
+ D_MODEL
140
+ )
141
+
142
+ self.norm = nn.LayerNorm(D_MODEL)
143
+
144
+ def forward(self, x):
145
+
146
+ x, _ = self.gru(x)
147
+
148
+ x = self.proj(x)
149
+
150
+ return self.norm(x)
151
+
152
+ # =========================================================
153
+ # TRANSFORMER EXPERT
154
+ # =========================================================
155
+
156
+ class TransformerExpert(nn.Module):
157
+ def __init__(self, config):
158
+ super().__init__()
159
+
160
+ encoder_layer = nn.TransformerEncoderLayer(
161
+ d_model=D_MODEL,
162
+ nhead=config["heads"],
163
+ dim_feedforward=config["ffn"],
164
+ dropout=config["dropout"],
165
+ batch_first=True,
166
+ activation="gelu"
167
+ )
168
+
169
+ self.encoder = nn.TransformerEncoder(
170
+ encoder_layer,
171
+ num_layers=config["layers"]
172
+ )
173
+
174
+ self.norm = nn.LayerNorm(D_MODEL)
175
+
176
+ def forward(self, x):
177
+
178
+ x = self.encoder(x)
179
+
180
+ return self.norm(x)
181
+
182
+ # =========================================================
183
+ # MAMBA-LIKE BLOCK
184
+ # =========================================================
185
+
186
+ class MambaLikeBlock(nn.Module):
187
+ def __init__(self, state_dim):
188
+ super().__init__()
189
+
190
+ self.in_proj = nn.Linear(
191
+ D_MODEL,
192
+ state_dim
193
+ )
194
+
195
+ self.gate = nn.Linear(
196
+ D_MODEL,
197
+ state_dim
198
+ )
199
+
200
+ self.out_proj = nn.Linear(
201
+ state_dim,
202
+ D_MODEL
203
+ )
204
+
205
+ self.norm = nn.LayerNorm(D_MODEL)
206
+
207
+ def forward(self, x):
208
+
209
+ residual = x
210
+
211
+ h = self.in_proj(x)
212
+
213
+ g = torch.sigmoid(
214
+ self.gate(x)
215
+ )
216
+
217
+ x = h * g
218
+
219
+ x = self.out_proj(x)
220
+
221
+ x = x + residual
222
+
223
+ return self.norm(x)
224
+
225
+ class MambaExpert(nn.Module):
226
+ def __init__(self, config):
227
+ super().__init__()
228
+
229
+ self.blocks = nn.ModuleList([
230
+ MambaLikeBlock(
231
+ state_dim=config["state_dim"]
232
+ )
233
+ for _ in range(config["layers"])
234
+ ])
235
+
236
+ self.norm = nn.LayerNorm(D_MODEL)
237
+
238
+ def forward(self, x):
239
+
240
+ for block in self.blocks:
241
+ x = block(x)
242
+
243
+ return self.norm(x)
244
+
245
+ # =========================================================
246
+ # HYBRID MODEL
247
+ # =========================================================
248
+
249
+ class GenoLiteHybrid(nn.Module):
250
+ def __init__(self):
251
+ super().__init__()
252
+
253
+ # -------------------------------------------------
254
+ # EMBEDDING
255
+ # -------------------------------------------------
256
+
257
+ self.embedding = nn.Embedding(
258
+ VOCAB_SIZE,
259
+ D_MODEL
260
+ )
261
+
262
+ # -------------------------------------------------
263
+ # EXPERTS
264
+ # -------------------------------------------------
265
+
266
+ self.cnn = CNNExpert(CONFIG["cnn"])
267
+
268
+ self.gru = GRUExpert(CONFIG["gru"])
269
+
270
+ self.transformer = TransformerExpert(
271
+ CONFIG["transformer"]
272
+ )
273
+
274
+ self.mamba = MambaExpert(CONFIG["mamba"])
275
+
276
+ # -------------------------------------------------
277
+ # FUSION
278
+ # -------------------------------------------------
279
+
280
+ self.fusion = nn.Sequential(
281
+
282
+ nn.Linear(
283
+ D_MODEL * 4,
284
+ D_MODEL
285
+ ),
286
+
287
+ nn.GELU(),
288
+
289
+ nn.Dropout(0.1),
290
+
291
+ nn.LayerNorm(D_MODEL)
292
+ )
293
+
294
+ # -------------------------------------------------
295
+ # CLASSIFIER
296
+ # -------------------------------------------------
297
+
298
+ self.classifier = nn.Sequential(
299
+
300
+ nn.Linear(
301
+ D_MODEL,
302
+ 512
303
+ ),
304
+
305
+ nn.GELU(),
306
+
307
+ nn.Dropout(0.1),
308
+
309
+ nn.Linear(
310
+ 512,
311
+ NUM_CLASSES
312
+ )
313
+ )
314
+
315
+ def forward(self, x):
316
+
317
+ # -------------------------------------------------
318
+ # EMBEDDING
319
+ # -------------------------------------------------
320
+
321
+ x = self.embedding(x)
322
+
323
+ # -------------------------------------------------
324
+ # EXPERTS
325
+ # -------------------------------------------------
326
+
327
+ cnn_out = self.cnn(x)
328
+
329
+ gru_out = self.gru(x)
330
+
331
+ tf_out = self.transformer(x)
332
+
333
+ mamba_out = self.mamba(x)
334
+
335
+ # -------------------------------------------------
336
+ # FUSION
337
+ # -------------------------------------------------
338
+
339
+ fused = torch.cat(
340
+ [
341
+ cnn_out,
342
+ gru_out,
343
+ tf_out,
344
+ mamba_out
345
+ ],
346
+ dim=-1
347
+ )
348
+
349
+ fused = self.fusion(fused)
350
+
351
+ # -------------------------------------------------
352
+ # GLOBAL POOLING
353
+ # -------------------------------------------------
354
+
355
+ pooled = fused.mean(dim=1)
356
+
357
+ # -------------------------------------------------
358
+ # CLASSIFIER
359
+ # -------------------------------------------------
360
+
361
+ logits = self.classifier(pooled)
362
+
363
+ return logits
364
+
365
+ # =========================================================
366
+ # PARAM COUNTER
367
+ # =========================================================
368
+
369
+ def count_params(model):
370
+ return sum(
371
+ p.numel()
372
+ for p in model.parameters()
373
+ )
374
+
375
+ # =========================================================
376
+ # TEST
377
+ # =========================================================
378
+
379
+ if __name__ == "__main__":
380
+
381
+ model = GenoLiteHybrid()
382
+
383
+ x = torch.randint(
384
+ 0,
385
+ 11,
386
+ (2, 64)
387
+ )
388
+
389
+ y = model(x)
390
+
391
+ print("\n================ TEST ================\n")
392
+
393
+ print("Input shape :", x.shape)
394
+
395
+ print("Output shape:", y.shape)
396
+
397
+ total_params = count_params(model)
398
+
399
+ print(f"\nTotal Params: {total_params / 1e6:.2f}M")
400
+
401
+ print("\n======================================\n")