| import os | |
| import torch | |
| root = "/content/drive/MyDrive" | |
| mae_config={ | |
| "lr":1e-4, | |
| "warmup":5, | |
| "weight_decay":5e-4, | |
| "num_epochs":200, | |
| "num_classes":14, | |
| "zip_path":os.path.join(root,"CheXpert-v1.0-small","chexpert.zip"), | |
| "resume":os.path.join(root,"CheXpert-v1.0-small","maecheckpoints","best_mae.pth"), | |
| "logdir":os.path.join(root,"CheXpert-v1.0-small","maelogs"), | |
| "checkpoints":os.path.join(root,"CheXpert-v1.0-small","maecheckpoints"), | |
| "datadir":root, | |
| "lmdb":os.path.join(root,"CheXpert-v1.0-small","lmdb"), | |
| "csv":os.path.join(root,"CheXpert-v1.0-small","train.csv"), | |
| "batch_size":96, | |
| "device":torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"), | |
| "accumulation":1, | |
| "dirsToMake":[os.path.join(root,"CheXpert-v1.0-small","maecheckpoints"),os.path.join(root,"CheXpert-v1.0-small","maelogs")], | |
| "train_csv":os.path.join(root,"CheXpert-v1.0-small","train_ready.csv"), | |
| "val_csv":os.path.join(root,"CheXpert-v1.0-small","val_ready.csv"), | |
| "test_csv":os.path.join(root,"CheXpert-v1.0-small","test_ready.csv") | |
| ,"channels":1,"mask_ratio":0.75,"dropout":0.25,"img_size":384,"encoder_dim":768, | |
| "mlp_dim":3072,"decoder_dim":512,"encoder_depth":12,"encoder_head":8,"decoder_depth":8, | |
| "decoder_head":8,"patch_size":16 | |
| } |