gabrielbianchin commited on
Commit
117e99b
·
1 Parent(s): 75a0fcb

update files

Browse files
README.md CHANGED
@@ -70,7 +70,7 @@ with torch.no_grad():
70
  mol_rep = (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-8)
71
 
72
  # seqscreen
73
- seqscreen = AutoModel.from_pretrained('SaeedLab/SeqScreen-Frozen', trust_remote_code=True).eval()
74
 
75
  with torch.no_grad():
76
  outputs = seqscreen(prot=prot_rep, mol=mol_rep)
 
70
  mol_rep = (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-8)
71
 
72
  # seqscreen
73
+ seqscreen = AutoModel.from_pretrained('SaeedLab/SeqScreen-Finetuning', trust_remote_code=True).eval()
74
 
75
  with torch.no_grad():
76
  outputs = seqscreen(prot=prot_rep, mol=mol_rep)
config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "SeqScreenModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_seqscreen.SeqScreenConfig",
7
+ "AutoModel": "modeling_seqscreen.SeqScreenModel"
8
+ },
9
+ "dropout": 0.1,
10
+ "dtype": "float32",
11
+ "esm2_model_name": "facebook/esm2_t36_3B_UR50D",
12
+ "lora_adapter_repo": "SaeedLab/SeqScreen-lora",
13
+ "model_type": "seqscreen",
14
+ "mol_dim": 768,
15
+ "proj_dim": 512,
16
+ "prot_dim": 2560,
17
+ "transformers_version": "4.57.3"
18
+ }
configuration_seqscreen.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class SeqScreenConfig(PretrainedConfig):
5
+ model_type = "seqscreen"
6
+
7
+ def __init__(
8
+ self,
9
+ prot_dim: int = 2560,
10
+ mol_dim: int = 768,
11
+ proj_dim: int = 512,
12
+ dropout: float = 0.1,
13
+ esm2_model_name: str = "facebook/esm2_t36_3B_UR50D",
14
+ lora_adapter_repo: str = None,
15
+ **kwargs,
16
+ ):
17
+ super().__init__(**kwargs)
18
+ self.prot_dim = prot_dim
19
+ self.mol_dim = mol_dim
20
+ self.proj_dim = proj_dim
21
+ self.dropout = dropout
22
+ self.esm2_model_name = esm2_model_name
23
+ self.lora_adapter_repo = lora_adapter_repo
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c4a55beefcc43242daf84849856898dcf03039d21a5849c1878e7a2edc05042
3
+ size 8930448
modeling_seqscreen.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from dataclasses import dataclass
5
+ import torch
6
+ from transformers.utils import ModelOutput
7
+ from transformers import PreTrainedModel
8
+
9
+ from .configuration_seqscreen import SeqScreenConfig
10
+
11
+ @dataclass
12
+ class SeqScreenModelOutput(ModelOutput):
13
+ prot_rep: torch.FloatTensor = None
14
+ mol_rep: torch.FloatTensor = None
15
+ similarity: torch.FloatTensor = None
16
+
17
+ class ProjectionLayer(nn.Module):
18
+ def __init__(self, in_dim, out_dim, dropout):
19
+ super().__init__()
20
+ self.projection = nn.Sequential(
21
+ nn.Linear(in_dim, out_dim),
22
+ nn.LayerNorm(out_dim),
23
+ nn.GELU(),
24
+ nn.Dropout(dropout),
25
+ nn.Linear(out_dim, out_dim)
26
+ )
27
+
28
+ def forward(self, x):
29
+ x = self.projection(x)
30
+ return F.normalize(x, dim=-1)
31
+
32
+
33
+ class SeqScreenModel(PreTrainedModel):
34
+ config_class = SeqScreenConfig
35
+ base_model_prefix = "seqscreen"
36
+
37
+ def __init__(self, config: SeqScreenConfig):
38
+ super().__init__(config)
39
+
40
+ self.proj_prot = ProjectionLayer(config.prot_dim, config.proj_dim, dropout=config.dropout)
41
+ self.proj_mol = ProjectionLayer(config.mol_dim, config.proj_dim, dropout=config.dropout)
42
+
43
+ self.post_init()
44
+
45
+ def forward(self, prot: torch.Tensor, mol: torch.Tensor):
46
+ prot_rep = self.proj_prot(prot)
47
+ mol_rep = self.proj_mol(mol)
48
+ similarity = prot_rep @ mol_rep.T
49
+
50
+ return SeqScreenModelOutput(
51
+ prot_rep=prot_rep,
52
+ mol_rep=mol_rep,
53
+ similarity=similarity
54
+ )