mwalmsley commited on
Commit
c927f0b
·
verified ·
1 Parent(s): c67df2d

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +4 -181
README.md CHANGED
@@ -10,190 +10,13 @@ library_name: timm
10
 
11
  # Masked Autoencoder for Euclid Images
12
 
13
- ## Overview
14
-
15
  This masked autoencoder (MAE) is trained to reconstruct Euclid galaxy images where 90% of the image is masked.
16
  The trained model shows superhuman performance at reconstruction.
17
 
18
- **There is an [interactive demo here](https://huggingface.co/spaces/mwalmsley/euclid_masked_autoencoder). Try choosing your own images!**
19
-
20
- This version is trained on RR2 (3M images) A DR1 version (13.6M images) will follow.
21
-
22
- The base model is a custom `timm` vision transformer; see config.yaml for the exact hyperparameters. These don't matter much; the main things to note are that we need finegrained patches (8x8 pixels here) but the model otherwise need not be large to work well. This version is under 100M parameters and comfortably runs predictions on CPU.
23
-
24
- This version is as presented in `Galaxy Morphology and Interpretability through Sparsity, Wu & Walmsley, NeurIPS ML4Science workshop 2025` (on arxiv very shortly, bibtex to follow), except that I have removed the angular-scale-dependent positional encoding for simplicity (it turns out to not be necessary for good performance). Please cite the workshop paper if you find our work helpful.
25
-
26
- ## Instructions
27
-
28
- ### Quickstart - Download
29
-
30
- ```python
31
- import mae_timm_simplified # download this script from the "files and versions" tab
32
-
33
- import omegaconf
34
- from huggingface_hub import hf_hub_download
35
-
36
- cfg_path = hf_hub_download(repo_id="mwalmsley/euclid-rr2-mae", filename="config.yaml")
37
- cfg = omegaconf.OmegaConf.load(cfg_path)
38
- mae = mae_timm_simplified.MAE.from_pretrained("mwalmsley/euclid-rr2-mae", cfg=cfg)
39
-
40
- ```
41
-
42
- ### Quickstart - Make Prediction
43
-
44
- ```python
45
- from PIL import Image
46
- import torch
47
- from lightly.models.utils import random_token_mask
48
-
49
- image = Image.open('foo.jpg')
50
-
51
- image = preprocess_image(image)
52
-
53
- batch = {
54
- 'image': image.unsqueeze(0), # (1, 3, H, W)
55
- 'id_str': ['dummy' ], # required by my convention, ignore
56
- }
57
-
58
-
59
- # utility function for generating random patch indices to mask
60
- _, idx_mask = random_token_mask(
61
- size=(1, mae.sequence_length), # (batch_size, seq_len)
62
- mask_ratio=0.9, # your choice
63
- device=batch['image'].device,
64
- )
65
-
66
- with torch.no_grad():
67
- result = mae.predict(batch, idx_mask=idx_mask)
68
-
69
- ```
70
-
71
- `result` is a dict with keys for the original image, masked image, reconstructed image, reconstruction loss, and embedding.
72
-
73
- `preprocess_image` is a torchvision transform that returns 3x224x224 float tensors normalised from 0 to 1:
74
-
75
- ```python
76
- from torchvision.transforms import v2
77
-
78
- def preprocess_image(image):
79
- preprocess = transforms.Compose([
80
- v2.ToImage(),
81
- transforms.Resize((224, 224)),
82
- v2.ToDtype(torch.float32, scale=True)
83
- ])
84
- return preprocess(image)
85
- ```
86
-
87
- `idx_mask` is a list of patch indices to mask (e.g. [1, 5, 45, ...]). `random_token_mask` is a small utility from `lightly` that generates random patch indices - but it's basically equivalent to `np.random.choice`. For custom masks, read on.
88
-
89
- #### Generate a custom mask
90
-
91
- We divide the image into 784 patches, in a grid of 28 by 28 patches.
92
- Each patch is 8x8 pixels (covering our 224x224 image).
93
 
94
- The patch in the top corner is index 1 (not 0!), and higher indices go left-to-right and then down a row (like reading a page).
95
 
96
- For example, to mask only the first 28 patches (the top of the image)
97
-
98
- ```python
99
- row_mask = torch.tensor(range(28)) + 1 # 1 to 29
100
- # copy for all images in the batch
101
- idx_mask = row_mask.unsqueeze(0).repeat(batch_size, 1) # (batch_size, num_masked)
102
- idx_mask = idx_mask.to('cuda')
103
- ```
104
-
105
- To mask the middle strip:
106
-
107
- ```python
108
- row_mask = torch.tensor(range(28)) + 1 + 13*28
109
- ```
110
-
111
- And so on, however you like. Just remember to add 1 for the class token!
112
-
113
- #### Make predictions for the masked patches
114
-
115
- ```python
116
- mae = mae.to('cuda')
117
- with torch.no_grad():
118
- result = mae.predict(batch, idx_mask=idx_mask)
119
-
120
- # result has keys including images, masked, reconstructed
121
- # each key is a list of standard PIL images
122
- images = result['images']
123
- masked = result['masked']
124
- reconstructed = result['reconstructed']
125
-
126
- # Visualize the results
127
- fig, axes = plt.subplots(nrows=3, ncols=8, figsize=(24, 9))
128
- for i in range(8):
129
- axes[0, i].imshow(images[i])
130
- axes[0, i].set_title("Original")
131
- axes[1, i].imshow(masked[i])
132
- axes[1, i].set_title("Masked")
133
- axes[2, i].imshow(reconstructed[i])
134
- axes[2, i].set_title("Reconstructed")
135
- plt.tight_layout()
136
- plt.show()
137
- ```
138
-
139
- That's everything you need to know for doing inference (reconstructions or embeddings). To reproduce my training, keep reading.
140
-
141
- ### Download Data
142
-
143
- Get a dataset of Euclid images, prepared as Galaxy-Zoo-style jpgs:
144
-
145
- ```python
146
-
147
- from datasets import load_dataset
148
-
149
- dataset_dict = load_dataset(
150
- 'mwalmsley/euclid_q1', # _rr2, _dr1 versions are available to EC members
151
- name='tiny-v1-gz_arcsinh_vis_y' # tiny subset for testing
152
- )
153
- ```
154
-
155
- Use my utility package `galaxy-datasets` to load this as a Lightning DataModule, including an appropriate torchvision transform...
156
-
157
- ```python
158
-
159
- from galaxy_datasets.pytorch.galaxy_datamodule import HuggingFaceDataModule
160
- from galaxy_datasets.transforms import default_view_config, get_galaxy_transform
161
-
162
-
163
- # define augmentations to use
164
- view_config = default_view_config()
165
- view_config.output_size = 224
166
- view_config.erase_iterations = 0 # for simplicity
167
- ssl_image_transform = get_galaxy_transform(cfg=view_config)
168
- # this is just a torchvision Compose transform
169
- # returns 3x224x224 float tensor normalised 0-1.
170
-
171
- datamodule = HuggingFaceDataModule(
172
- dataset_dict=dataset_dict,
173
- train_transform=ssl_image_transform,
174
- test_transform=ssl_image_transform,
175
- batch_size=batch_size,
176
- num_workers=num_workers,
177
- prefetch_factor=prefetch_factor
178
- )
179
- datamodule.setup()
180
- # this is just a lightning datamodule
181
- # should yield batches with an 'image' key, see below
182
-
183
- # get a batch
184
- test_loader = datamodule.test_dataloader()
185
- for batch in test_loader:
186
- batch['image'] = batch['image'].to('cuda')
187
- break
188
- ```
189
-
190
- ...or you can do this yourself. You should make batches that include an 'image' key which contains
191
-
192
- - BxCx224x224 float tensors normalised from 0 to 1
193
- - where those tensors are created by transforming (e.g. with torchvision) a GZ-style jpg (download from HuggingFace above)
194
-
195
- It might work for other human-friendly jpgs, but that's outside of the training distribution, so no promises.
196
-
197
- ---
198
 
199
- Walmsley trained the model and Wu ran the sparsity analysis. Additional thanks to Inigo Val Slijepcevic, Micah Bowles, Devina Mohan, Anna Scaife, and Joshua Speagle, for their help and advice. We are grateful to the Euclid Consortium and the European Space Agency for making the data available.
 
10
 
11
  # Masked Autoencoder for Euclid Images
12
 
 
 
13
  This masked autoencoder (MAE) is trained to reconstruct Euclid galaxy images where 90% of the image is masked.
14
  The trained model shows superhuman performance at reconstruction.
15
 
16
+ ## UPDATED VERSION AVAILABLE
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ This version is trained on Euclid RR2 (3M images). We later trained an otherwise-identical model on Euclid DR1 (13.6M images). You should use that instead - it's strictly better!
19
 
20
+ Go to [mwalmsley/euclid-dr1-mae](https://huggingface.co/mwalmsley/euclid-dr1-mae) for full details and example code.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ We'll leave this model up for reproducibility, but seriously, please use the DR1 version!