--- license: other library_name: diffuse tags: - text-to-image - diffusion - flux - jax - flax --- # FLUX.1-dev Model (JAX/Flax) ![Downloads](https://img.shields.io/badge/dynamic/json?color=blue&label=downloads&query=%24.downloads&url=https%3A%2F%2Fhuggingface.co%2Fapi%2Fmodels%2Fjcopo%2Fflux_jax) This repository contains a FLUX.1-dev text-to-image diffusion model stored in Orbax/JAX format, optimized for use with JAX/Flax frameworks. ## Model Description FLUX.1-dev is a powerful text-to-image generation model that uses a transformer-based architecture with dual text encoders (CLIP and T5) for enhanced text understanding and image generation capabilities. ### Components This model includes the following components: - **Transformer**: Main diffusion transformer model - **VAE**: Variational Autoencoder for image encoding/decoding - **CLIP Text Encoder**: For text understanding - **T5 Text Encoder**: For enhanced text processing - **Tokenizers**: CLIP and T5 tokenizers ## Usage To use this model, you should use the **Diffuse** library, which provides an easy-to-use interface for FLUX models. ### Tutorial For a comprehensive tutorial on using FLUX models with Diffuse, please refer to: **[FLUX Tutorial Documentation](https://diffuse.readthedocs.io/en/latest/flux_tutorial.html)** ## Resources - **Diffuse Library**: [https://github.com/jcopo/diffuse](https://github.com/jcopo/diffuse) - **Documentation**: [https://diffuse.readthedocs.io/](https://diffuse.readthedocs.io/) - **FLUX Tutorial**: [https://diffuse.readthedocs.io/en/latest/flux_tutorial.html](https://diffuse.readthedocs.io/en/latest/flux_tutorial.html) ## Model Format This model is stored in **Orbax checkpoint format**, optimized for JAX/Flax frameworks. The Diffuse library handles loading and inference automatically. ### Installation Install the Diffuse library: ```bash pip install git+https://github.com/jcopo/diffuse.git ``` ### Code Example ```python import jax from pathlib import Path from huggingface_hub import snapshot_download from diffuse import FluxModelLoader, FluxTimer, Flow, Predictor, Denoiser from diffuse.integrators import EulerIntegrator from diffuse.utils import _latent_shapes # =========================== # 1. Download Model # =========================== HF_REPO_ID = "jcopo/flux_jax" checkpoint_dir = Path(snapshot_download(repo_id=HF_REPO_ID, repo_type="model")) # =========================== # 2. Set Generation Parameters # =========================== PROMPT = "A serene landscape with mountains at sunset, highly detailed, photorealistic" HEIGHT = 512 WIDTH = 512 NUM_STEPS = 20 GUIDANCE_SCALE = 4.0 SEED = 42 # =========================== # 3. Load Model and Prepare Network # =========================== loader = FluxModelLoader(checkpoint_dir=checkpoint_dir, verbose=True) conditioned = loader.prepare_conditioned_network( prompt=PROMPT, negative_prompt=None, guidance_scale=GUIDANCE_SCALE, height=HEIGHT, width=WIDTH, ) # =========================== # 4. Setup Diffusion Components # =========================== _, transformer_hw = _latent_shapes(HEIGHT, WIDTH) image_seq_len = transformer_hw[0] * transformer_hw[1] # Initialize timer with dynamic shift timer = FluxTimer(num_steps=NUM_STEPS, use_dynamic_shift=True) timer.set_image_seq_len(image_seq_len) # Create flow model and predictor flow = Flow(tf=1.0) predictor = Predictor( model=flow, network=conditioned.network_fn, prediction_type="velocity", ) # Create integrator and denoiser integrator = EulerIntegrator(model=flow, timer=timer) denoiser = Denoiser( integrator=integrator, model=flow, predictor=predictor, x0_shape=(transformer_hw[0], transformer_hw[1], conditioned.in_channels), ) # =========================== # 5. Generate Image # =========================== key = jax.random.PRNGKey(SEED) state, _ = denoiser.generate( rng_key=key, n_steps=NUM_STEPS, n_particles=1, keep_history=False, ) # Get latent from generation latent = state.integrator_state.position # =========================== # 6. Decode to Image # =========================== image = loader.decode_latent(latent) print(f"Generated image shape: {image.shape}") # Save image (image is a numpy array in [0, 1] range) from PIL import Image img = Image.fromarray((image * 255).astype('uint8')) img.save("output.png") ``` ### Model Specifications - **Architecture**: Transformer-based diffusion model - **Hidden Dimension**: 3072 - **Attention Heads**: 24 - **Double Layers**: 19 - **Single Layers**: 38 - **Precision**: BFloat16 - **Joint Attention Dimension**: 4096 - **Pooled Projection Dimension**: 768 - **In Channels**: 64 ## License Please refer to the original FLUX.1-dev license terms for usage restrictions and guidelines. ## Citation If you use this model in your research, please cite the original FLUX paper and the Diffuse library. ```bibtex @software{diffuse2024, title = {Diffuse: A modular diffusion model library}, author = {Iollo, J., Oudoumanessah G.}, year = {2025}, url = {https://github.com/jcopo/diffuse} } ```