Initial upload: weights, config, code, README, requirements
Browse files- README.md +57 -0
- config.json +9 -0
- model.pt +3 -0
- model.py +48 -0
- model.safetensors +3 -0
- requirements.txt +5 -0
README.md
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: "pytorch"
|
| 3 |
+
tags:
|
| 4 |
+
- protein
|
| 5 |
+
- biosequence
|
| 6 |
+
- cnn
|
| 7 |
+
- embedding
|
| 8 |
+
license: apache-2.0
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# CNNED_Protein
|
| 12 |
+
|
| 13 |
+
CNN-based embedding model for protein/bio sequences (triplet/contrastive training ready).
|
| 14 |
+
|
| 15 |
+
## Model Summary
|
| 16 |
+
- **Input**: one-hot encoded sequence of shape `(B, A, L)`
|
| 17 |
+
- **Encoder**: 1D CNN + AvgPooling stacks
|
| 18 |
+
- **Output**: L2-normalized embedding `(B, D)` via projection head
|
| 19 |
+
- **Training**: Designed for triplet/contrastive loss (anchor, positive, negative)
|
| 20 |
+
|
| 21 |
+
### Config
|
| 22 |
+
- `alphabet_size`: 27
|
| 23 |
+
- `target_size`: 128
|
| 24 |
+
- `channel`: 256
|
| 25 |
+
- `depth`: 3
|
| 26 |
+
- `kernel_size`: 7
|
| 27 |
+
- `l2norm`: True
|
| 28 |
+
|
| 29 |
+
## Usage
|
| 30 |
+
|
| 31 |
+
```python
|
| 32 |
+
import json, torch
|
| 33 |
+
from safetensors.torch import load_file
|
| 34 |
+
|
| 35 |
+
# Load config
|
| 36 |
+
cfg = json.load(open("config.json","r"))
|
| 37 |
+
from model import CNNED_Protein
|
| 38 |
+
model = CNNED_Protein(**cfg).eval()
|
| 39 |
+
|
| 40 |
+
# Load weights
|
| 41 |
+
try:
|
| 42 |
+
sd = load_file("model.safetensors")
|
| 43 |
+
except Exception:
|
| 44 |
+
sd = torch.load("model.pt", map_location="cpu")
|
| 45 |
+
model.load_state_dict(sd, strict=True)
|
| 46 |
+
model.eval()
|
| 47 |
+
|
| 48 |
+
# Dummy inference
|
| 49 |
+
# x: (B, A, L) one-hot tensor
|
| 50 |
+
x = torch.randn(2, cfg['alphabet_size'], 512)
|
| 51 |
+
y, z = model.encode(x)
|
| 52 |
+
print(y.shape) # (2, target_size)
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
## Notes
|
| 56 |
+
- TripletMarginLoss / InfoNCE 등 metric learning에 적합.
|
| 57 |
+
- Sequence length 변동에 robust (global pooling).
|
config.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_class": "CNNED_Protein",
|
| 3 |
+
"alphabet_size": 27,
|
| 4 |
+
"target_size": 128,
|
| 5 |
+
"channel": 256,
|
| 6 |
+
"depth": 3,
|
| 7 |
+
"kernel_size": 7,
|
| 8 |
+
"l2norm": true
|
| 9 |
+
}
|
model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b836a1eaccd791f88d122c2773a3c80ef75f7c47dbc3dcdec998600a0e6825ae
|
| 3 |
+
size 4277097
|
model.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
class CNNED_Protein(nn.Module):
|
| 6 |
+
def __init__(self, alphabet_size: int, target_size: int,
|
| 7 |
+
channel: int, depth: int, kernel_size: int, l2norm: bool = True):
|
| 8 |
+
super().__init__()
|
| 9 |
+
C_in = alphabet_size
|
| 10 |
+
C = channel
|
| 11 |
+
K = kernel_size
|
| 12 |
+
pad = K // 2
|
| 13 |
+
|
| 14 |
+
blocks = [
|
| 15 |
+
nn.Conv1d(C_in, C, K, stride=1, padding=pad, bias=False),
|
| 16 |
+
nn.BatchNorm1d(C),
|
| 17 |
+
nn.ReLU(inplace=True),
|
| 18 |
+
]
|
| 19 |
+
for _ in range(depth - 1):
|
| 20 |
+
blocks += [
|
| 21 |
+
nn.Conv1d(C, C, K, stride=1, padding=pad, bias=False),
|
| 22 |
+
nn.BatchNorm1d(C),
|
| 23 |
+
nn.ReLU(inplace=True),
|
| 24 |
+
nn.AvgPool1d(2),
|
| 25 |
+
]
|
| 26 |
+
self.conv = nn.Sequential(*blocks)
|
| 27 |
+
self.pool = nn.AdaptiveAvgPool1d(1)
|
| 28 |
+
self.proj = nn.Sequential(
|
| 29 |
+
nn.Linear(C, C),
|
| 30 |
+
nn.ReLU(inplace=True),
|
| 31 |
+
nn.Linear(C, target_size),
|
| 32 |
+
)
|
| 33 |
+
self.l2norm = l2norm
|
| 34 |
+
|
| 35 |
+
def encode(self, x: torch.Tensor):
|
| 36 |
+
# x: (B, A, L)
|
| 37 |
+
z = self.conv(x) # (B, C, L')
|
| 38 |
+
z = self.pool(z).squeeze(-1) # (B, C)
|
| 39 |
+
y = self.proj(z) # (B, D)
|
| 40 |
+
if self.l2norm:
|
| 41 |
+
y = F.normalize(y, dim=-1)
|
| 42 |
+
return y, z
|
| 43 |
+
|
| 44 |
+
def forward(self, a: torch.Tensor, p: torch.Tensor, n: torch.Tensor):
|
| 45 |
+
ay, _ = self.encode(a)
|
| 46 |
+
py, _ = self.encode(p)
|
| 47 |
+
ny, _ = self.encode(n)
|
| 48 |
+
return ay, py, ny
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:aee9a63233ca456f0f7a044ef706053698d96d4d52639e07f20be38854c61414
|
| 3 |
+
size 4272400
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
numpy>=1.24
|
| 3 |
+
safetensors>=0.4.0
|
| 4 |
+
tqdm>=4.65
|
| 5 |
+
biopython>=1.83
|