Halton-MaskGIT / README.md
llvictorll's picture
Add link to paper (#2)
2737fad verified
---
license: mit
pipeline_tag: unconditional-image-generation
library_name: pytorch
---
## 🌟 Halton Scheduler for Masked Generative Image Transformer 🌟
[![GitHub stars](https://img.shields.io/github/stars/valeoai/Halton-MaskGIT.svg?style=social)](https://github.com/valeoai/Halton-MaskGIT/stargazers)
[![Hugging Face Model](https://img.shields.io/badge/Hugging%20Face-Model%20Card-orange?logo=huggingface)](https://huggingface.co/llvictorll/Halton-MaskGIT/tree/main)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/valeoai/Halton-Maskgit/blob/main/colab_demo.ipynb)
[![License](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE.txt)
[![Paper](https://img.shields.io/badge/ICLR-2025-blue)](https://huggingface.co/papers/2503.17076) <img src="statics/its_just_a_frog_cie.png" alt="drawing" width="25"/>
Official PyTorch implementation of the paper:
**Halton Scheduler for Masked Generative Image Transformer**
*Victor Besnier, Mickael Chen, David Hurych, Eduardo Valle, Matthieu Cord*
Accepted at **ICLR 2025**.
TL;DR: We introduce a new sampling strategy using the Halton Scheduler, which spreads tokens uniformly across the image.
This approach reduces sampling errors, and improves image quality.
---
## πŸš€ Overview
Welcome to the official implementation of our ICLR 2025 paper! πŸŽ‰
This repository introduces **Halton Scheduler for Masked Generative Image Transformer (MaskGIT)** and includes:
1. **Class-to-Image Model**: Generates high-quality 384x384 images from ImageNet class labels.
<p align="center">
<img src="statics/cls2img_halton.png" width="100%" alt="Cls2Img">
</p>
2. **Text-to-Image Model**: Generates realistic images from textual descriptions (coming soon)
<p align="center">
<img src="statics/txt2img_halton.jpg" width="100%" alt="Txt2Img">
</p>
Explore, train, and extend our easy to use generative models! πŸš€
The v1.0 version, previously known as "MaskGIT-pytorch" is available [here!](https://github.com/valeoai/Halton-MaskGIT/tree/v1.0)
---
## πŸ“ Repository Structure
```plaintext
β”œ Halton-MaskGIT/
| β”œβ”€β”€ Congig/ <- Base config file for the demo
| | β”œβ”€β”€ base_cls2img.yaml
| | └── base_txt2img.yaml
| β”œβ”€β”€ Dataset/ <- Data loading utilities
| | β”œβ”€β”€ dataset.py <- PyTorch dataset class
| | └── dataloader.py <- PyTorch dataloader
| β”œβ”€β”€ launch/
| | β”œβ”€β”€ run_cls_to_img.sh <- Training script for class-to-image
| | └── run_txt_to_img.sh <- Training script for text-to-image (coming soon)
| β”œβ”€β”€ Metrics/
| | β”œβ”€β”€ extract_train_fid.py <- Precompute FID stats for ImageNet
| | β”œβ”€β”€ inception_metrics.py <- Inception score and FID evaluation
| | └── sample_and_eval.py <- Sampling and evaluation
| β”œβ”€β”€ Network/
| | β”œβ”€β”€ ema.py <- EMA model
| | β”œβ”€β”€ transformer.py <- Transformer for class-to-image
| | β”œβ”€β”€ txt_transformer.py <- Transformer for text-to-image (coming soon)
| | └── va_model.py <- VQGAN architecture
| β”œβ”€β”€ Sampler/
| | β”œβ”€β”€ confidence_sampler.py <- Confidence scheduler
| | β”œβ”€β”€ halton_sampler.py <- Halton scheduler
| β”œβ”€β”€ Trainer/ <- Training classes
| | β”œβ”€β”€ abstract_trainer.py <- Abstract trainer
| | β”œβ”€β”€ cls_trainer.py <- Class-to-image trainer
| | └── txt_trainer.py <- Text-to-image trainer (coming soon)
| β”œβ”€β”€ statics/ <- Sample images and assets
| β”œβ”€β”€ saved_networks/ <- placeholder for the downloaded models
| β”œβ”€β”€ colab_demo.ipynb <- Inference demo
| β”œβ”€β”€ app.py <- Gradio example
| β”œβ”€β”€ LICENSE.txt <- MIT license
| β”œβ”€β”€ env.yaml <- Environment setup file
| β”œβ”€β”€ README.md <- This file! πŸ“–
| └── main.py <- Main script
```
## πŸ› οΈ Usage
Get started with just a few steps:
### 1️⃣ Clone the repository
```bash
git clone https://github.com/valeoai/Halton-MaskGIT.git
cd Halton-MaskGIT
```
### 2️⃣ Install dependencies
```bash
conda env create -f env.yaml
conda activate maskgit
```
### 3️⃣ Download pretrained models
```python
from huggingface_hub import hf_hub_download
# The VQ-GAN
hf_hub_download(repo_id="FoundationVision/LlamaGen",
filename="vq_ds16_c2i.pt",
local_dir="./saved_networks/")
# (Optional) The MaskGIT
hf_hub_download(repo_id="llvictorll/Halton-Maskgit",
filename="ImageNet_384_large.pth",
local_dir="./saved_networks/")
```
### 4️⃣ Extract the code from the VQGAN
```bash
python extract_vq_features.py --data_folder="/path/to/ImageNet/" --dest_folder="/your/path/" --bsize=256 --compile
```
### 5️⃣ Train the model
To train the class-to-image model:
```bash
bash launch/run_cls_to_img.sh
```
## πŸ“Ÿ Quick Start for sampling
To quickly verify the functionality of our model, you can try this Python code:
```python
import torch
from Utils.utils import load_args_from_file
from Utils.viz import show_images_grid
from huggingface_hub import hf_hub_download
from Trainer.cls_trainer import MaskGIT
from Sampler.halton_sampler import HaltonSampler
config_path = "Config/base_cls2img.yaml" # Path to your config file
args = load_args_from_file(config_path)
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Download the VQGAN from LlamaGen
hf_hub_download(repo_id="FoundationVision/LlamaGen",
filename="vq_ds16_c2i.pt",
local_dir="./saved_networks/")
# Download the MaskGIT
hf_hub_download(repo_id="llvictorll/Halton-Maskgit",
filename="ImageNet_384_large.pth",
local_dir="./saved_networks/")
# Initialisation of the model
model = MaskGIT(args)
# select your scheduler
sampler = HaltonSampler(sm_temp_min=1, sm_temp_max=1.2, temp_pow=1, temp_warmup=0, w=2,
sched_pow=2, step=32, randomize=True, top_k=-1)
# [goldfish, chicken, tiger cat, hourglass, ship, dog, race car, airliner]
labels = [1, 7, 282, 604, 724, 179, 751, 404]
gen_images = sampler(trainer=model, nb_sample=8, labels=labels, verbose=True)[0]
show_images_grid(gen_images)
```
or run the gradio πŸ–ΌοΈ app.py --> ```python app.py ``` and connect to http://127.0.0.1:6006 on your navigator
🎨 Want to try the model, but you don't have a gpu? Check out the Colab Notebook for an easy-to-run demo!
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/valeoai/Halton-Maskgit/blob/main/colab_demo.ipynb)
## 🧠 Pretrained Models
The pretrained MaskGIT models are available on [Hugging Face](https://huggingface.co/llvictorll/Halton-MaskGIT/tree/main).
Use them to jump straight into inference or fine-tuning.
| Model | # Params | # Input | # GFLOP | VQGAN | MaskGIT |
|----------------------|----------|---------|---------|--------|-------------------------------------------------------------------|
| Halton-MaskGIT-Large | 480M | 24x24 | 83.00 | [πŸ”— Download](https://huggingface.co/FoundationVision/LlamaGen/blob/main/vq_ds16_c2i.pt) | [πŸ”— Download](https://huggingface.co/llvictorll/Halton-MaskGIT/blob/main/ImageNet_384_large.pth) |
## ❀️ Contribute
We welcome contributions and feedback! πŸ› οΈ
If you encounter any issues, have suggestions, or want to collaborate, feel free to:
- Create an issue
- Fork the repository and submit a pull request
Your input is highly valued. Let’s make this project even better together! πŸ™Œ
## πŸ“œ License
This project is licensed under the MIT License.
See the [LICENSE](LICENSE.txt) file for details.
## πŸ™ Acknowledgments
We are grateful for the support of the IT4I Karolina Cluster in the Czech Republic for powering our experiments.
The pretrained VQGAN ImageNet (f=16/8, 16384 codebook) is from the [LlamaGen official repository](https://github.com/FoundationVision/LlamaGen?tab=readme-ov-file)
## πŸ“– Citation
If you find our work useful, please cite us and add a star ⭐ to the repository :)
```
@inproceedings{besnier2025iclr,
title={Halton Scheduler for Masked Generative Image Transformer},
author={Victor Besnier, Mickael Chen, David Hurych, Eduardo Valle, Matthieu Cord},
booktitle={International Conference on Learning Representations (ICLR)},
year={2025}
}
```
## ⭐ Stars History
[![Star History Chart](https://api.star-history.com/svg?repos=valeoai/Halton-MaskGIT&type=Date)](https://star-history.com/#valeoai/Halton-MaskGIT&Date)