metadata
license: apache-2.0
tags:
- diffusion
- text-to-image
- safety
- dose-response
base_model: Photoroom/PRX
datasets:
- lehduong/flux_generated
- LucasFang/FLUX-Reason-6M
- brivangl/midjourney-v6-llava
pipeline_tag: text-to-image
Dose-Response C4: Original proportion (~1.21% unsafe), 1M scale
This model is part of a dose-response experiment studying how the fraction of unsafe content in training data affects the safety of generated images from text-to-image diffusion models.
Model Details
| Architecture | PRX-1.2B (Photoroom diffusion model) |
| Parameters | 1.2B (denoiser only) |
| Resolution | 512px |
| Condition | C4 — Original proportion (~1.21% unsafe), 1M scale |
| Unsafe fraction | ~1.21% (original) |
| Training set size | 1M images |
| Training steps | 100K batches |
| Batch size | 1024 (global) |
| Precision | bf16 |
| Hardware | 8x H200 GPUs |
Condition Description
Downscaled to 1M images while preserving the original ~1.21% unsafe proportion (12K unsafe, 988K safe).
Dose-Response Conditions Overview
This model is one of 7 conditions in the dose-response experiment:
| Condition | Unsafe Fraction | Dataset Scale | Description |
|---|---|---|---|
| C0 | 0% | Full (~7.85M) | All unsafe removed |
| C1 | 5% | Full (~8.24M) | Unsafe oversampled to 5% |
| C2 | 10% | Full (~8.72M) | Unsafe oversampled to 10% |
| C3 | ~1.21% | Full (~7.94M) | Original composition |
| C4 | ~1.21% | 1M | Original proportion, downscaled |
| C5 | ~9.6% | 1M | All unsafe included, downscaled |
| C6 | ~1.21% | 100K | Original proportion, small scale |
Training Details
- Base architecture: PRX 1.2B
- Text encoder: T5-Gemma-2B (frozen)
- VAE: Identity (no compression)
- Optimizer: Muon
- Algorithms: TREAD + REPA-v3 + LPIPS + Perceptual DINO + EMA
- EMA smoothing: 0.999 (updated every 10 batches)
- Training data sources:
lehduong/flux_generated,LucasFang/FLUX-Reason-6M,brivangl/midjourney-v6-llava - Safety annotations: Training data annotated with LlavaGuard-7B to classify images as safe/unsafe
Files
denoiser.pt— Consolidated single-file checkpoint (EMA weights, ready for inference)distributed/— Original FSDP distributed checkpoint shardsconfig.yaml— Full Hydra training configuration
Usage
import torch
# Load consolidated checkpoint
state_dict = torch.load("denoiser.pt", map_location="cpu")
# Keys are in format: denoiser.*
For the full generation pipeline, see the diffusion_safety repository.
Citation
If you use these models, please cite the associated paper and the PRX architecture.
License
Apache 2.0