Upload 6 files
Browse files- .gitattributes +1 -0
- README.md +71 -0
- gif_for_readme.gif +3 -0
- main.py +75 -0
- model.py +113 -0
- trainer.py +82 -0
- utils.py +124 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
gif_for_readme.gif filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
<h1>Radon Averaging</h1>
|
| 3 |
+
<h3>A Practical Approach for Designing Rotation-Invariant Models</h3>
|
| 4 |
+
|
| 5 |
+
<a href="https://www.python.org/">
|
| 6 |
+
<img src="https://img.shields.io/badge/Python-3.10+-blue?logo=python&style=flat-square" alt="Python Badge"/>
|
| 7 |
+
</a>
|
| 8 |
+
|
| 9 |
+
<a href="https://pytorch.org/">
|
| 10 |
+
<img src="https://img.shields.io/badge/PyTorch-2.0+-EE4C2C?logo=pytorch&style=flat-square" alt="PyTorch Badge"/>
|
| 11 |
+
</a>
|
| 12 |
+
|
| 13 |
+
<a href="https://doi.org/10.1016/j.engappai.2025.113299">
|
| 14 |
+
<img src="https://img.shields.io/badge/EAAI%202026-Published-success?style=flat-square" alt="EAAI Badge"/>
|
| 15 |
+
</a>
|
| 16 |
+
|
| 17 |
+
<a href="https://www.elsevier.com/">
|
| 18 |
+
<img src="https://img.shields.io/badge/Elsevier-Journal-orange?style=flat-square" alt="Elsevier Badge"/>
|
| 19 |
+
</a>
|
| 20 |
+
<br/><br/>
|
| 21 |
+
|
| 22 |
+
<!-- Radon Transform Animation -->
|
| 23 |
+
<img src="./gif_for_readme.gif" width="700px"/>
|
| 24 |
+
</div>
|
| 25 |
+
|
| 26 |
+
---
|
| 27 |
+
|
| 28 |
+
## Engineering Applications of Artificial Intelligence (EAAI 2026)
|
| 29 |
+
### Pytorch Implementation
|
| 30 |
+
|
| 31 |
+
This repository contains a pytorch implementation of **Radon Averaging (RA)** from the paper:
|
| 32 |
+
|
| 33 |
+
> **Radon Averaging: A practical approach for designing rotation-invariant models**
|
| 34 |
+
> Jangwon Kim, Sanghyun Ryoo, Jiwon Kim, Junkee Hong, Soohee Han
|
| 35 |
+
> *Engineering Applications of Artificial Intelligence*, Volume 164, 2026
|
| 36 |
+
|
| 37 |
+
## 📄 Paper Link
|
| 38 |
+
> **DOI:** https://doi.org/10.1016/j.engappai.2025.113299
|
| 39 |
+
> **Journal:** Engineering Applications of Artificial Intelligence
|
| 40 |
+
|
| 41 |
+
---
|
| 42 |
+
|
| 43 |
+
## Radon Averaging
|
| 44 |
+
|
| 45 |
+
Radon Averaging achieves rotation invariance by:
|
| 46 |
+
1. **Radon Transform** (ℛ): Converts images ($I$) to sinograms, where an rotation corresponds ($$g$$) to a circular shift.
|
| 47 |
+
2. **Averaging over Discrete Rotations** ($$G$$): Eliminates boundary artifacts via group averaging
|
| 48 |
+
3. **Standard CNN Backbone** ($$Φ$$): No architectural changes required
|
| 49 |
+
```math
|
| 50 |
+
RA_G^Φ(I) = \frac{1}{|G|} \sum_{g \in G} (Φ \circ π(g) \circ ℛ)(I)
|
| 51 |
+
```
|
| 52 |
+
---
|
| 53 |
+
|
| 54 |
+
## Advantages
|
| 55 |
+
- **Plug-and-play**: works with standard (pretrained) CNN backbones (no architectural changes).
|
| 56 |
+
- **Rotation invariance in practice**: stable representations under image rotations.
|
| 57 |
+
- **Reduces boundary artifacts**: group averaging mitigates Radon transform edge effects.
|
| 58 |
+
---
|
| 59 |
+
|
| 60 |
+
## Citation Example
|
| 61 |
+
```
|
| 62 |
+
@article{kim2026radonaveraging,
|
| 63 |
+
title = {Radon Averaging: A practical approach for designing rotation-invariant models},
|
| 64 |
+
author = {Kim, Jangwon and Ryoo, Sanghyun and Kim, Jiwon and Hong, Junkee and Han, Soohee},
|
| 65 |
+
journal = {Engineering Applications of Artificial Intelligence},
|
| 66 |
+
volume = {164},
|
| 67 |
+
pages = {113299},
|
| 68 |
+
year = {2026},
|
| 69 |
+
doi = {10.1016/j.engappai.2025.113299}
|
| 70 |
+
}
|
| 71 |
+
```
|
gif_for_readme.gif
ADDED
|
Git LFS Details
|
main.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import torch
|
| 3 |
+
from model import RA
|
| 4 |
+
from trainer import train
|
| 5 |
+
import random
|
| 6 |
+
import warnings
|
| 7 |
+
warnings.filterwarnings('ignore')
|
| 8 |
+
|
| 9 |
+
from utils import *
|
| 10 |
+
|
| 11 |
+
def parse_args():
|
| 12 |
+
parser = argparse.ArgumentParser(description="Train & Test with configurable args")
|
| 13 |
+
parser.add_argument("--num_train_data", type=int, default=10000, help="number of training samples")
|
| 14 |
+
parser.add_argument("--batch_size", type=int, default=64, help="batch size")
|
| 15 |
+
parser.add_argument("--epochs", type=int, default=50, help="number of epochs")
|
| 16 |
+
parser.add_argument("--group", type=str, default="C8", help="group name for RA model (C4 or C8)")
|
| 17 |
+
parser.add_argument("--n_seeds", type=int, default=10, help="number of seeds to run")
|
| 18 |
+
|
| 19 |
+
return parser.parse_args()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def train_and_test(model, train_loader, val_loader, test_loader, epochs, device):
|
| 23 |
+
score, f1 = train(
|
| 24 |
+
model=model,
|
| 25 |
+
epochs=epochs,
|
| 26 |
+
train_loader=train_loader,
|
| 27 |
+
val_loader=val_loader,
|
| 28 |
+
test_loader=test_loader,
|
| 29 |
+
lr=1e-4,
|
| 30 |
+
wd=1e-4,
|
| 31 |
+
device=device,
|
| 32 |
+
)
|
| 33 |
+
return score, f1
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def main(args):
|
| 37 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 38 |
+
base_seed = random.randint(1, 999)
|
| 39 |
+
|
| 40 |
+
accs = []
|
| 41 |
+
f1s = []
|
| 42 |
+
|
| 43 |
+
print_run_config(args, device, base_seed)
|
| 44 |
+
|
| 45 |
+
for seed in range(base_seed, base_seed + args.n_seeds):
|
| 46 |
+
train_dataset = MnistDataset(num_train_data=args.num_train_data, mode="train", seed=seed)
|
| 47 |
+
validation_dataset = MnistDataset(mode="validation", seed=seed)
|
| 48 |
+
test_dataset = MnistDataset(mode="test", seed=seed)
|
| 49 |
+
|
| 50 |
+
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size)
|
| 51 |
+
val_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=args.batch_size)
|
| 52 |
+
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size)
|
| 53 |
+
|
| 54 |
+
model = RA(group=args.group).to(device)
|
| 55 |
+
score, f1 = train_and_test(
|
| 56 |
+
model=model,
|
| 57 |
+
train_loader=train_loader,
|
| 58 |
+
val_loader=val_loader,
|
| 59 |
+
test_loader=test_loader,
|
| 60 |
+
epochs=args.epochs,
|
| 61 |
+
device=device,
|
| 62 |
+
)
|
| 63 |
+
accs.append(score)
|
| 64 |
+
f1s.append(f1)
|
| 65 |
+
|
| 66 |
+
a_m, a_std = cal_mean_std(accs)
|
| 67 |
+
print(f"[Acc] Mean: {a_m} | Std: {a_std}\n")
|
| 68 |
+
|
| 69 |
+
f_m, f_std = cal_mean_std(f1s)
|
| 70 |
+
print(f"[F-Score] Mean: {f_m} | Std: {f_std}\n")
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
if __name__ == "__main__":
|
| 74 |
+
args = parse_args()
|
| 75 |
+
main(args)
|
model.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import division
|
| 2 |
+
import copy
|
| 3 |
+
import torch
|
| 4 |
+
import torchvision.transforms.functional as TF
|
| 5 |
+
from torchvision import models
|
| 6 |
+
from torchvision.transforms import Pad
|
| 7 |
+
from torchvision.transforms import Resize
|
| 8 |
+
from torchvision.transforms import ToTensor
|
| 9 |
+
from skimage.transform import radon
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
WIDTH = 29
|
| 15 |
+
THETA = np.linspace(0.0, 360, WIDTH, endpoint=False)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class RA(torch.nn.Module):
|
| 19 |
+
def __init__(self, group='C4'):
|
| 20 |
+
super(RA, self).__init__()
|
| 21 |
+
self.backbone = MnistCNN()
|
| 22 |
+
self.group = group
|
| 23 |
+
self.resize1 = Resize(87) # to upsample
|
| 24 |
+
self.resize2 = Resize(29) # to downsample
|
| 25 |
+
self.totensor = ToTensor()
|
| 26 |
+
|
| 27 |
+
def forward(self, x):
|
| 28 |
+
y = 0
|
| 29 |
+
x = x.to('cpu')
|
| 30 |
+
org_x = copy.deepcopy(x)
|
| 31 |
+
|
| 32 |
+
if self.group == 'C1':
|
| 33 |
+
G = [0]
|
| 34 |
+
elif self.group == 'C4':
|
| 35 |
+
G = [0, 90, 180, 270]
|
| 36 |
+
elif self.group == 'C8':
|
| 37 |
+
G = [0, 45, 90, 135, 180, 225, 270, 315]
|
| 38 |
+
else:
|
| 39 |
+
raise NameError
|
| 40 |
+
|
| 41 |
+
if self.group != 'C1':
|
| 42 |
+
for r in G:
|
| 43 |
+
x = x.to('cpu')
|
| 44 |
+
for i in range(x.shape[0]):
|
| 45 |
+
img = org_x[i][0].numpy()
|
| 46 |
+
img = Image.fromarray(img, mode='F')
|
| 47 |
+
np_x = self.totensor(self.resize2(self.resize1(img).rotate(r, Image.BILINEAR))).numpy()
|
| 48 |
+
sinogram = radon(np_x[0], theta=THETA)
|
| 49 |
+
sinogram = torch.FloatTensor(sinogram).reshape(1, 29, 29)
|
| 50 |
+
x[i] = sinogram
|
| 51 |
+
|
| 52 |
+
x = x.to('cuda')
|
| 53 |
+
y += self.backbone(x)
|
| 54 |
+
|
| 55 |
+
y /= len(G)
|
| 56 |
+
return y
|
| 57 |
+
else:
|
| 58 |
+
x = x.to('cpu')
|
| 59 |
+
for i in range(x.shape[0]):
|
| 60 |
+
np_x = org_x[i][0].numpy()
|
| 61 |
+
sinogram = radon(np_x, theta=THETA)
|
| 62 |
+
sinogram = torch.FloatTensor(sinogram).reshape(1, 29, 29)
|
| 63 |
+
x[i] = sinogram
|
| 64 |
+
|
| 65 |
+
x = x.to('cuda')
|
| 66 |
+
y = self.backbone(x)
|
| 67 |
+
return y
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class MnistCNN(torch.nn.Module):
|
| 71 |
+
def __init__(self):
|
| 72 |
+
super(MnistCNN, self).__init__()
|
| 73 |
+
self.keep_prob = 0.5
|
| 74 |
+
|
| 75 |
+
self.layer1 = torch.nn.Sequential(
|
| 76 |
+
torch.nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
|
| 77 |
+
torch.nn.ReLU(),
|
| 78 |
+
torch.nn.MaxPool2d(kernel_size=2, stride=2))
|
| 79 |
+
|
| 80 |
+
self.layer2 = torch.nn.Sequential(
|
| 81 |
+
torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
|
| 82 |
+
torch.nn.ReLU(),
|
| 83 |
+
torch.nn.MaxPool2d(kernel_size=2, stride=2))
|
| 84 |
+
|
| 85 |
+
self.layer3 = torch.nn.Sequential(
|
| 86 |
+
torch.nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
|
| 87 |
+
torch.nn.ReLU(),
|
| 88 |
+
torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=1))
|
| 89 |
+
|
| 90 |
+
self.fc1 = torch.nn.Linear(2048, 128, bias=True)
|
| 91 |
+
self.layer4 = torch.nn.Sequential(
|
| 92 |
+
self.fc1,
|
| 93 |
+
torch.nn.ReLU(),
|
| 94 |
+
torch.nn.Dropout(p=1 - self.keep_prob))
|
| 95 |
+
|
| 96 |
+
self.fc2 = torch.nn.Linear(128, 10, bias=True)
|
| 97 |
+
self._initialize_weights()
|
| 98 |
+
|
| 99 |
+
def _initialize_weights(self):
|
| 100 |
+
for m in self.modules():
|
| 101 |
+
if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
|
| 102 |
+
torch.nn.init.orthogonal_(m.weight)
|
| 103 |
+
if m.bias is not None:
|
| 104 |
+
torch.nn.init.constant_(m.bias, 0)
|
| 105 |
+
|
| 106 |
+
def forward(self, x):
|
| 107 |
+
out = self.layer1(x)
|
| 108 |
+
out = self.layer2(out)
|
| 109 |
+
out = self.layer3(out)
|
| 110 |
+
out = out.view(out.size(0), -1)
|
| 111 |
+
out = self.layer4(out)
|
| 112 |
+
out = self.fc2(out)
|
| 113 |
+
return out
|
trainer.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from tqdm.auto import tqdm
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
from torch.utils.data import Dataset
|
| 9 |
+
from torchvision.transforms import Pad
|
| 10 |
+
from torchvision.transforms import Resize
|
| 11 |
+
from torchvision.transforms import ToTensor
|
| 12 |
+
from skimage.transform import radon, rescale
|
| 13 |
+
from scipy.ndimage import rotate
|
| 14 |
+
import random
|
| 15 |
+
from sklearn.metrics import f1_score
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def test(model, test_loader, device, final_test=False):
|
| 19 |
+
# test over the full rotated test set
|
| 20 |
+
total = 0
|
| 21 |
+
correct = 0
|
| 22 |
+
|
| 23 |
+
all_predictions = []
|
| 24 |
+
all_targets = []
|
| 25 |
+
|
| 26 |
+
with torch.no_grad():
|
| 27 |
+
model.eval()
|
| 28 |
+
for i, (x, t) in enumerate(test_loader):
|
| 29 |
+
x = x.to(device)
|
| 30 |
+
t = t.to(device)
|
| 31 |
+
y = model(x)
|
| 32 |
+
y = y.view(-1, 10)
|
| 33 |
+
|
| 34 |
+
_, prediction = torch.max(y.data, 1)
|
| 35 |
+
total += t.shape[0]
|
| 36 |
+
correct += (prediction == t).sum().item()
|
| 37 |
+
|
| 38 |
+
# Collect predictions and targets for F1 score calculation
|
| 39 |
+
all_predictions.extend(prediction.cpu().numpy())
|
| 40 |
+
all_targets.extend(t.cpu().numpy())
|
| 41 |
+
|
| 42 |
+
f1 = f1_score(all_targets, all_predictions, average='weighted')
|
| 43 |
+
if final_test:
|
| 44 |
+
print(f"[Final Test] Acc: {correct/total*100.} | F1-Score: {f1}\n")
|
| 45 |
+
|
| 46 |
+
return correct/total*100.0, f1
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def train(model, epochs, train_loader, val_loader, test_loader, lr=1e-4, wd=1e-4, device='cuda'):
|
| 50 |
+
best_acc = 0.0
|
| 51 |
+
best_model = None
|
| 52 |
+
loss_function = torch.nn.CrossEntropyLoss()
|
| 53 |
+
optimizer = torch.optim.Adam(model.backbone.parameters(), lr=lr, weight_decay=wd)
|
| 54 |
+
|
| 55 |
+
for epoch in range(epochs):
|
| 56 |
+
model.train()
|
| 57 |
+
for i, (x, t) in enumerate(train_loader):
|
| 58 |
+
optimizer.zero_grad()
|
| 59 |
+
x = x.to(device)
|
| 60 |
+
t = t.to(device)
|
| 61 |
+
y = model(x)
|
| 62 |
+
y = y.view(-1, 10)
|
| 63 |
+
loss = loss_function(y, t)
|
| 64 |
+
|
| 65 |
+
loss.backward()
|
| 66 |
+
|
| 67 |
+
optimizer.step()
|
| 68 |
+
del x, y, t, loss
|
| 69 |
+
if (epoch + 1) % 1 == 0:
|
| 70 |
+
accuracy, _ = test(model, val_loader, device=device)
|
| 71 |
+
print(f"epoch {epoch + 1} | validation accuracy: {accuracy}")
|
| 72 |
+
if accuracy > best_acc:
|
| 73 |
+
best_acc = accuracy
|
| 74 |
+
best_model = copy.deepcopy(model.to('cpu'))
|
| 75 |
+
model = model.to('cuda')
|
| 76 |
+
|
| 77 |
+
print(f"Max validation accuracy: {best_acc}\n")
|
| 78 |
+
best_model = best_model.to('cuda')
|
| 79 |
+
score, f1 = test(best_model, test_loader, device=device, final_test=True)
|
| 80 |
+
del best_model
|
| 81 |
+
del model
|
| 82 |
+
return score, f1
|
utils.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import shutil
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import tempfile
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from torch.utils.data import Dataset
|
| 8 |
+
from torchvision.transforms import Pad, Resize, ToTensor
|
| 9 |
+
from skimage.transform import radon, rescale
|
| 10 |
+
from scipy.ndimage import rotate
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
MNIST_TRAIN_GDRIVE = "https://drive.google.com/file/d/15G2FsYGRSpEkr5MTVofSFhKaMMIeiFhk/view?usp=drive_link"
|
| 14 |
+
MNIST_TEST_GDRIVE = "https://drive.google.com/file/d/1PK1DeFpw2OomuHDoA8ZtTPWLkPHOT6u6/view?usp=drive_link"
|
| 15 |
+
|
| 16 |
+
def _gdrive_download(url_or_id: str, out_path: Path) -> None:
|
| 17 |
+
import gdown
|
| 18 |
+
|
| 19 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 20 |
+
|
| 21 |
+
with tempfile.NamedTemporaryFile(delete=False, dir=str(out_path.parent), suffix=".tmp") as tf:
|
| 22 |
+
tmp_path = Path(tf.name)
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
ok = gdown.download(url_or_id, str(tmp_path), quiet=False, fuzzy=True)
|
| 26 |
+
if not ok or not tmp_path.exists() or tmp_path.stat().st_size == 0:
|
| 27 |
+
raise RuntimeError(f"[Google Drive] Failed to download: {url_or_id}")
|
| 28 |
+
|
| 29 |
+
shutil.move(str(tmp_path), str(out_path))
|
| 30 |
+
finally:
|
| 31 |
+
if tmp_path.exists():
|
| 32 |
+
try:
|
| 33 |
+
tmp_path.unlink()
|
| 34 |
+
except Exception:
|
| 35 |
+
pass
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _ensure_mnist_amat_files(mnist_dir: str = "mnist") -> None:
|
| 39 |
+
mnist_dir = Path(mnist_dir)
|
| 40 |
+
train_path = mnist_dir / "mnist_train.amat"
|
| 41 |
+
test_path = mnist_dir / "mnist_test.amat"
|
| 42 |
+
|
| 43 |
+
if not train_path.is_file():
|
| 44 |
+
print(f"[MNIST] '{train_path}' not found. Downloading MNIST train file from Google Drive...")
|
| 45 |
+
_gdrive_download(MNIST_TRAIN_GDRIVE, train_path)
|
| 46 |
+
print(f"[MNIST] Download complete: {train_path}")
|
| 47 |
+
|
| 48 |
+
if not test_path.is_file():
|
| 49 |
+
print(f"[MNIST] '{test_path}' not found. Downloading MNIST test file from Google Drive...")
|
| 50 |
+
_gdrive_download(MNIST_TEST_GDRIVE, test_path)
|
| 51 |
+
print(f"[MNIST] Download complete: {test_path}")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class MnistDataset(Dataset):
|
| 55 |
+
def __init__(self, mode, num_train_data=10000, seed=1):
|
| 56 |
+
assert num_train_data + 12000 <= 45000
|
| 57 |
+
assert mode in ['train', 'validation', 'test']
|
| 58 |
+
if seed is not None:
|
| 59 |
+
np.random.seed(seed)
|
| 60 |
+
torch.manual_seed(seed)
|
| 61 |
+
|
| 62 |
+
_ensure_mnist_amat_files("mnist")
|
| 63 |
+
|
| 64 |
+
if mode == "test":
|
| 65 |
+
file = "mnist/mnist_test.amat"
|
| 66 |
+
else:
|
| 67 |
+
file = "mnist/mnist_train.amat"
|
| 68 |
+
|
| 69 |
+
data = np.loadtxt(file)
|
| 70 |
+
images = data[:, :-1].reshape(-1, 28, 28).astype(np.float32)
|
| 71 |
+
|
| 72 |
+
# Shuffle the images
|
| 73 |
+
indices = np.arange(num_train_data + 12000)
|
| 74 |
+
np.random.shuffle(indices)
|
| 75 |
+
if mode == 'train':
|
| 76 |
+
images = images[indices[:num_train_data]]
|
| 77 |
+
data = data[indices[:num_train_data]]
|
| 78 |
+
elif mode == 'validation':
|
| 79 |
+
images = images[indices[num_train_data:]]
|
| 80 |
+
data = data[indices[num_train_data:]]
|
| 81 |
+
|
| 82 |
+
if mode == 'test' or mode == 'validation':
|
| 83 |
+
pad = Pad((0, 0, 1, 1), fill=0)
|
| 84 |
+
resize1 = Resize(87) # to upsample
|
| 85 |
+
resize2 = Resize(29) # to downsample
|
| 86 |
+
totensor = ToTensor()
|
| 87 |
+
|
| 88 |
+
self.images = torch.empty((images.shape[0], 1, 29, 29))
|
| 89 |
+
for i in range(images.shape[0]):
|
| 90 |
+
img = images[i]
|
| 91 |
+
img = Image.fromarray(img, mode='F')
|
| 92 |
+
r = (np.random.rand() * 360.)
|
| 93 |
+
self.images[i] = totensor(resize2(resize1(pad(img)).rotate(r, Image.BILINEAR))).reshape(1, 29, 29)
|
| 94 |
+
else:
|
| 95 |
+
self.images = torch.zeros((images.shape[0], 1, 29, 29))
|
| 96 |
+
self.images[:, :, :28, :28] = torch.tensor(images).reshape(-1, 1, 28, 28)
|
| 97 |
+
|
| 98 |
+
self.labels = data[:, -1].astype(np.int64)
|
| 99 |
+
|
| 100 |
+
def __getitem__(self, index):
|
| 101 |
+
image, label = self.images[index], self.labels[index]
|
| 102 |
+
|
| 103 |
+
return image, label
|
| 104 |
+
|
| 105 |
+
def __len__(self):
|
| 106 |
+
return len(self.labels)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def cal_mean_std(ary):
|
| 110 |
+
ary_np = np.array(ary)
|
| 111 |
+
return np.mean(ary_np), np.std(ary_np)
|
| 112 |
+
|
| 113 |
+
def print_run_config(args, device, base_seed):
|
| 114 |
+
print("\n" + "=" * 34)
|
| 115 |
+
print("RA Training Configuration")
|
| 116 |
+
print("-" * 34)
|
| 117 |
+
print(f"Device : {device}")
|
| 118 |
+
print(f"Group : {args.group}")
|
| 119 |
+
print(f"Epochs : {args.epochs}")
|
| 120 |
+
print(f"Batch size : {args.batch_size}")
|
| 121 |
+
print(f"Train samples : {args.num_train_data}")
|
| 122 |
+
print(f"Num seeds : {args.n_seeds}")
|
| 123 |
+
print(f"Seed range : {base_seed} ~ {base_seed + args.n_seeds - 1}")
|
| 124 |
+
print("=" * 34 + "\n")
|