|
|
--- |
|
|
library_name: "pytorch" |
|
|
tags: |
|
|
- protein |
|
|
- biosequence |
|
|
- cnn |
|
|
- embedding |
|
|
license: apache-2.0 |
|
|
--- |
|
|
|
|
|
# CNNED_Protein |
|
|
|
|
|
CNN-based embedding model for protein/bio sequences (triplet/contrastive training ready). |
|
|
|
|
|
## Model Summary |
|
|
- **Input**: one-hot encoded sequence of shape `(B, A, L)` |
|
|
- **Encoder**: 1D CNN + AvgPooling stacks |
|
|
- **Output**: L2-normalized embedding `(B, D)` via projection head |
|
|
- **Training**: Designed for triplet/contrastive loss (anchor, positive, negative) |
|
|
|
|
|
### Config |
|
|
- `alphabet_size`: 27 |
|
|
- `target_size`: 128 |
|
|
- `channel`: 256 |
|
|
- `depth`: 3 |
|
|
- `kernel_size`: 7 |
|
|
- `l2norm`: True |
|
|
|
|
|
## Usage |
|
|
|
|
|
```python |
|
|
import json, torch |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
# Load config |
|
|
cfg = json.load(open("config.json","r")) |
|
|
from model import CNNED_Protein |
|
|
model = CNNED_Protein(**cfg).eval() |
|
|
|
|
|
# Load weights |
|
|
try: |
|
|
sd = load_file("model.safetensors") |
|
|
except Exception: |
|
|
sd = torch.load("model.pt", map_location="cpu") |
|
|
model.load_state_dict(sd, strict=True) |
|
|
model.eval() |
|
|
|
|
|
# Dummy inference |
|
|
# x: (B, A, L) one-hot tensor |
|
|
x = torch.randn(2, cfg['alphabet_size'], 512) |
|
|
y, z = model.encode(x) |
|
|
print(y.shape) # (2, target_size) |
|
|
``` |
|
|
|
|
|
|