| --- |
| license: apache-2.0 |
| tags: |
| - diffusion |
| - text-to-image |
| - safety |
| - dose-response |
| base_model: Photoroom/PRX |
| datasets: |
| - lehduong/flux_generated |
| - LucasFang/FLUX-Reason-6M |
| - brivangl/midjourney-v6-llava |
| pipeline_tag: text-to-image |
| --- |
| |
| # Dose-Response C0: 0% unsafe, full scale |
|
|
| This model is part of a **dose-response experiment** studying how the fraction of unsafe content in training data affects the safety of generated images from text-to-image diffusion models. |
|
|
| ## Model Details |
|
|
| | | | |
| |---|---| |
| | **Architecture** | PRX-1.2B (Photoroom diffusion model) | |
| | **Parameters** | 1.2B (denoiser only) | |
| | **Resolution** | 512px | |
| | **Condition** | C0 — 0% unsafe, full scale | |
| | **Unsafe fraction** | 0% | |
| | **Training set size** | ~7.85M images | |
| | **Training steps** | 100K batches | |
| | **Batch size** | 1024 (global) | |
| | **Precision** | bf16 | |
| | **Hardware** | 8x H200 GPUs | |
|
|
| ## Condition Description |
|
|
| All unsafe images removed. Training uses only the safe pool (7.85M safe images). |
|
|
| ## Dose-Response Conditions Overview |
|
|
| This model is one of 7 conditions in the dose-response experiment: |
|
|
| | Condition | Unsafe Fraction | Dataset Scale | Description | |
| |-----------|----------------|---------------|-------------| |
| | **C0** | 0% | Full (~7.85M) | All unsafe removed | |
| | **C1** | 5% | Full (~8.24M) | Unsafe oversampled to 5% | |
| | **C2** | 10% | Full (~8.72M) | Unsafe oversampled to 10% | |
| | **C3** | ~1.21% | Full (~7.94M) | Original composition | |
| | **C4** | ~1.21% | 1M | Original proportion, downscaled | |
| | **C5** | ~9.6% | 1M | All unsafe included, downscaled | |
| | **C6** | ~1.21% | 100K | Original proportion, small scale | |
|
|
| ## Training Details |
|
|
| - **Base architecture**: [PRX](https://github.com/Photoroom/PRX) 1.2B |
| - **Text encoder**: T5-Gemma-2B (frozen) |
| - **VAE**: Identity (no compression) |
| - **Optimizer**: Muon |
| - **Algorithms**: TREAD + REPA-v3 + LPIPS + Perceptual DINO + EMA |
| - **EMA smoothing**: 0.999 (updated every 10 batches) |
| - **Training data sources**: `lehduong/flux_generated`, `LucasFang/FLUX-Reason-6M`, `brivangl/midjourney-v6-llava` |
| - **Safety annotations**: Training data annotated with [LlavaGuard-7B](https://huggingface.co/AIML-TUDA/LlavaGuard-v1.2-7B-OV) to classify images as safe/unsafe |
|
|
| ## Files |
|
|
| - `denoiser.pt` — Consolidated single-file checkpoint (EMA weights, ready for inference) |
| - `distributed/` — Original FSDP distributed checkpoint shards |
| - `config.yaml` — Full Hydra training configuration |
|
|
| ## Usage |
|
|
| ```python |
| import torch |
| |
| # Load consolidated checkpoint |
| state_dict = torch.load("denoiser.pt", map_location="cpu") |
| # Keys are in format: denoiser.* |
| ``` |
|
|
| For the full generation pipeline, see the [diffusion_safety](https://github.com/felifri/diffusion_safety) repository. |
|
|
| ## Citation |
|
|
| If you use these models, please cite the associated paper and the PRX architecture. |
|
|
| ## License |
|
|
| Apache 2.0 |
|
|