File size: 2,600 Bytes
f34af6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from dataset.data import PdbDataset
from dataset.data import PdbDataModule
import hydra
from omegaconf import DictConfig
from torchsummary import summary
from utils.flows import Interpolant
import numpy as np
from utils.modelUtils import get_time_embedding

"""
ProteinFlow model
model:
  node_embed_size: 256
  edge_embed_size: 128
  symmetric: False
  node_features:
    c_s: ${model.node_embed_size}
    c_pos_emb: 128
    c_timestep_emb: 128
    embed_diffuse_mask: False
    max_num_res: 2000
    timestep_int: 1000
  edge_features:
    single_bias_transition_n: 2
    c_s: ${model.node_embed_size}
    c_p: ${model.edge_embed_size}
    relpos_k: 64
    use_rbf: True
    num_rbf: 32
    feat_dim: 64
    num_bins: 22
    self_condition: True
  ipa:
    c_s: ${model.node_embed_size}
    c_z: ${model.edge_embed_size}
    c_hidden: 128
    no_heads: 8
    no_qk_points: 8
    no_v_points: 12
    seq_tfmr_num_heads: 4
    seq_tfmr_num_layers: 2
    num_blocks: 6
"""

import torch
from torch import nn

from models.classifier import ProtClassifier
from utils import modelUtils as u
from models import ipa_pytorch


@hydra.main(version_base=None, config_path=".", config_name="test_dataset")
def run(cfg: DictConfig) -> None:
    data = PdbDataset(dataset_cfg=cfg.data.dataset, is_training=True)
    datamodule = PdbDataModule(cfg.data)
    datamodule.setup("")
    train_loader = datamodule.train_dataloader()
    interpolant = Interpolant(cfg.interpolant)
    interpolant.set_device("cpu")
    
    print(len(data))
    # print(data[0])
    print(data[0].keys())
    for i in range(3):
        batch = next(iter(train_loader))
        print(batch["trans_1"].shape)
        noisy_batch = interpolant.corrupt_batch(batch)
        
        model = ProtClassifier(cfg.model)
        model(noisy_batch)
    
    """
    print(noisy_batch["class"].shape)
    print(noisy_batch["t"].shape)
    print(noisy_batch["trans_t"].shape)
    print(noisy_batch["rotmats_t"].shape)
    """
    #time_emb = get_time_embedding(noisy_batch["t"][:, 0], 128)
    #print(time_emb.shape)
    
    # print(summary(model, data[0].shape))
run()


"""
dirname = "class_preprocessed"
df = pd.read_csv(os.path.join(dirname, "metadata.csv"))
print(df.shape[0])

pos_df = df[df["class"] == 1]
neg_df = df[df["class"] == 0]
print(pos_df.shape[0])
print(neg_df.shape[0])

filtered_df = df[df["modeled_seq_len"] <= 256]
filtered_pos_df = filtered_df[filtered_df["class"] == 1]
filtered_neg_df = filtered_df[filtered_df["class"] == 0]
print(filtered_df.shape[0])
print(filtered_pos_df.shape[0])
print(filtered_neg_df.shape[0])
"""