|
|
import pandas as pd
|
|
|
import numpy as np
|
|
|
import os
|
|
|
import torch
|
|
|
from sklearn.model_selection import train_test_split
|
|
|
import logging
|
|
|
|
|
|
|
|
|
def process_lsvq(train_data_name, test_data_name, metadata_path, feature_path, network_name):
|
|
|
train_df = pd.read_csv(f'{metadata_path}/{train_data_name.upper()}_metadata.csv')
|
|
|
test_df = pd.read_csv(f'{metadata_path}/{test_data_name.upper()}_metadata.csv')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_vids = test_df['vid']
|
|
|
|
|
|
|
|
|
train_scores = train_df['mos'].tolist()
|
|
|
test_scores = test_df['mos'].tolist()
|
|
|
train_mos_list = train_scores
|
|
|
test_mos_list = test_scores
|
|
|
|
|
|
|
|
|
sorted_train_df = pd.DataFrame({'vid': train_df['vid'], 'framerate': train_df['framerate'], 'MOS': train_mos_list, 'MOS_raw': train_df['mos']})
|
|
|
sorted_test_df = pd.DataFrame({'vid': test_df['vid'], 'framerate': test_df['framerate'], 'MOS': test_mos_list, 'MOS_raw': test_df['mos']})
|
|
|
|
|
|
|
|
|
train_features = torch.load(f'{feature_path}/{network_name}_{train_data_name}_features.pt')
|
|
|
print(f"loaded {train_data_name}: dimensions are {train_features.shape}")
|
|
|
test_features = torch.load(f'{feature_path}/{network_name}_{test_data_name}_features.pt')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(len(train_features))
|
|
|
print(len(test_features))
|
|
|
|
|
|
|
|
|
sorted_train_df.to_csv(f'{metadata_path}mos_files/{train_data_name}_MOS_train.csv', index=False)
|
|
|
sorted_test_df.to_csv(f'{metadata_path}mos_files/{train_data_name}_MOS_test.csv', index=False)
|
|
|
os.makedirs(os.path.join(feature_path, "split_train_test"), exist_ok=True)
|
|
|
torch.save(train_features, f'{feature_path}/split_train_test/{network_name}_{train_data_name}_train_features.pt')
|
|
|
torch.save(test_features, f'{feature_path}/split_train_test/{network_name}_{test_data_name}_test_features.pt')
|
|
|
|
|
|
return train_features, test_features, test_vids
|
|
|
|
|
|
def process_other(data_name, test_size, random_state, metadata_path, feature_path, network_name):
|
|
|
metadata_name = f'{data_name.upper()}_metadata.csv'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
df = pd.read_csv(f'{metadata_path}/{metadata_name}')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unique_vids = df['vid'].unique()
|
|
|
|
|
|
|
|
|
train_vids, test_vids = train_test_split(unique_vids, test_size=test_size, random_state=random_state)
|
|
|
|
|
|
|
|
|
train_df = df[df['vid'].isin(train_vids)]
|
|
|
test_df = df[df['vid'].isin(test_vids)]
|
|
|
|
|
|
|
|
|
train_scores = train_df['mos'].tolist()
|
|
|
test_scores = test_df['mos'].tolist()
|
|
|
train_mos_list = train_scores
|
|
|
test_mos_list = test_scores
|
|
|
|
|
|
|
|
|
sorted_train_df = pd.DataFrame({'vid': train_df['vid'], 'framerate': train_df['framerate'], 'MOS': train_mos_list, 'MOS_raw': train_df['mos']})
|
|
|
sorted_test_df = pd.DataFrame({'vid': test_df['vid'], 'framerate': test_df['framerate'], 'MOS': test_mos_list, 'MOS_raw': test_df['mos']})
|
|
|
|
|
|
|
|
|
features = torch.load(f'{feature_path}/{network_name}_{data_name}_features.pt')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_features = features[train_df.index]
|
|
|
test_features = features[test_df.index]
|
|
|
|
|
|
|
|
|
sorted_train_df.to_csv(f'{metadata_path}mos_files/{data_name}_MOS_train.csv', index=False)
|
|
|
sorted_test_df.to_csv(f'{metadata_path}mos_files/{data_name}_MOS_test.csv', index=False)
|
|
|
os.makedirs(os.path.join(feature_path, "split_train_test"), exist_ok=True)
|
|
|
torch.save(train_features, f'{feature_path}/split_train_test/{network_name}_{data_name}_train_features.pt')
|
|
|
torch.save(test_features, f'{feature_path}/split_train_test/{network_name}_{data_name}_test_features.pt')
|
|
|
|
|
|
return train_features, test_features, test_vids
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
network_name = "slowfast"
|
|
|
data_name = "test"
|
|
|
metadata_path = '../../metadata/'
|
|
|
feature_path = '../../features/konvid_1k_test/slowfast/'
|
|
|
|
|
|
|
|
|
test_size = 0.2
|
|
|
random_state = None
|
|
|
|
|
|
if data_name == 'lsvq_train':
|
|
|
test_data_name = 'lsvq_test'
|
|
|
process_lsvq(data_name, test_data_name, metadata_path, feature_path, network_name)
|
|
|
|
|
|
else:
|
|
|
process_other(data_name, test_size, random_state, metadata_path, feature_path, network_name)
|
|
|
|