FlowProt / model /test.py
alibtsd's picture
Deploy FlowProt Docker Space
f34af6f verified
Raw
History Blame Contribute Delete
2.6 kB
from dataset.data import PdbDataset
from dataset.data import PdbDataModule
import hydra
from omegaconf import DictConfig
from torchsummary import summary
from utils.flows import Interpolant
import numpy as np
from utils.modelUtils import get_time_embedding
"""
ProteinFlow model
model:
node_embed_size: 256
edge_embed_size: 128
symmetric: False
node_features:
c_s: ${model.node_embed_size}
c_pos_emb: 128
c_timestep_emb: 128
embed_diffuse_mask: False
max_num_res: 2000
timestep_int: 1000
edge_features:
single_bias_transition_n: 2
c_s: ${model.node_embed_size}
c_p: ${model.edge_embed_size}
relpos_k: 64
use_rbf: True
num_rbf: 32
feat_dim: 64
num_bins: 22
self_condition: True
ipa:
c_s: ${model.node_embed_size}
c_z: ${model.edge_embed_size}
c_hidden: 128
no_heads: 8
no_qk_points: 8
no_v_points: 12
seq_tfmr_num_heads: 4
seq_tfmr_num_layers: 2
num_blocks: 6
"""
import torch
from torch import nn
from models.classifier import ProtClassifier
from utils import modelUtils as u
from models import ipa_pytorch
@hydra.main(version_base=None, config_path=".", config_name="test_dataset")
def run(cfg: DictConfig) -> None:
data = PdbDataset(dataset_cfg=cfg.data.dataset, is_training=True)
datamodule = PdbDataModule(cfg.data)
datamodule.setup("")
train_loader = datamodule.train_dataloader()
interpolant = Interpolant(cfg.interpolant)
interpolant.set_device("cpu")
print(len(data))
# print(data[0])
print(data[0].keys())
for i in range(3):
batch = next(iter(train_loader))
print(batch["trans_1"].shape)
noisy_batch = interpolant.corrupt_batch(batch)
model = ProtClassifier(cfg.model)
model(noisy_batch)
"""
print(noisy_batch["class"].shape)
print(noisy_batch["t"].shape)
print(noisy_batch["trans_t"].shape)
print(noisy_batch["rotmats_t"].shape)
"""
#time_emb = get_time_embedding(noisy_batch["t"][:, 0], 128)
#print(time_emb.shape)
# print(summary(model, data[0].shape))
run()
"""
dirname = "class_preprocessed"
df = pd.read_csv(os.path.join(dirname, "metadata.csv"))
print(df.shape[0])
pos_df = df[df["class"] == 1]
neg_df = df[df["class"] == 0]
print(pos_df.shape[0])
print(neg_df.shape[0])
filtered_df = df[df["modeled_seq_len"] <= 256]
filtered_pos_df = filtered_df[filtered_df["class"] == 1]
filtered_neg_df = filtered_df[filtered_df["class"] == 0]
print(filtered_df.shape[0])
print(filtered_pos_df.shape[0])
print(filtered_neg_df.shape[0])
"""