Text-to-Image
diffusion
safety
dose-response
dose-response-c2 / README.md
anonym371's picture
Squash
da1eb8b
---
license: apache-2.0
tags: [diffusion, text-to-image, safety, dose-response]
base_model: Photoroom/PRX
datasets:
- lehduong/flux_generated
- LucasFang/FLUX-Reason-6M
- brivangl/midjourney-v6-llava
pipeline_tag: text-to-image
---
# Dose-Response C2 (8M-5%): Unsafe oversampled to 5%
Part of a **dose-response experiment** studying how unsafe training data fraction affects text-to-image model output safety.
## Condition
| | |
|---|---|
| Label | C2 (8M-5%) |
| Description | Unsafe oversampled to p=5%. |
| Training set size N | 8.24M |
| Unsafe fraction p | 5% |
| Unsafe count U | ~412K |
## Architecture
| | |
|---|---|
| Class | PRX (rectified-flow DiT) |
| Hidden size | 1792 |
| Depth | 16 |
| Heads | 28 |
| MLP ratio | 3.5 |
| Patch size | 32 px |
| Bottleneck | 256 |
| Resolution | 512×512 |
## Text encoder
| | |
|---|---|
| Model | `google/t5gemma-2b-2b-ul2` |
| Max prompt tokens | 256 |
| Dtype | bfloat16 |
## Diffusion scheduler
| | |
|---|---|
| Type | x-prediction flow matching |
| Train timesteps | 1000 |
| Timestep shift | 3.0 |
## Training
| | |
|---|---|
| Iterations | 100,000 |
| Samples seen | ~25.60M |
| Global batch size | 256 |
| Microbatch (per GPU) | 32 |
| Hardware | 8× NVIDIA H200 |
| Precision | bfloat16 (amp_bf16) |
| Optimizer (transformer blocks) | Muon (lr=1e-4, momentum=0.95, nesterov, ns_steps=5, weight_decay=0) |
| Optimizer (other params) | AdamW (lr=1e-4, β=(0.9, 0.95), eps=1e-8, weight_decay=0) |
| LR schedule | 1,000-step linear warmup, constant after |
| EMA | decay 0.999, started at step 0 |
| Random seed | 42 |
| Trainer | Composer + FSDP |
## Training data sources
The training set combines three image datasets, with per-condition filtering/oversampling:
- [`lehduong/flux_generated`](https://huggingface.co/datasets/lehduong/flux_generated) (~1.7M)
- [`LucasFang/FLUX-Reason-6M`](https://huggingface.co/datasets/LucasFang/FLUX-Reason-6M) (~6M)
- [`brivangl/midjourney-v6-llava`](https://huggingface.co/datasets/brivangl/midjourney-v6-llava) (~1M)
## Files
- `denoiser.pt` — Consolidated EMA-denoiser checkpoint
- `config.yaml` — Full training configuration
## Framework
Trained with the [PRX](https://github.com/Photoroom/PRX) framework (Composer + FSDP). The full `config.yaml` is included for reproducibility.