|
|
--- |
|
|
license: other |
|
|
library_name: diffuse |
|
|
tags: |
|
|
- text-to-image |
|
|
- diffusion |
|
|
- flux |
|
|
- jax |
|
|
- flax |
|
|
--- |
|
|
|
|
|
# FLUX.1-dev Model (JAX/Flax) |
|
|
|
|
|
 |
|
|
|
|
|
This repository contains a FLUX.1-dev text-to-image diffusion model stored in Orbax/JAX format, optimized for use with JAX/Flax frameworks. |
|
|
|
|
|
## Model Description |
|
|
|
|
|
FLUX.1-dev is a powerful text-to-image generation model that uses a transformer-based architecture with dual text encoders (CLIP and T5) for enhanced text understanding and image generation capabilities. |
|
|
|
|
|
|
|
|
### Components |
|
|
|
|
|
This model includes the following components: |
|
|
|
|
|
- **Transformer**: Main diffusion transformer model |
|
|
- **VAE**: Variational Autoencoder for image encoding/decoding |
|
|
- **CLIP Text Encoder**: For text understanding |
|
|
- **T5 Text Encoder**: For enhanced text processing |
|
|
- **Tokenizers**: CLIP and T5 tokenizers |
|
|
|
|
|
## Usage |
|
|
|
|
|
To use this model, you should use the **Diffuse** library, which provides an easy-to-use interface for FLUX models. |
|
|
|
|
|
### Tutorial |
|
|
|
|
|
For a comprehensive tutorial on using FLUX models with Diffuse, please refer to: |
|
|
**[FLUX Tutorial Documentation](https://diffuse.readthedocs.io/en/latest/flux_tutorial.html)** |
|
|
|
|
|
## Resources |
|
|
|
|
|
- **Diffuse Library**: [https://github.com/jcopo/diffuse](https://github.com/jcopo/diffuse) |
|
|
- **Documentation**: [https://diffuse.readthedocs.io/](https://diffuse.readthedocs.io/) |
|
|
- **FLUX Tutorial**: [https://diffuse.readthedocs.io/en/latest/flux_tutorial.html](https://diffuse.readthedocs.io/en/latest/flux_tutorial.html) |
|
|
|
|
|
## Model Format |
|
|
|
|
|
This model is stored in **Orbax checkpoint format**, optimized for JAX/Flax frameworks. The Diffuse library handles loading and inference automatically. |
|
|
|
|
|
### Installation |
|
|
|
|
|
Install the Diffuse library: |
|
|
|
|
|
```bash |
|
|
pip install git+https://github.com/jcopo/diffuse.git |
|
|
``` |
|
|
|
|
|
### Code Example |
|
|
|
|
|
```python |
|
|
import jax |
|
|
from pathlib import Path |
|
|
from huggingface_hub import snapshot_download |
|
|
from diffuse import FluxModelLoader, FluxTimer, Flow, Predictor, Denoiser |
|
|
from diffuse.integrators import EulerIntegrator |
|
|
from diffuse.utils import _latent_shapes |
|
|
|
|
|
# =========================== |
|
|
# 1. Download Model |
|
|
# =========================== |
|
|
HF_REPO_ID = "jcopo/flux_jax" |
|
|
|
|
|
checkpoint_dir = Path(snapshot_download(repo_id=HF_REPO_ID, repo_type="model")) |
|
|
|
|
|
# =========================== |
|
|
# 2. Set Generation Parameters |
|
|
# =========================== |
|
|
PROMPT = "A serene landscape with mountains at sunset, highly detailed, photorealistic" |
|
|
HEIGHT = 512 |
|
|
WIDTH = 512 |
|
|
NUM_STEPS = 20 |
|
|
GUIDANCE_SCALE = 4.0 |
|
|
SEED = 42 |
|
|
|
|
|
# =========================== |
|
|
# 3. Load Model and Prepare Network |
|
|
# =========================== |
|
|
loader = FluxModelLoader(checkpoint_dir=checkpoint_dir, verbose=True) |
|
|
|
|
|
conditioned = loader.prepare_conditioned_network( |
|
|
prompt=PROMPT, |
|
|
negative_prompt=None, |
|
|
guidance_scale=GUIDANCE_SCALE, |
|
|
height=HEIGHT, |
|
|
width=WIDTH, |
|
|
) |
|
|
|
|
|
# =========================== |
|
|
# 4. Setup Diffusion Components |
|
|
# =========================== |
|
|
_, transformer_hw = _latent_shapes(HEIGHT, WIDTH) |
|
|
image_seq_len = transformer_hw[0] * transformer_hw[1] |
|
|
|
|
|
# Initialize timer with dynamic shift |
|
|
timer = FluxTimer(num_steps=NUM_STEPS, use_dynamic_shift=True) |
|
|
timer.set_image_seq_len(image_seq_len) |
|
|
|
|
|
# Create flow model and predictor |
|
|
flow = Flow(tf=1.0) |
|
|
predictor = Predictor( |
|
|
model=flow, |
|
|
network=conditioned.network_fn, |
|
|
prediction_type="velocity", |
|
|
) |
|
|
|
|
|
# Create integrator and denoiser |
|
|
integrator = EulerIntegrator(model=flow, timer=timer) |
|
|
denoiser = Denoiser( |
|
|
integrator=integrator, |
|
|
model=flow, |
|
|
predictor=predictor, |
|
|
x0_shape=(transformer_hw[0], transformer_hw[1], conditioned.in_channels), |
|
|
) |
|
|
|
|
|
# =========================== |
|
|
# 5. Generate Image |
|
|
# =========================== |
|
|
key = jax.random.PRNGKey(SEED) |
|
|
state, _ = denoiser.generate( |
|
|
rng_key=key, |
|
|
n_steps=NUM_STEPS, |
|
|
n_particles=1, |
|
|
keep_history=False, |
|
|
) |
|
|
|
|
|
# Get latent from generation |
|
|
latent = state.integrator_state.position |
|
|
|
|
|
# =========================== |
|
|
# 6. Decode to Image |
|
|
# =========================== |
|
|
image = loader.decode_latent(latent) |
|
|
print(f"Generated image shape: {image.shape}") |
|
|
|
|
|
# Save image (image is a numpy array in [0, 1] range) |
|
|
from PIL import Image |
|
|
img = Image.fromarray((image * 255).astype('uint8')) |
|
|
img.save("output.png") |
|
|
``` |
|
|
|
|
|
|
|
|
### Model Specifications |
|
|
|
|
|
- **Architecture**: Transformer-based diffusion model |
|
|
- **Hidden Dimension**: 3072 |
|
|
- **Attention Heads**: 24 |
|
|
- **Double Layers**: 19 |
|
|
- **Single Layers**: 38 |
|
|
- **Precision**: BFloat16 |
|
|
- **Joint Attention Dimension**: 4096 |
|
|
- **Pooled Projection Dimension**: 768 |
|
|
- **In Channels**: 64 |
|
|
|
|
|
## License |
|
|
|
|
|
Please refer to the original FLUX.1-dev license terms for usage restrictions and guidelines. |
|
|
|
|
|
## Citation |
|
|
|
|
|
If you use this model in your research, please cite the original FLUX paper and the Diffuse library. |
|
|
|
|
|
```bibtex |
|
|
@software{diffuse2024, |
|
|
title = {Diffuse: A modular diffusion model library}, |
|
|
author = {Iollo, J., Oudoumanessah G.}, |
|
|
year = {2025}, |
|
|
url = {https://github.com/jcopo/diffuse} |
|
|
} |
|
|
``` |
|
|
|