Update README.md
Browse files
README.md
CHANGED
|
@@ -10,190 +10,13 @@ library_name: timm
|
|
| 10 |
|
| 11 |
# Masked Autoencoder for Euclid Images
|
| 12 |
|
| 13 |
-
## Overview
|
| 14 |
-
|
| 15 |
This masked autoencoder (MAE) is trained to reconstruct Euclid galaxy images where 90% of the image is masked.
|
| 16 |
The trained model shows superhuman performance at reconstruction.
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
This version is trained on RR2 (3M images) A DR1 version (13.6M images) will follow.
|
| 21 |
-
|
| 22 |
-
The base model is a custom `timm` vision transformer; see config.yaml for the exact hyperparameters. These don't matter much; the main things to note are that we need finegrained patches (8x8 pixels here) but the model otherwise need not be large to work well. This version is under 100M parameters and comfortably runs predictions on CPU.
|
| 23 |
-
|
| 24 |
-
This version is as presented in `Galaxy Morphology and Interpretability through Sparsity, Wu & Walmsley, NeurIPS ML4Science workshop 2025` (on arxiv very shortly, bibtex to follow), except that I have removed the angular-scale-dependent positional encoding for simplicity (it turns out to not be necessary for good performance). Please cite the workshop paper if you find our work helpful.
|
| 25 |
-
|
| 26 |
-
## Instructions
|
| 27 |
-
|
| 28 |
-
### Quickstart - Download
|
| 29 |
-
|
| 30 |
-
```python
|
| 31 |
-
import mae_timm_simplified # download this script from the "files and versions" tab
|
| 32 |
-
|
| 33 |
-
import omegaconf
|
| 34 |
-
from huggingface_hub import hf_hub_download
|
| 35 |
-
|
| 36 |
-
cfg_path = hf_hub_download(repo_id="mwalmsley/euclid-rr2-mae", filename="config.yaml")
|
| 37 |
-
cfg = omegaconf.OmegaConf.load(cfg_path)
|
| 38 |
-
mae = mae_timm_simplified.MAE.from_pretrained("mwalmsley/euclid-rr2-mae", cfg=cfg)
|
| 39 |
-
|
| 40 |
-
```
|
| 41 |
-
|
| 42 |
-
### Quickstart - Make Prediction
|
| 43 |
-
|
| 44 |
-
```python
|
| 45 |
-
from PIL import Image
|
| 46 |
-
import torch
|
| 47 |
-
from lightly.models.utils import random_token_mask
|
| 48 |
-
|
| 49 |
-
image = Image.open('foo.jpg')
|
| 50 |
-
|
| 51 |
-
image = preprocess_image(image)
|
| 52 |
-
|
| 53 |
-
batch = {
|
| 54 |
-
'image': image.unsqueeze(0), # (1, 3, H, W)
|
| 55 |
-
'id_str': ['dummy' ], # required by my convention, ignore
|
| 56 |
-
}
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
# utility function for generating random patch indices to mask
|
| 60 |
-
_, idx_mask = random_token_mask(
|
| 61 |
-
size=(1, mae.sequence_length), # (batch_size, seq_len)
|
| 62 |
-
mask_ratio=0.9, # your choice
|
| 63 |
-
device=batch['image'].device,
|
| 64 |
-
)
|
| 65 |
-
|
| 66 |
-
with torch.no_grad():
|
| 67 |
-
result = mae.predict(batch, idx_mask=idx_mask)
|
| 68 |
-
|
| 69 |
-
```
|
| 70 |
-
|
| 71 |
-
`result` is a dict with keys for the original image, masked image, reconstructed image, reconstruction loss, and embedding.
|
| 72 |
-
|
| 73 |
-
`preprocess_image` is a torchvision transform that returns 3x224x224 float tensors normalised from 0 to 1:
|
| 74 |
-
|
| 75 |
-
```python
|
| 76 |
-
from torchvision.transforms import v2
|
| 77 |
-
|
| 78 |
-
def preprocess_image(image):
|
| 79 |
-
preprocess = transforms.Compose([
|
| 80 |
-
v2.ToImage(),
|
| 81 |
-
transforms.Resize((224, 224)),
|
| 82 |
-
v2.ToDtype(torch.float32, scale=True)
|
| 83 |
-
])
|
| 84 |
-
return preprocess(image)
|
| 85 |
-
```
|
| 86 |
-
|
| 87 |
-
`idx_mask` is a list of patch indices to mask (e.g. [1, 5, 45, ...]). `random_token_mask` is a small utility from `lightly` that generates random patch indices - but it's basically equivalent to `np.random.choice`. For custom masks, read on.
|
| 88 |
-
|
| 89 |
-
#### Generate a custom mask
|
| 90 |
-
|
| 91 |
-
We divide the image into 784 patches, in a grid of 28 by 28 patches.
|
| 92 |
-
Each patch is 8x8 pixels (covering our 224x224 image).
|
| 93 |
|
| 94 |
-
|
| 95 |
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
```python
|
| 99 |
-
row_mask = torch.tensor(range(28)) + 1 # 1 to 29
|
| 100 |
-
# copy for all images in the batch
|
| 101 |
-
idx_mask = row_mask.unsqueeze(0).repeat(batch_size, 1) # (batch_size, num_masked)
|
| 102 |
-
idx_mask = idx_mask.to('cuda')
|
| 103 |
-
```
|
| 104 |
-
|
| 105 |
-
To mask the middle strip:
|
| 106 |
-
|
| 107 |
-
```python
|
| 108 |
-
row_mask = torch.tensor(range(28)) + 1 + 13*28
|
| 109 |
-
```
|
| 110 |
-
|
| 111 |
-
And so on, however you like. Just remember to add 1 for the class token!
|
| 112 |
-
|
| 113 |
-
#### Make predictions for the masked patches
|
| 114 |
-
|
| 115 |
-
```python
|
| 116 |
-
mae = mae.to('cuda')
|
| 117 |
-
with torch.no_grad():
|
| 118 |
-
result = mae.predict(batch, idx_mask=idx_mask)
|
| 119 |
-
|
| 120 |
-
# result has keys including images, masked, reconstructed
|
| 121 |
-
# each key is a list of standard PIL images
|
| 122 |
-
images = result['images']
|
| 123 |
-
masked = result['masked']
|
| 124 |
-
reconstructed = result['reconstructed']
|
| 125 |
-
|
| 126 |
-
# Visualize the results
|
| 127 |
-
fig, axes = plt.subplots(nrows=3, ncols=8, figsize=(24, 9))
|
| 128 |
-
for i in range(8):
|
| 129 |
-
axes[0, i].imshow(images[i])
|
| 130 |
-
axes[0, i].set_title("Original")
|
| 131 |
-
axes[1, i].imshow(masked[i])
|
| 132 |
-
axes[1, i].set_title("Masked")
|
| 133 |
-
axes[2, i].imshow(reconstructed[i])
|
| 134 |
-
axes[2, i].set_title("Reconstructed")
|
| 135 |
-
plt.tight_layout()
|
| 136 |
-
plt.show()
|
| 137 |
-
```
|
| 138 |
-
|
| 139 |
-
That's everything you need to know for doing inference (reconstructions or embeddings). To reproduce my training, keep reading.
|
| 140 |
-
|
| 141 |
-
### Download Data
|
| 142 |
-
|
| 143 |
-
Get a dataset of Euclid images, prepared as Galaxy-Zoo-style jpgs:
|
| 144 |
-
|
| 145 |
-
```python
|
| 146 |
-
|
| 147 |
-
from datasets import load_dataset
|
| 148 |
-
|
| 149 |
-
dataset_dict = load_dataset(
|
| 150 |
-
'mwalmsley/euclid_q1', # _rr2, _dr1 versions are available to EC members
|
| 151 |
-
name='tiny-v1-gz_arcsinh_vis_y' # tiny subset for testing
|
| 152 |
-
)
|
| 153 |
-
```
|
| 154 |
-
|
| 155 |
-
Use my utility package `galaxy-datasets` to load this as a Lightning DataModule, including an appropriate torchvision transform...
|
| 156 |
-
|
| 157 |
-
```python
|
| 158 |
-
|
| 159 |
-
from galaxy_datasets.pytorch.galaxy_datamodule import HuggingFaceDataModule
|
| 160 |
-
from galaxy_datasets.transforms import default_view_config, get_galaxy_transform
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
# define augmentations to use
|
| 164 |
-
view_config = default_view_config()
|
| 165 |
-
view_config.output_size = 224
|
| 166 |
-
view_config.erase_iterations = 0 # for simplicity
|
| 167 |
-
ssl_image_transform = get_galaxy_transform(cfg=view_config)
|
| 168 |
-
# this is just a torchvision Compose transform
|
| 169 |
-
# returns 3x224x224 float tensor normalised 0-1.
|
| 170 |
-
|
| 171 |
-
datamodule = HuggingFaceDataModule(
|
| 172 |
-
dataset_dict=dataset_dict,
|
| 173 |
-
train_transform=ssl_image_transform,
|
| 174 |
-
test_transform=ssl_image_transform,
|
| 175 |
-
batch_size=batch_size,
|
| 176 |
-
num_workers=num_workers,
|
| 177 |
-
prefetch_factor=prefetch_factor
|
| 178 |
-
)
|
| 179 |
-
datamodule.setup()
|
| 180 |
-
# this is just a lightning datamodule
|
| 181 |
-
# should yield batches with an 'image' key, see below
|
| 182 |
-
|
| 183 |
-
# get a batch
|
| 184 |
-
test_loader = datamodule.test_dataloader()
|
| 185 |
-
for batch in test_loader:
|
| 186 |
-
batch['image'] = batch['image'].to('cuda')
|
| 187 |
-
break
|
| 188 |
-
```
|
| 189 |
-
|
| 190 |
-
...or you can do this yourself. You should make batches that include an 'image' key which contains
|
| 191 |
-
|
| 192 |
-
- BxCx224x224 float tensors normalised from 0 to 1
|
| 193 |
-
- where those tensors are created by transforming (e.g. with torchvision) a GZ-style jpg (download from HuggingFace above)
|
| 194 |
-
|
| 195 |
-
It might work for other human-friendly jpgs, but that's outside of the training distribution, so no promises.
|
| 196 |
-
|
| 197 |
-
---
|
| 198 |
|
| 199 |
-
|
|
|
|
| 10 |
|
| 11 |
# Masked Autoencoder for Euclid Images
|
| 12 |
|
|
|
|
|
|
|
| 13 |
This masked autoencoder (MAE) is trained to reconstruct Euclid galaxy images where 90% of the image is masked.
|
| 14 |
The trained model shows superhuman performance at reconstruction.
|
| 15 |
|
| 16 |
+
## UPDATED VERSION AVAILABLE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
+
This version is trained on Euclid RR2 (3M images). We later trained an otherwise-identical model on Euclid DR1 (13.6M images). You should use that instead - it's strictly better!
|
| 19 |
|
| 20 |
+
Go to [mwalmsley/euclid-dr1-mae](https://huggingface.co/mwalmsley/euclid-dr1-mae) for full details and example code.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
+
We'll leave this model up for reproducibility, but seriously, please use the DR1 version!
|