| import os |
| import sys |
| import re |
|
|
| sys.path.append(os.path.join(os.path.dirname(__file__), '..')) |
|
|
| def main(): |
| |
| models_dir = os.path.join(os.path.dirname(__file__), '..', 'models') |
| if not os.path.exists(models_dir): |
| os.makedirs(models_dir) |
|
|
| log_dir = os.path.join(os.path.dirname(__file__), '..', 'logs') |
| trim_str = 'logs_Active_Dmax_0.6_pDC50_6.0_test_split_0.1_random_random_cv_model_fold0/logs_Active_Dmax_0.6_pDC50_6.0_test_split_0.1_' |
| |
| |
| |
| for root, dirs, files in os.walk(log_dir): |
| for file in files: |
| if file.endswith('.ckpt'): |
| checkpoint_file = os.path.join(root, file) |
| model_name = file.split(trim_str)[-1] |
| if 'tanimoto' in root: |
| split_type = 'tanimoto' |
| elif 'random' in root: |
| split_type = 'random' |
| elif 'uniprot' in root: |
| split_type = 'uniprot' |
| else: |
| raise ValueError('Unknown split type') |
| if 'fold' in root: |
| |
| fold = root.split('fold')[-1][0] |
| |
| model_name = model_name.replace("protac", f"cv_model_{split_type}_fold{fold}") |
| else: |
| model_name = model_name.replace("val_", "test_") |
| |
| |
| base_model_name = model_name.split('-')[0] |
| old_model_name = None |
| |
| for model in os.listdir(models_dir): |
| if base_model_name in model: |
| old_model_name = model |
| break |
| |
| |
| |
| if old_model_name is not None: |
| if 'val_acc' in model_name: |
| old_acc = float(re.search(r'val_acc=(\d+\.\d+)', old_model_name).group(1)) |
| old_roc_auc = float(re.search(r'val_roc_auc=(\d+\.\d+)', old_model_name).group(1)) |
| new_acc = float(re.search(r'val_acc=(\d+\.\d+)', model_name).group(1)) |
| new_roc_auc = float(re.search(r'val_roc_auc=(\d+\.\d+)', model_name).group(1)) |
| if new_acc > old_acc and new_roc_auc > old_roc_auc: |
| print(f'Replacing {old_model_name} with {model_name}') |
| os.system(f'rm {os.path.join(models_dir, old_model_name)}') |
| os.system(f'cp {checkpoint_file} {os.path.join(models_dir, model_name)}') |
| if 'test_acc' in model_name: |
| old_acc = float(re.search(r'test_acc=(\d+\.\d+)', old_model_name).group(1)) |
| old_roc_auc = float(re.search(r'test_roc_auc=(\d+\.\d+)', old_model_name).group(1)) |
| new_acc = float(re.search(r'test_acc=(\d+\.\d+)', model_name).group(1)) |
| new_roc_auc = float(re.search(r'test_roc_auc=(\d+\.\d+)', model_name).group(1)) |
| if new_acc > old_acc and new_roc_auc > old_roc_auc: |
| print(f'Replacing {old_model_name} with {model_name}') |
| os.system(f'rm {os.path.join(models_dir, old_model_name)}') |
| os.system(f'cp {checkpoint_file} {os.path.join(models_dir, model_name)}') |
| else: |
| print(f'Copying {model_name}') |
| os.system(f'cp {checkpoint_file} {os.path.join(models_dir, model_name)}') |
|
|
| if __name__ == '__main__': |
| main() |