FLUX.1-dev Model (JAX/Flax)
This repository contains a FLUX.1-dev text-to-image diffusion model stored in Orbax/JAX format, optimized for use with JAX/Flax frameworks.
Model Description
FLUX.1-dev is a powerful text-to-image generation model that uses a transformer-based architecture with dual text encoders (CLIP and T5) for enhanced text understanding and image generation capabilities.
Components
This model includes the following components:
- Transformer: Main diffusion transformer model
- VAE: Variational Autoencoder for image encoding/decoding
- CLIP Text Encoder: For text understanding
- T5 Text Encoder: For enhanced text processing
- Tokenizers: CLIP and T5 tokenizers
Usage
To use this model, you should use the Diffuse library, which provides an easy-to-use interface for FLUX models.
Tutorial
For a comprehensive tutorial on using FLUX models with Diffuse, please refer to: FLUX Tutorial Documentation
Resources
- Diffuse Library: https://github.com/jcopo/diffuse
- Documentation: https://diffuse.readthedocs.io/
- FLUX Tutorial: https://diffuse.readthedocs.io/en/latest/flux_tutorial.html
Model Format
This model is stored in Orbax checkpoint format, optimized for JAX/Flax frameworks. The Diffuse library handles loading and inference automatically.
Installation
Install the Diffuse library:
pip install git+https://github.com/jcopo/diffuse.git
Code Example
import jax
from pathlib import Path
from huggingface_hub import snapshot_download
from diffuse import FluxModelLoader, FluxTimer, Flow, Predictor, Denoiser
from diffuse.integrators import EulerIntegrator
from diffuse.utils import _latent_shapes
# ===========================
# 1. Download Model
# ===========================
HF_REPO_ID = "jcopo/flux_jax"
checkpoint_dir = Path(snapshot_download(repo_id=HF_REPO_ID, repo_type="model"))
# ===========================
# 2. Set Generation Parameters
# ===========================
PROMPT = "A serene landscape with mountains at sunset, highly detailed, photorealistic"
HEIGHT = 512
WIDTH = 512
NUM_STEPS = 20
GUIDANCE_SCALE = 4.0
SEED = 42
# ===========================
# 3. Load Model and Prepare Network
# ===========================
loader = FluxModelLoader(checkpoint_dir=checkpoint_dir, verbose=True)
conditioned = loader.prepare_conditioned_network(
prompt=PROMPT,
negative_prompt=None,
guidance_scale=GUIDANCE_SCALE,
height=HEIGHT,
width=WIDTH,
)
# ===========================
# 4. Setup Diffusion Components
# ===========================
_, transformer_hw = _latent_shapes(HEIGHT, WIDTH)
image_seq_len = transformer_hw[0] * transformer_hw[1]
# Initialize timer with dynamic shift
timer = FluxTimer(num_steps=NUM_STEPS, use_dynamic_shift=True)
timer.set_image_seq_len(image_seq_len)
# Create flow model and predictor
flow = Flow(tf=1.0)
predictor = Predictor(
model=flow,
network=conditioned.network_fn,
prediction_type="velocity",
)
# Create integrator and denoiser
integrator = EulerIntegrator(model=flow, timer=timer)
denoiser = Denoiser(
integrator=integrator,
model=flow,
predictor=predictor,
x0_shape=(transformer_hw[0], transformer_hw[1], conditioned.in_channels),
)
# ===========================
# 5. Generate Image
# ===========================
key = jax.random.PRNGKey(SEED)
state, _ = denoiser.generate(
rng_key=key,
n_steps=NUM_STEPS,
n_particles=1,
keep_history=False,
)
# Get latent from generation
latent = state.integrator_state.position
# ===========================
# 6. Decode to Image
# ===========================
image = loader.decode_latent(latent)
print(f"Generated image shape: {image.shape}")
# Save image (image is a numpy array in [0, 1] range)
from PIL import Image
img = Image.fromarray((image * 255).astype('uint8'))
img.save("output.png")
Model Specifications
- Architecture: Transformer-based diffusion model
- Hidden Dimension: 3072
- Attention Heads: 24
- Double Layers: 19
- Single Layers: 38
- Precision: BFloat16
- Joint Attention Dimension: 4096
- Pooled Projection Dimension: 768
- In Channels: 64
License
Please refer to the original FLUX.1-dev license terms for usage restrictions and guidelines.
Citation
If you use this model in your research, please cite the original FLUX paper and the Diffuse library.
@software{diffuse2024,
title = {Diffuse: A modular diffusion model library},
author = {Iollo, J., Oudoumanessah G.},
year = {2025},
url = {https://github.com/jcopo/diffuse}
}
- Downloads last month
- 7