File size: 2,396 Bytes
9205b56 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | import os
import numpy as np
import torch
import argparse
import shutil
from torch.utils.data import DataLoader
from tqdm import tqdm
from dataset import Autoencoder_dataset
from model import Autoencoder
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_path', type=str, required=True)
parser.add_argument('--dataset_name', type=str, required=True)
parser.add_argument('--encoder_dims',
nargs = '+',
type=int,
default=[256, 128, 64, 32, 3],
)
parser.add_argument('--decoder_dims',
nargs = '+',
type=int,
default=[16, 32, 64, 128, 256, 256, 512],
)
args = parser.parse_args()
dataset_name = args.dataset_name
encoder_hidden_dims = args.encoder_dims
decoder_hidden_dims = args.decoder_dims
dataset_path = args.dataset_path
ckpt_path = f"ckpt/{dataset_name}/best_ckpt.pth"
data_dir = f"{dataset_path}/language_features"
output_dir = f"{dataset_path}/language_features_dim3"
os.makedirs(output_dir, exist_ok=True)
# copy the segmentation map
for filename in os.listdir(data_dir):
if filename.endswith("_s.npy"):
source_path = os.path.join(data_dir, filename)
target_path = os.path.join(output_dir, filename)
shutil.copy(source_path, target_path)
checkpoint = torch.load(ckpt_path)
train_dataset = Autoencoder_dataset(data_dir)
test_loader = DataLoader(
dataset=train_dataset,
batch_size=256,
shuffle=False,
num_workers=16,
drop_last=False
)
model = Autoencoder(encoder_hidden_dims, decoder_hidden_dims).to("cuda:0")
model.load_state_dict(checkpoint)
model.eval()
for idx, feature in tqdm(enumerate(test_loader)):
data = feature.to("cuda:0")
with torch.no_grad():
outputs = model.encode(data).to("cpu").numpy()
if idx == 0:
features = outputs
else:
features = np.concatenate([features, outputs], axis=0)
os.makedirs(output_dir, exist_ok=True)
start = 0
for k,v in train_dataset.data_dic.items():
path = os.path.join(output_dir, k)
np.save(path, features[start:start+v])
start += v
|