flux_jax / README.md
jcopo's picture
Update README.md
ea44bd5 verified
---
license: other
library_name: diffuse
tags:
- text-to-image
- diffusion
- flux
- jax
- flax
---
# FLUX.1-dev Model (JAX/Flax)
![Downloads](https://img.shields.io/badge/dynamic/json?color=blue&label=downloads&query=%24.downloads&url=https%3A%2F%2Fhuggingface.co%2Fapi%2Fmodels%2Fjcopo%2Fflux_jax)
This repository contains a FLUX.1-dev text-to-image diffusion model stored in Orbax/JAX format, optimized for use with JAX/Flax frameworks.
## Model Description
FLUX.1-dev is a powerful text-to-image generation model that uses a transformer-based architecture with dual text encoders (CLIP and T5) for enhanced text understanding and image generation capabilities.
### Components
This model includes the following components:
- **Transformer**: Main diffusion transformer model
- **VAE**: Variational Autoencoder for image encoding/decoding
- **CLIP Text Encoder**: For text understanding
- **T5 Text Encoder**: For enhanced text processing
- **Tokenizers**: CLIP and T5 tokenizers
## Usage
To use this model, you should use the **Diffuse** library, which provides an easy-to-use interface for FLUX models.
### Tutorial
For a comprehensive tutorial on using FLUX models with Diffuse, please refer to:
**[FLUX Tutorial Documentation](https://diffuse.readthedocs.io/en/latest/flux_tutorial.html)**
## Resources
- **Diffuse Library**: [https://github.com/jcopo/diffuse](https://github.com/jcopo/diffuse)
- **Documentation**: [https://diffuse.readthedocs.io/](https://diffuse.readthedocs.io/)
- **FLUX Tutorial**: [https://diffuse.readthedocs.io/en/latest/flux_tutorial.html](https://diffuse.readthedocs.io/en/latest/flux_tutorial.html)
## Model Format
This model is stored in **Orbax checkpoint format**, optimized for JAX/Flax frameworks. The Diffuse library handles loading and inference automatically.
### Installation
Install the Diffuse library:
```bash
pip install git+https://github.com/jcopo/diffuse.git
```
### Code Example
```python
import jax
from pathlib import Path
from huggingface_hub import snapshot_download
from diffuse import FluxModelLoader, FluxTimer, Flow, Predictor, Denoiser
from diffuse.integrators import EulerIntegrator
from diffuse.utils import _latent_shapes
# ===========================
# 1. Download Model
# ===========================
HF_REPO_ID = "jcopo/flux_jax"
checkpoint_dir = Path(snapshot_download(repo_id=HF_REPO_ID, repo_type="model"))
# ===========================
# 2. Set Generation Parameters
# ===========================
PROMPT = "A serene landscape with mountains at sunset, highly detailed, photorealistic"
HEIGHT = 512
WIDTH = 512
NUM_STEPS = 20
GUIDANCE_SCALE = 4.0
SEED = 42
# ===========================
# 3. Load Model and Prepare Network
# ===========================
loader = FluxModelLoader(checkpoint_dir=checkpoint_dir, verbose=True)
conditioned = loader.prepare_conditioned_network(
prompt=PROMPT,
negative_prompt=None,
guidance_scale=GUIDANCE_SCALE,
height=HEIGHT,
width=WIDTH,
)
# ===========================
# 4. Setup Diffusion Components
# ===========================
_, transformer_hw = _latent_shapes(HEIGHT, WIDTH)
image_seq_len = transformer_hw[0] * transformer_hw[1]
# Initialize timer with dynamic shift
timer = FluxTimer(num_steps=NUM_STEPS, use_dynamic_shift=True)
timer.set_image_seq_len(image_seq_len)
# Create flow model and predictor
flow = Flow(tf=1.0)
predictor = Predictor(
model=flow,
network=conditioned.network_fn,
prediction_type="velocity",
)
# Create integrator and denoiser
integrator = EulerIntegrator(model=flow, timer=timer)
denoiser = Denoiser(
integrator=integrator,
model=flow,
predictor=predictor,
x0_shape=(transformer_hw[0], transformer_hw[1], conditioned.in_channels),
)
# ===========================
# 5. Generate Image
# ===========================
key = jax.random.PRNGKey(SEED)
state, _ = denoiser.generate(
rng_key=key,
n_steps=NUM_STEPS,
n_particles=1,
keep_history=False,
)
# Get latent from generation
latent = state.integrator_state.position
# ===========================
# 6. Decode to Image
# ===========================
image = loader.decode_latent(latent)
print(f"Generated image shape: {image.shape}")
# Save image (image is a numpy array in [0, 1] range)
from PIL import Image
img = Image.fromarray((image * 255).astype('uint8'))
img.save("output.png")
```
### Model Specifications
- **Architecture**: Transformer-based diffusion model
- **Hidden Dimension**: 3072
- **Attention Heads**: 24
- **Double Layers**: 19
- **Single Layers**: 38
- **Precision**: BFloat16
- **Joint Attention Dimension**: 4096
- **Pooled Projection Dimension**: 768
- **In Channels**: 64
## License
Please refer to the original FLUX.1-dev license terms for usage restrictions and guidelines.
## Citation
If you use this model in your research, please cite the original FLUX paper and the Diffuse library.
```bibtex
@software{diffuse2024,
title = {Diffuse: A modular diffusion model library},
author = {Iollo, J., Oudoumanessah G.},
year = {2025},
url = {https://github.com/jcopo/diffuse}
}
```