|
|
--- |
|
|
license: cc-by-sa-4.0 |
|
|
datasets: |
|
|
- Smith42/galaxies |
|
|
- Smith42/galaxies_metadata |
|
|
- Smith42/galaxies_embeddings |
|
|
tags: |
|
|
- astronomy |
|
|
- images |
|
|
- huggingscience |
|
|
- science |
|
|
--- |
|
|
<center> |
|
|
<img src="assets/shoggoth_telescope_sticker_2.png" alt="astroPT_shoggoth" width="300px"/> |
|
|
</center> |
|
|
|
|
|
# astroPTv2.0: a Large Observation Model for Astronomy |
|
|
|
|
|
Here we have the model files for the astroPT project, the code to run inference |
|
|
with these models is found here: |
|
|
[https://github.com/smith42/astropt](https://github.com/smith42/astropt) |
|
|
|
|
|
You will find the fully trained models (pretrained on 8.6 million galaxies) in |
|
|
folders labelled with the model parameter count in the `astropt` directory. |
|
|
|
|
|
Unlike the older models which were trained on the "image" column in |
|
|
[smith42/galaxies](https://huggingface.co/datasets/Smith42/galaxies), |
|
|
these models are trained on the "cropped" galaxies from the "image_crop" |
|
|
column. Those galaxies have been cropped and zoomed so that they take up the |
|
|
majority of each image before uploading. |
|
|
|
|
|
We get some promising scaling on this new dataset, see below: |
|
|
|
|
|
<center> |
|
|
<img src="assets/scaling_law.png" alt="scaling_law" width="300px"/> |
|
|
</center> |
|
|
|
|
|
## Usage |
|
|
|
|
|
To use these models in anger you can `pip install astropt` and run the following code: |
|
|
|
|
|
```python |
|
|
from astropt.model_utils import load_astropt |
|
|
from astropt.local_datasets import GalaxyImageDataset |
|
|
|
|
|
from datasets import load_dataset # for Smith42/galaxies |
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
from functools import partial |
|
|
from torch.utils.data import DataLoader |
|
|
from torchvision import transforms |
|
|
|
|
|
# boilerplate to preprocess galaxy images |
|
|
def normalise(x): |
|
|
std, mean = torch.std_mean(x, dim=1, keepdim=True) |
|
|
return (x - mean) / (std + 1e-8) |
|
|
|
|
|
def data_transforms(): |
|
|
return transforms.Compose([transforms.Lambda(normalise)]) |
|
|
|
|
|
def _process_galaxy_wrapper(idx, func): |
|
|
"""This function ensures that the image is tokenised in the same way as the pre-trained model is expecting""" |
|
|
galaxy = func( |
|
|
torch.from_numpy(np.array(idx["image"]).swapaxes(0, 2)).to(float) |
|
|
).to(torch.float) |
|
|
galaxy_positions = torch.arange(0, len(galaxy), dtype=torch.long) |
|
|
return { |
|
|
"images": galaxy, |
|
|
"images_positions": galaxy_positions, |
|
|
} |
|
|
|
|
|
# for 095M parameter model, 015M and 850M models are also available: |
|
|
model = load_astropt("Smith42/astroPT_v2.0", path="astropt/095M") |
|
|
|
|
|
galproc = GalaxyImageDataset( |
|
|
None, |
|
|
spiral=True, |
|
|
transform={"images": data_transforms()}, |
|
|
modality_registry=model.modality_registry |
|
|
) |
|
|
|
|
|
ds = ( |
|
|
load_dataset("Smith42/galaxies", split="test", revision="v2.0", streaming=True) |
|
|
.select_columns("image") |
|
|
.map(partial(_process_galaxy_wrapper, func=galproc.process_galaxy)) |
|
|
.with_format("torch") |
|
|
) |
|
|
|
|
|
dl = iter(DataLoader(ds, batch_size=128, num_workers=32)) |
|
|
|
|
|
zs = [] |
|
|
for B in dl: |
|
|
zs.append(model.generate_embeddings(B)["images"].detach().numpy()) |
|
|
zs = np.concatenate(zs) |
|
|
|
|
|
# do cool stuff with zs... |
|
|
``` |
|
|
|
|
|
|
|
|
## Updates and community |
|
|
|
|
|
AstroPT is an open-to-all UniverseTBD project. Please join the [UniverseTBD](https://universetbd.org) Discord for updates: [https://discord.gg/MNEVegvfJq](https://discord.gg/MNEVegvfJq) |