seungwonwonwon commited on
Commit
eda7ec8
·
verified ·
1 Parent(s): 6b76bf2

Initial upload: weights, config, code, README, requirements

Browse files
Files changed (6) hide show
  1. README.md +57 -0
  2. config.json +9 -0
  3. model.pt +3 -0
  4. model.py +48 -0
  5. model.safetensors +3 -0
  6. requirements.txt +5 -0
README.md ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: "pytorch"
3
+ tags:
4
+ - protein
5
+ - biosequence
6
+ - cnn
7
+ - embedding
8
+ license: apache-2.0
9
+ ---
10
+
11
+ # CNNED_Protein
12
+
13
+ CNN-based embedding model for protein/bio sequences (triplet/contrastive training ready).
14
+
15
+ ## Model Summary
16
+ - **Input**: one-hot encoded sequence of shape `(B, A, L)`
17
+ - **Encoder**: 1D CNN + AvgPooling stacks
18
+ - **Output**: L2-normalized embedding `(B, D)` via projection head
19
+ - **Training**: Designed for triplet/contrastive loss (anchor, positive, negative)
20
+
21
+ ### Config
22
+ - `alphabet_size`: 27
23
+ - `target_size`: 128
24
+ - `channel`: 256
25
+ - `depth`: 3
26
+ - `kernel_size`: 7
27
+ - `l2norm`: True
28
+
29
+ ## Usage
30
+
31
+ ```python
32
+ import json, torch
33
+ from safetensors.torch import load_file
34
+
35
+ # Load config
36
+ cfg = json.load(open("config.json","r"))
37
+ from model import CNNED_Protein
38
+ model = CNNED_Protein(**cfg).eval()
39
+
40
+ # Load weights
41
+ try:
42
+ sd = load_file("model.safetensors")
43
+ except Exception:
44
+ sd = torch.load("model.pt", map_location="cpu")
45
+ model.load_state_dict(sd, strict=True)
46
+ model.eval()
47
+
48
+ # Dummy inference
49
+ # x: (B, A, L) one-hot tensor
50
+ x = torch.randn(2, cfg['alphabet_size'], 512)
51
+ y, z = model.encode(x)
52
+ print(y.shape) # (2, target_size)
53
+ ```
54
+
55
+ ## Notes
56
+ - TripletMarginLoss / InfoNCE 등 metric learning에 적합.
57
+ - Sequence length 변동에 robust (global pooling).
config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_class": "CNNED_Protein",
3
+ "alphabet_size": 27,
4
+ "target_size": 128,
5
+ "channel": 256,
6
+ "depth": 3,
7
+ "kernel_size": 7,
8
+ "l2norm": true
9
+ }
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b836a1eaccd791f88d122c2773a3c80ef75f7c47dbc3dcdec998600a0e6825ae
3
+ size 4277097
model.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class CNNED_Protein(nn.Module):
6
+ def __init__(self, alphabet_size: int, target_size: int,
7
+ channel: int, depth: int, kernel_size: int, l2norm: bool = True):
8
+ super().__init__()
9
+ C_in = alphabet_size
10
+ C = channel
11
+ K = kernel_size
12
+ pad = K // 2
13
+
14
+ blocks = [
15
+ nn.Conv1d(C_in, C, K, stride=1, padding=pad, bias=False),
16
+ nn.BatchNorm1d(C),
17
+ nn.ReLU(inplace=True),
18
+ ]
19
+ for _ in range(depth - 1):
20
+ blocks += [
21
+ nn.Conv1d(C, C, K, stride=1, padding=pad, bias=False),
22
+ nn.BatchNorm1d(C),
23
+ nn.ReLU(inplace=True),
24
+ nn.AvgPool1d(2),
25
+ ]
26
+ self.conv = nn.Sequential(*blocks)
27
+ self.pool = nn.AdaptiveAvgPool1d(1)
28
+ self.proj = nn.Sequential(
29
+ nn.Linear(C, C),
30
+ nn.ReLU(inplace=True),
31
+ nn.Linear(C, target_size),
32
+ )
33
+ self.l2norm = l2norm
34
+
35
+ def encode(self, x: torch.Tensor):
36
+ # x: (B, A, L)
37
+ z = self.conv(x) # (B, C, L')
38
+ z = self.pool(z).squeeze(-1) # (B, C)
39
+ y = self.proj(z) # (B, D)
40
+ if self.l2norm:
41
+ y = F.normalize(y, dim=-1)
42
+ return y, z
43
+
44
+ def forward(self, a: torch.Tensor, p: torch.Tensor, n: torch.Tensor):
45
+ ay, _ = self.encode(a)
46
+ py, _ = self.encode(p)
47
+ ny, _ = self.encode(n)
48
+ return ay, py, ny
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aee9a63233ca456f0f7a044ef706053698d96d4d52639e07f20be38854c61414
3
+ size 4272400
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ numpy>=1.24
3
+ safetensors>=0.4.0
4
+ tqdm>=4.65
5
+ biopython>=1.83