mmtf commited on
Commit
91fa862
·
verified ·
1 Parent(s): e38b42f

Upload STAR-GO checkpoint + config

Browse files
Files changed (3) hide show
  1. README.md +46 -0
  2. config.toml +43 -0
  3. model.ckpt +3 -0
README.md ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: "stargo-cc"
3
+ tags:
4
+ - star-go
5
+ - protein
6
+ - gene-ontology
7
+ - bioinformatics
8
+ - pytorch
9
+ - lightning
10
+ ---
11
+
12
+ # stargo-cc
13
+
14
+ STAR-GO checkpoint published for easier discoverability. This repository stores the original Lightning `.ckpt` and the original TOML config so you can reconstruct the model as trained.
15
+
16
+ ## Files
17
+ - `model.ckpt`: PyTorch Lightning checkpoint for `TrainingModel`
18
+ - `config.toml`: training/model config (same schema as this repo's `configs/*.toml`)
19
+
20
+ ## Provenance
21
+ - W&B artifact: `contempro-cc-2020-ordered-encdec-medium:best`
22
+
23
+ ## Usage
24
+ This repository contains a Lightning checkpoint and the original TOML config. Load it like this:
25
+
26
+ ```python
27
+ import torch
28
+ from huggingface_hub import hf_hub_download
29
+
30
+ from config import from_toml
31
+ from model import TrainingModel, get_model_cls
32
+
33
+ repo_id = "mmtf/stargo-cc"
34
+ ckpt_path = hf_hub_download(repo_id, "model.ckpt")
35
+ cfg_path = hf_hub_download(repo_id, "config.toml")
36
+
37
+ cfg = from_toml(cfg_path)
38
+
39
+ module = TrainingModel.load_from_checkpoint(
40
+ ckpt_path,
41
+ model=get_model_cls(cfg.model.name)(cfg.model),
42
+ training_config=cfg.train,
43
+ )
44
+ module = module.to("cuda" if torch.cuda.is_available() else "cpu")
45
+ module.eval()
46
+ ```
config.toml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [train]
2
+ # Data paths and configuration
3
+ data_dir = "datasets/pfresgo"
4
+ go_embed_file = "ontology.embeddings.npy"
5
+ protein_embed_file = "per_residue_embeddings.h5"
6
+ subontology = "cellular_component" # overridden in train.py CLI calls
7
+ go_release = "2020" # overridden in train.py CLI calls
8
+ order_go_terms = true
9
+
10
+ # Compute settings
11
+ use_tpu = false
12
+ prepare_data = false
13
+ dm_num_workers = 0
14
+ bf16_precision = true
15
+
16
+ # Training hyperparameters
17
+ batch_size = 8
18
+ learning_rate = 6e-5
19
+ weight_decay = 0.01
20
+ max_epochs = 100
21
+ gradient_accumulation = 4
22
+
23
+ [model]
24
+ # Model type
25
+ name = "bert"
26
+ decoder = true
27
+
28
+ # Architecture configuration
29
+ hidden_dim = 256
30
+ intermediate_size = 1024
31
+ num_encoder_layers = 6
32
+ num_decoder_layers = 6
33
+ num_attention_heads = 8
34
+
35
+ # Input dimensions
36
+ go_input_dim = 200
37
+ seq_input_dim = 1024
38
+
39
+ # Regularization and activation
40
+ hidden_dropout_prob = 0.1
41
+ attention_probs_dropout_prob = 0.1
42
+ hidden_act = "gelu"
43
+ layer_norm_eps = 1e-12
model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:80c19794d909f9c745f92f439b275c8d45b52ee1f5f9f768230733a633beb129
3
+ size 137597176