gcp v1 large
Browse files- checkpoints/best_valid.pth +3 -0
- config_gcpnet_encoder.yaml +83 -0
- config_geometric_decoder.yaml +14 -0
- config_vqvae.yaml +129 -0
checkpoints/best_valid.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7d1d43950a29834e7f702409bf957e9ffb75eb3cb3952074ba4c545bb4130eaf
|
| 3 |
+
size 2545380820
|
config_gcpnet_encoder.yaml
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
features:
|
| 2 |
+
module: models.gcpnet.features.factory.ProteinFeaturiser
|
| 3 |
+
kwargs:
|
| 4 |
+
representation: CA
|
| 5 |
+
scalar_node_features:
|
| 6 |
+
- amino_acid_one_hot
|
| 7 |
+
- sequence_positional_encoding
|
| 8 |
+
- alpha
|
| 9 |
+
- kappa
|
| 10 |
+
- dihedrals
|
| 11 |
+
vector_node_features:
|
| 12 |
+
- orientation
|
| 13 |
+
edge_types:
|
| 14 |
+
- knn_16
|
| 15 |
+
scalar_edge_features:
|
| 16 |
+
- edge_distance
|
| 17 |
+
vector_edge_features:
|
| 18 |
+
- edge_vectors
|
| 19 |
+
task:
|
| 20 |
+
transform: null
|
| 21 |
+
encoder:
|
| 22 |
+
module: models.gcpnet.models.graph_encoders.gcpnet.GCPNetModel
|
| 23 |
+
kwargs:
|
| 24 |
+
num_layers: 6
|
| 25 |
+
emb_dim: 128
|
| 26 |
+
node_s_emb_dim: 128
|
| 27 |
+
node_v_emb_dim: 16
|
| 28 |
+
edge_s_emb_dim: 32
|
| 29 |
+
edge_v_emb_dim: 4
|
| 30 |
+
r_max: 10.0
|
| 31 |
+
num_rbf: 8
|
| 32 |
+
activation: silu
|
| 33 |
+
pool: sum
|
| 34 |
+
module_cfg:
|
| 35 |
+
norm_pos_diff: true
|
| 36 |
+
scalar_gate: 0
|
| 37 |
+
vector_gate: true
|
| 38 |
+
scalar_nonlinearity: silu
|
| 39 |
+
vector_nonlinearity: silu
|
| 40 |
+
nonlinearities:
|
| 41 |
+
- silu
|
| 42 |
+
- silu
|
| 43 |
+
r_max: 10.0
|
| 44 |
+
num_rbf: 8
|
| 45 |
+
bottleneck: 4
|
| 46 |
+
vector_linear: true
|
| 47 |
+
vector_identity: true
|
| 48 |
+
default_bottleneck: 4
|
| 49 |
+
predict_node_positions: false
|
| 50 |
+
predict_node_rep: true
|
| 51 |
+
node_positions_weight: 1.0
|
| 52 |
+
update_positions_with_vector_sum: false
|
| 53 |
+
enable_e3_equivariance: false
|
| 54 |
+
pool: sum
|
| 55 |
+
model_cfg:
|
| 56 |
+
h_input_dim: 49
|
| 57 |
+
chi_input_dim: 2
|
| 58 |
+
e_input_dim: 9
|
| 59 |
+
xi_input_dim: 1
|
| 60 |
+
h_hidden_dim: 128
|
| 61 |
+
chi_hidden_dim: 16
|
| 62 |
+
e_hidden_dim: 32
|
| 63 |
+
xi_hidden_dim: 4
|
| 64 |
+
num_layers: 6
|
| 65 |
+
dropout: 0.0
|
| 66 |
+
layer_cfg:
|
| 67 |
+
pre_norm: false
|
| 68 |
+
use_gcp_norm: true
|
| 69 |
+
use_gcp_dropout: true
|
| 70 |
+
use_scalar_message_attention: true
|
| 71 |
+
num_feedforward_layers: 2
|
| 72 |
+
dropout: 0.0
|
| 73 |
+
nonlinearity_slope: 0.01
|
| 74 |
+
mp_cfg:
|
| 75 |
+
edge_encoder: false
|
| 76 |
+
edge_gate: false
|
| 77 |
+
num_message_layers: 4
|
| 78 |
+
message_residual: 0
|
| 79 |
+
message_ff_multiplier: 1
|
| 80 |
+
self_message: true
|
| 81 |
+
checkpoint_path: ./models/checkpoints/structure_denoising/gcpnet/ca_bb/last.ckpt
|
| 82 |
+
top_k: 30
|
| 83 |
+
num_positional_embeddings: 16
|
config_geometric_decoder.yaml
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dimension: 1024 # Used as dim_in and dim_out for ContinuousTransformerWrapper
|
| 2 |
+
ff_mult: 4 # Multiplier for the feedforward dimension
|
| 3 |
+
depth: 16 # Number of layers in the Encoder
|
| 4 |
+
heads: 16 # Number of attention heads in the Encoder
|
| 5 |
+
rotary_pos_emb: True
|
| 6 |
+
attn_flash: True # FA-2 if installed
|
| 7 |
+
attn_kv_heads: 1 # GQA
|
| 8 |
+
qk_norm: True
|
| 9 |
+
pre_norm: True
|
| 10 |
+
residual_attn: False # Set pre_norm to False if residual_attn is True
|
| 11 |
+
num_memory_tokens: 0 # Number of memory tokens, 0 means no memory tokens
|
| 12 |
+
|
| 13 |
+
direction_loss_bins: 16
|
| 14 |
+
pos_scale_factor: 1.0
|
config_vqvae.yaml
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fix_seed: 0
|
| 2 |
+
checkpoints_every: 1
|
| 3 |
+
tensorboard_log: True
|
| 4 |
+
tqdm_progress_bar: False
|
| 5 |
+
result_path: ./results/vqvae
|
| 6 |
+
find_unused_parameters: True
|
| 7 |
+
dispatch_batches: False
|
| 8 |
+
even_batches: True
|
| 9 |
+
non_blocking: False
|
| 10 |
+
split_batches: False
|
| 11 |
+
|
| 12 |
+
resume:
|
| 13 |
+
enabled: True
|
| 14 |
+
resume_path: results/vqvae/2025-07-17__16-36-40/checkpoints/epoch_1.pth
|
| 15 |
+
restart_optimizer: True
|
| 16 |
+
discard_decoder_weights: False
|
| 17 |
+
|
| 18 |
+
model:
|
| 19 |
+
compile_model: False
|
| 20 |
+
max_length: 1280
|
| 21 |
+
decoder_output_scaling_factor: 1 # Added scaling factor for backbone prediction outputs
|
| 22 |
+
use_ndlinear: False # Toggle for using NdLinear instead of Conv1d layers
|
| 23 |
+
encoder:
|
| 24 |
+
name: gcpnet # gcpnet
|
| 25 |
+
freeze_parameters: False
|
| 26 |
+
pretrained:
|
| 27 |
+
enabled: True
|
| 28 |
+
config_path: ./configs/pretrained/structure_denoising_pretrained_config.yaml
|
| 29 |
+
checkpoint_path: ./models/checkpoints/structure_denoising/gcpnet/ca_bb/last.ckpt # define your checkpoint directory here
|
| 30 |
+
vqvae:
|
| 31 |
+
vector_quantization:
|
| 32 |
+
enabled: True
|
| 33 |
+
freeze_parameters: False
|
| 34 |
+
dim: 256
|
| 35 |
+
decay: 0.995
|
| 36 |
+
codebook_size: 4096
|
| 37 |
+
commitment_weight: 0.05
|
| 38 |
+
orthogonal_reg_weight: 10
|
| 39 |
+
orthogonal_reg_max_codes: 512
|
| 40 |
+
orthogonal_reg_active_codes_only: True
|
| 41 |
+
rotation_trick: True
|
| 42 |
+
threshold_ema_dead_code: 2
|
| 43 |
+
kmeans_init: True
|
| 44 |
+
kmeans_iters: 10
|
| 45 |
+
alpha: 0.25
|
| 46 |
+
encoder:
|
| 47 |
+
freeze_parameters: False
|
| 48 |
+
dimension: 1024 # Used as dim_in and dim_out for ContinuousTransformerWrapper
|
| 49 |
+
ff_mult: 4 # Multiplier for the feedforward dimension
|
| 50 |
+
depth: 12 # Number of layers in the Encoder
|
| 51 |
+
heads: 12 # Number of attention heads in the Encoder
|
| 52 |
+
rotary_pos_emb: True
|
| 53 |
+
attn_flash: True # FA-2 if installed
|
| 54 |
+
attn_kv_heads: 3 # GQA
|
| 55 |
+
qk_norm: True
|
| 56 |
+
pre_norm: True
|
| 57 |
+
residual_attn: False # Set pre_norm to False if residual_attn is True
|
| 58 |
+
num_memory_tokens: 0 # Number of memory tokens, 0 means no memory tokens
|
| 59 |
+
decoder:
|
| 60 |
+
name: geometric_decoder # geometric_decoder
|
| 61 |
+
freeze_parameters: False
|
| 62 |
+
|
| 63 |
+
train_settings:
|
| 64 |
+
data_path: ../../datasets/vqvae/uniref_50/
|
| 65 |
+
num_epochs: 16
|
| 66 |
+
shuffle: True
|
| 67 |
+
mixed_precision: bf16 # no, fp16, bf16, fp8
|
| 68 |
+
save_pdb_every: 1
|
| 69 |
+
batch_size: 4
|
| 70 |
+
num_workers: 24
|
| 71 |
+
grad_accumulation: 1
|
| 72 |
+
max_task_samples: 24000000
|
| 73 |
+
profile_train_loop: False
|
| 74 |
+
cutoff_augmentation:
|
| 75 |
+
enabled: False
|
| 76 |
+
probability: 0.5
|
| 77 |
+
min_length: 25
|
| 78 |
+
nan_augmentation:
|
| 79 |
+
enabled: True
|
| 80 |
+
probability: 0.05
|
| 81 |
+
max_length: 30
|
| 82 |
+
gradient_norm_logging_freq: 50 # How often to calculate and log gradient norm (in steps)
|
| 83 |
+
losses:
|
| 84 |
+
alignment_strategy: kabsch # kabsch, quaternion, no
|
| 85 |
+
mse:
|
| 86 |
+
enabled: True
|
| 87 |
+
weight: 0.005
|
| 88 |
+
backbone_distance:
|
| 89 |
+
enabled: True
|
| 90 |
+
weight: 0.01
|
| 91 |
+
backbone_direction:
|
| 92 |
+
enabled: True
|
| 93 |
+
weight: 0.05
|
| 94 |
+
binned_distance_classification:
|
| 95 |
+
enabled: False
|
| 96 |
+
weight: 0.01
|
| 97 |
+
binned_direction_classification:
|
| 98 |
+
enabled: False
|
| 99 |
+
weight: 0.01
|
| 100 |
+
|
| 101 |
+
valid_settings:
|
| 102 |
+
data_path: ../../datasets/vqvae/whole_validation_2048_h5/
|
| 103 |
+
do_every: 1
|
| 104 |
+
save_pdb_every: 1
|
| 105 |
+
batch_size: 8
|
| 106 |
+
num_workers: 0
|
| 107 |
+
|
| 108 |
+
visualization_settings:
|
| 109 |
+
data_path: ../../datasets/vqvae/whole_validation_2048_h5/validation_set_2048_h5
|
| 110 |
+
fasta_path: visualization/Rep_subfamily_basedon_S40pdb.fa
|
| 111 |
+
do_every: 8192
|
| 112 |
+
batch_size: 1
|
| 113 |
+
num_workers: 4
|
| 114 |
+
|
| 115 |
+
optimizer:
|
| 116 |
+
name: adam
|
| 117 |
+
lr: 5e-5
|
| 118 |
+
weight_decouple: True
|
| 119 |
+
weight_decay: 1e-4
|
| 120 |
+
eps: 1e-7
|
| 121 |
+
beta_1: 0.9
|
| 122 |
+
beta_2: 0.98
|
| 123 |
+
use_8bit_adam: True
|
| 124 |
+
grad_clip_norm: 1
|
| 125 |
+
decay:
|
| 126 |
+
warmup: 16000
|
| 127 |
+
min_lr: 1e-6
|
| 128 |
+
gamma: 0.2
|
| 129 |
+
num_restarts: 1
|