File size: 2,600 Bytes
f34af6f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 | from dataset.data import PdbDataset
from dataset.data import PdbDataModule
import hydra
from omegaconf import DictConfig
from torchsummary import summary
from utils.flows import Interpolant
import numpy as np
from utils.modelUtils import get_time_embedding
"""
ProteinFlow model
model:
node_embed_size: 256
edge_embed_size: 128
symmetric: False
node_features:
c_s: ${model.node_embed_size}
c_pos_emb: 128
c_timestep_emb: 128
embed_diffuse_mask: False
max_num_res: 2000
timestep_int: 1000
edge_features:
single_bias_transition_n: 2
c_s: ${model.node_embed_size}
c_p: ${model.edge_embed_size}
relpos_k: 64
use_rbf: True
num_rbf: 32
feat_dim: 64
num_bins: 22
self_condition: True
ipa:
c_s: ${model.node_embed_size}
c_z: ${model.edge_embed_size}
c_hidden: 128
no_heads: 8
no_qk_points: 8
no_v_points: 12
seq_tfmr_num_heads: 4
seq_tfmr_num_layers: 2
num_blocks: 6
"""
import torch
from torch import nn
from models.classifier import ProtClassifier
from utils import modelUtils as u
from models import ipa_pytorch
@hydra.main(version_base=None, config_path=".", config_name="test_dataset")
def run(cfg: DictConfig) -> None:
data = PdbDataset(dataset_cfg=cfg.data.dataset, is_training=True)
datamodule = PdbDataModule(cfg.data)
datamodule.setup("")
train_loader = datamodule.train_dataloader()
interpolant = Interpolant(cfg.interpolant)
interpolant.set_device("cpu")
print(len(data))
# print(data[0])
print(data[0].keys())
for i in range(3):
batch = next(iter(train_loader))
print(batch["trans_1"].shape)
noisy_batch = interpolant.corrupt_batch(batch)
model = ProtClassifier(cfg.model)
model(noisy_batch)
"""
print(noisy_batch["class"].shape)
print(noisy_batch["t"].shape)
print(noisy_batch["trans_t"].shape)
print(noisy_batch["rotmats_t"].shape)
"""
#time_emb = get_time_embedding(noisy_batch["t"][:, 0], 128)
#print(time_emb.shape)
# print(summary(model, data[0].shape))
run()
"""
dirname = "class_preprocessed"
df = pd.read_csv(os.path.join(dirname, "metadata.csv"))
print(df.shape[0])
pos_df = df[df["class"] == 1]
neg_df = df[df["class"] == 0]
print(pos_df.shape[0])
print(neg_df.shape[0])
filtered_df = df[df["modeled_seq_len"] <= 256]
filtered_pos_df = filtered_df[filtered_df["class"] == 1]
filtered_neg_df = filtered_df[filtered_df["class"] == 0]
print(filtered_df.shape[0])
print(filtered_pos_df.shape[0])
print(filtered_neg_df.shape[0])
""" |