FLUX.1-dev Model (JAX/Flax)

Downloads

This repository contains a FLUX.1-dev text-to-image diffusion model stored in Orbax/JAX format, optimized for use with JAX/Flax frameworks.

Model Description

FLUX.1-dev is a powerful text-to-image generation model that uses a transformer-based architecture with dual text encoders (CLIP and T5) for enhanced text understanding and image generation capabilities.

Components

This model includes the following components:

  • Transformer: Main diffusion transformer model
  • VAE: Variational Autoencoder for image encoding/decoding
  • CLIP Text Encoder: For text understanding
  • T5 Text Encoder: For enhanced text processing
  • Tokenizers: CLIP and T5 tokenizers

Usage

To use this model, you should use the Diffuse library, which provides an easy-to-use interface for FLUX models.

Tutorial

For a comprehensive tutorial on using FLUX models with Diffuse, please refer to: FLUX Tutorial Documentation

Resources

Model Format

This model is stored in Orbax checkpoint format, optimized for JAX/Flax frameworks. The Diffuse library handles loading and inference automatically.

Installation

Install the Diffuse library:

pip install git+https://github.com/jcopo/diffuse.git

Code Example

import jax
from pathlib import Path
from huggingface_hub import snapshot_download
from diffuse import FluxModelLoader, FluxTimer, Flow, Predictor, Denoiser
from diffuse.integrators import EulerIntegrator
from diffuse.utils import _latent_shapes

# ===========================
# 1. Download Model
# ===========================
HF_REPO_ID = "jcopo/flux_jax"

checkpoint_dir = Path(snapshot_download(repo_id=HF_REPO_ID, repo_type="model"))

# ===========================
# 2. Set Generation Parameters
# ===========================
PROMPT = "A serene landscape with mountains at sunset, highly detailed, photorealistic"
HEIGHT = 512
WIDTH = 512
NUM_STEPS = 20
GUIDANCE_SCALE = 4.0
SEED = 42

# ===========================
# 3. Load Model and Prepare Network
# ===========================
loader = FluxModelLoader(checkpoint_dir=checkpoint_dir, verbose=True)

conditioned = loader.prepare_conditioned_network(
    prompt=PROMPT,
    negative_prompt=None,
    guidance_scale=GUIDANCE_SCALE,
    height=HEIGHT,
    width=WIDTH,
)

# ===========================
# 4. Setup Diffusion Components
# ===========================
_, transformer_hw = _latent_shapes(HEIGHT, WIDTH)
image_seq_len = transformer_hw[0] * transformer_hw[1]

# Initialize timer with dynamic shift
timer = FluxTimer(num_steps=NUM_STEPS, use_dynamic_shift=True)
timer.set_image_seq_len(image_seq_len)

# Create flow model and predictor
flow = Flow(tf=1.0)
predictor = Predictor(
    model=flow,
    network=conditioned.network_fn,
    prediction_type="velocity",
)

# Create integrator and denoiser
integrator = EulerIntegrator(model=flow, timer=timer)
denoiser = Denoiser(
    integrator=integrator,
    model=flow,
    predictor=predictor,
    x0_shape=(transformer_hw[0], transformer_hw[1], conditioned.in_channels),
)

# ===========================
# 5. Generate Image
# ===========================
key = jax.random.PRNGKey(SEED)
state, _ = denoiser.generate(
    rng_key=key,
    n_steps=NUM_STEPS,
    n_particles=1,
    keep_history=False,
)

# Get latent from generation
latent = state.integrator_state.position

# ===========================
# 6. Decode to Image
# ===========================
image = loader.decode_latent(latent)
print(f"Generated image shape: {image.shape}")

# Save image (image is a numpy array in [0, 1] range)
from PIL import Image
img = Image.fromarray((image * 255).astype('uint8'))
img.save("output.png")

Model Specifications

  • Architecture: Transformer-based diffusion model
  • Hidden Dimension: 3072
  • Attention Heads: 24
  • Double Layers: 19
  • Single Layers: 38
  • Precision: BFloat16
  • Joint Attention Dimension: 4096
  • Pooled Projection Dimension: 768
  • In Channels: 64

License

Please refer to the original FLUX.1-dev license terms for usage restrictions and guidelines.

Citation

If you use this model in your research, please cite the original FLUX paper and the Diffuse library.

@software{diffuse2024,
  title = {Diffuse: A modular diffusion model library},
  author = {Iollo, J., Oudoumanessah G.},
  year = {2025},
  url = {https://github.com/jcopo/diffuse}
}
Downloads last month
7
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support