--- license: apache-2.0 tags: - diffusion - text-to-image - safety - dose-response base_model: Photoroom/PRX datasets: - lehduong/flux_generated - LucasFang/FLUX-Reason-6M - brivangl/midjourney-v6-llava pipeline_tag: text-to-image --- # Dose-Response C4: Original proportion (~1.21% unsafe), 1M scale This model is part of a **dose-response experiment** studying how the fraction of unsafe content in training data affects the safety of generated images from text-to-image diffusion models. ## Model Details | | | |---|---| | **Architecture** | PRX-1.2B (Photoroom diffusion model) | | **Parameters** | 1.2B (denoiser only) | | **Resolution** | 512px | | **Condition** | C4 — Original proportion (~1.21% unsafe), 1M scale | | **Unsafe fraction** | ~1.21% (original) | | **Training set size** | 1M images | | **Training steps** | 100K batches | | **Batch size** | 1024 (global) | | **Precision** | bf16 | | **Hardware** | 8x H200 GPUs | ## Condition Description Downscaled to 1M images while preserving the original ~1.21% unsafe proportion (12K unsafe, 988K safe). ## Dose-Response Conditions Overview This model is one of 7 conditions in the dose-response experiment: | Condition | Unsafe Fraction | Dataset Scale | Description | |-----------|----------------|---------------|-------------| | **C0** | 0% | Full (~7.85M) | All unsafe removed | | **C1** | 5% | Full (~8.24M) | Unsafe oversampled to 5% | | **C2** | 10% | Full (~8.72M) | Unsafe oversampled to 10% | | **C3** | ~1.21% | Full (~7.94M) | Original composition | | **C4** | ~1.21% | 1M | Original proportion, downscaled | | **C5** | ~9.6% | 1M | All unsafe included, downscaled | | **C6** | ~1.21% | 100K | Original proportion, small scale | ## Training Details - **Base architecture**: [PRX](https://github.com/Photoroom/PRX) 1.2B - **Text encoder**: T5-Gemma-2B (frozen) - **VAE**: Identity (no compression) - **Optimizer**: Muon - **Algorithms**: TREAD + REPA-v3 + LPIPS + Perceptual DINO + EMA - **EMA smoothing**: 0.999 (updated every 10 batches) - **Training data sources**: `lehduong/flux_generated`, `LucasFang/FLUX-Reason-6M`, `brivangl/midjourney-v6-llava` - **Safety annotations**: Training data annotated with [LlavaGuard-7B](https://huggingface.co/AIML-TUDA/LlavaGuard-v1.2-7B-OV) to classify images as safe/unsafe ## Files - `denoiser.pt` — Consolidated single-file checkpoint (EMA weights, ready for inference) - `distributed/` — Original FSDP distributed checkpoint shards - `config.yaml` — Full Hydra training configuration ## Usage ```python import torch # Load consolidated checkpoint state_dict = torch.load("denoiser.pt", map_location="cpu") # Keys are in format: denoiser.* ``` For the full generation pipeline, see the [diffusion_safety](https://github.com/felifri/diffusion_safety) repository. ## Citation If you use these models, please cite the associated paper and the PRX architecture. ## License Apache 2.0