Mahdip72 commited on
Commit
64bb4ff
·
verified ·
1 Parent(s): 08f1ade

gcp v1 large

Browse files
checkpoints/best_valid.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d1d43950a29834e7f702409bf957e9ffb75eb3cb3952074ba4c545bb4130eaf
3
+ size 2545380820
config_gcpnet_encoder.yaml ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ features:
2
+ module: models.gcpnet.features.factory.ProteinFeaturiser
3
+ kwargs:
4
+ representation: CA
5
+ scalar_node_features:
6
+ - amino_acid_one_hot
7
+ - sequence_positional_encoding
8
+ - alpha
9
+ - kappa
10
+ - dihedrals
11
+ vector_node_features:
12
+ - orientation
13
+ edge_types:
14
+ - knn_16
15
+ scalar_edge_features:
16
+ - edge_distance
17
+ vector_edge_features:
18
+ - edge_vectors
19
+ task:
20
+ transform: null
21
+ encoder:
22
+ module: models.gcpnet.models.graph_encoders.gcpnet.GCPNetModel
23
+ kwargs:
24
+ num_layers: 6
25
+ emb_dim: 128
26
+ node_s_emb_dim: 128
27
+ node_v_emb_dim: 16
28
+ edge_s_emb_dim: 32
29
+ edge_v_emb_dim: 4
30
+ r_max: 10.0
31
+ num_rbf: 8
32
+ activation: silu
33
+ pool: sum
34
+ module_cfg:
35
+ norm_pos_diff: true
36
+ scalar_gate: 0
37
+ vector_gate: true
38
+ scalar_nonlinearity: silu
39
+ vector_nonlinearity: silu
40
+ nonlinearities:
41
+ - silu
42
+ - silu
43
+ r_max: 10.0
44
+ num_rbf: 8
45
+ bottleneck: 4
46
+ vector_linear: true
47
+ vector_identity: true
48
+ default_bottleneck: 4
49
+ predict_node_positions: false
50
+ predict_node_rep: true
51
+ node_positions_weight: 1.0
52
+ update_positions_with_vector_sum: false
53
+ enable_e3_equivariance: false
54
+ pool: sum
55
+ model_cfg:
56
+ h_input_dim: 49
57
+ chi_input_dim: 2
58
+ e_input_dim: 9
59
+ xi_input_dim: 1
60
+ h_hidden_dim: 128
61
+ chi_hidden_dim: 16
62
+ e_hidden_dim: 32
63
+ xi_hidden_dim: 4
64
+ num_layers: 6
65
+ dropout: 0.0
66
+ layer_cfg:
67
+ pre_norm: false
68
+ use_gcp_norm: true
69
+ use_gcp_dropout: true
70
+ use_scalar_message_attention: true
71
+ num_feedforward_layers: 2
72
+ dropout: 0.0
73
+ nonlinearity_slope: 0.01
74
+ mp_cfg:
75
+ edge_encoder: false
76
+ edge_gate: false
77
+ num_message_layers: 4
78
+ message_residual: 0
79
+ message_ff_multiplier: 1
80
+ self_message: true
81
+ checkpoint_path: ./models/checkpoints/structure_denoising/gcpnet/ca_bb/last.ckpt
82
+ top_k: 30
83
+ num_positional_embeddings: 16
config_geometric_decoder.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dimension: 1024 # Used as dim_in and dim_out for ContinuousTransformerWrapper
2
+ ff_mult: 4 # Multiplier for the feedforward dimension
3
+ depth: 16 # Number of layers in the Encoder
4
+ heads: 16 # Number of attention heads in the Encoder
5
+ rotary_pos_emb: True
6
+ attn_flash: True # FA-2 if installed
7
+ attn_kv_heads: 1 # GQA
8
+ qk_norm: True
9
+ pre_norm: True
10
+ residual_attn: False # Set pre_norm to False if residual_attn is True
11
+ num_memory_tokens: 0 # Number of memory tokens, 0 means no memory tokens
12
+
13
+ direction_loss_bins: 16
14
+ pos_scale_factor: 1.0
config_vqvae.yaml ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fix_seed: 0
2
+ checkpoints_every: 1
3
+ tensorboard_log: True
4
+ tqdm_progress_bar: False
5
+ result_path: ./results/vqvae
6
+ find_unused_parameters: True
7
+ dispatch_batches: False
8
+ even_batches: True
9
+ non_blocking: False
10
+ split_batches: False
11
+
12
+ resume:
13
+ enabled: True
14
+ resume_path: results/vqvae/2025-07-17__16-36-40/checkpoints/epoch_1.pth
15
+ restart_optimizer: True
16
+ discard_decoder_weights: False
17
+
18
+ model:
19
+ compile_model: False
20
+ max_length: 1280
21
+ decoder_output_scaling_factor: 1 # Added scaling factor for backbone prediction outputs
22
+ use_ndlinear: False # Toggle for using NdLinear instead of Conv1d layers
23
+ encoder:
24
+ name: gcpnet # gcpnet
25
+ freeze_parameters: False
26
+ pretrained:
27
+ enabled: True
28
+ config_path: ./configs/pretrained/structure_denoising_pretrained_config.yaml
29
+ checkpoint_path: ./models/checkpoints/structure_denoising/gcpnet/ca_bb/last.ckpt # define your checkpoint directory here
30
+ vqvae:
31
+ vector_quantization:
32
+ enabled: True
33
+ freeze_parameters: False
34
+ dim: 256
35
+ decay: 0.995
36
+ codebook_size: 4096
37
+ commitment_weight: 0.05
38
+ orthogonal_reg_weight: 10
39
+ orthogonal_reg_max_codes: 512
40
+ orthogonal_reg_active_codes_only: True
41
+ rotation_trick: True
42
+ threshold_ema_dead_code: 2
43
+ kmeans_init: True
44
+ kmeans_iters: 10
45
+ alpha: 0.25
46
+ encoder:
47
+ freeze_parameters: False
48
+ dimension: 1024 # Used as dim_in and dim_out for ContinuousTransformerWrapper
49
+ ff_mult: 4 # Multiplier for the feedforward dimension
50
+ depth: 12 # Number of layers in the Encoder
51
+ heads: 12 # Number of attention heads in the Encoder
52
+ rotary_pos_emb: True
53
+ attn_flash: True # FA-2 if installed
54
+ attn_kv_heads: 3 # GQA
55
+ qk_norm: True
56
+ pre_norm: True
57
+ residual_attn: False # Set pre_norm to False if residual_attn is True
58
+ num_memory_tokens: 0 # Number of memory tokens, 0 means no memory tokens
59
+ decoder:
60
+ name: geometric_decoder # geometric_decoder
61
+ freeze_parameters: False
62
+
63
+ train_settings:
64
+ data_path: ../../datasets/vqvae/uniref_50/
65
+ num_epochs: 16
66
+ shuffle: True
67
+ mixed_precision: bf16 # no, fp16, bf16, fp8
68
+ save_pdb_every: 1
69
+ batch_size: 4
70
+ num_workers: 24
71
+ grad_accumulation: 1
72
+ max_task_samples: 24000000
73
+ profile_train_loop: False
74
+ cutoff_augmentation:
75
+ enabled: False
76
+ probability: 0.5
77
+ min_length: 25
78
+ nan_augmentation:
79
+ enabled: True
80
+ probability: 0.05
81
+ max_length: 30
82
+ gradient_norm_logging_freq: 50 # How often to calculate and log gradient norm (in steps)
83
+ losses:
84
+ alignment_strategy: kabsch # kabsch, quaternion, no
85
+ mse:
86
+ enabled: True
87
+ weight: 0.005
88
+ backbone_distance:
89
+ enabled: True
90
+ weight: 0.01
91
+ backbone_direction:
92
+ enabled: True
93
+ weight: 0.05
94
+ binned_distance_classification:
95
+ enabled: False
96
+ weight: 0.01
97
+ binned_direction_classification:
98
+ enabled: False
99
+ weight: 0.01
100
+
101
+ valid_settings:
102
+ data_path: ../../datasets/vqvae/whole_validation_2048_h5/
103
+ do_every: 1
104
+ save_pdb_every: 1
105
+ batch_size: 8
106
+ num_workers: 0
107
+
108
+ visualization_settings:
109
+ data_path: ../../datasets/vqvae/whole_validation_2048_h5/validation_set_2048_h5
110
+ fasta_path: visualization/Rep_subfamily_basedon_S40pdb.fa
111
+ do_every: 8192
112
+ batch_size: 1
113
+ num_workers: 4
114
+
115
+ optimizer:
116
+ name: adam
117
+ lr: 5e-5
118
+ weight_decouple: True
119
+ weight_decay: 1e-4
120
+ eps: 1e-7
121
+ beta_1: 0.9
122
+ beta_2: 0.98
123
+ use_8bit_adam: True
124
+ grad_clip_norm: 1
125
+ decay:
126
+ warmup: 16000
127
+ min_lr: 1e-6
128
+ gamma: 0.2
129
+ num_restarts: 1