Causal-Adapter / README.md
LeiTong's picture
Update README.md
e868925 verified
---
license: apache-2.0
base_model:
- lambda/miniSD-diffusers
- stabilityai/stable-diffusion-3-medium-diffusers
library_name: diffusers
pipeline_tag: text-to-image
tags:
- diffusion
- stable-diffusion
- stable-diffusion-3
- controlnet
- causal-inference
- counterfactual-generation
- causal-adapter
---
# Causal-Adapter Pretrained Weights
## Model Overview
This repository provides pretrained Causal-Adapter weights across four benchmark settings. The released checkpoints include Causal-Adapter models built on both SD1.5-style and SD3-style diffusion structures.
Causal-Adapter is designed to inject structured causal semantics into pretrained text-to-image diffusion models for controllable and causally consistent counterfactual image generation.
Detailed usage examples are available in our notebook benchmarks:
[Notebook Benchmarks](https://github.com/LeiTong02/Causal-Adapter/tree/main/notebook_benchmarks)
An example configuration can be found in:
```text
notebook_benchmarks/counterfactuals_celeba.ipynb
```
## Base Models
The released checkpoints are based on the following pretrained diffusion backbones:
- **SD1.5-style structure:** `lambda/miniSD-diffusers`
- **SD3-style structure:** `stabilityai/stable-diffusion-3-medium-diffusers`
## Benchmark Resources
The released weights are evaluated on benchmark settings built from the following resources:
- **Pendulum dataset generation:**
[CausalVAE Pendulum](https://github.com/huawei-noah/trustworthyAI/blob/master/research/CausalVAE/causal_data/pendulum.py)
- **CelebA and ADNI benchmark configuration:**
[counterfactual-benchmark](https://github.com/gulnazaki/counterfactual-benchmark)
- **CelebA-HQ dataset:**
[CelebAMask-HQ](https://github.com/switchablenorms/CelebAMask-HQ)
## Example Configuration
The following example shows the main paths required for running the CelebA counterfactual generation notebook.
```python
import os
# Shared roots
# 1) Frozen SD1.5 backbone.
# For example: "lambda/miniSD-diffusers"
BASE_MODEL_PATH = ""
# 2) Causal-Adapter ControlNet checkpoint and the matching MCPL learned pseudo-tokens.
# Example ControlNet checkpoint:
# https://huggingface.co/LeiTong/Causal-Adapter/tree/main/celeba/controlnet/controlnet-steps-200000.safetensors
CONTROLNET_PATH = ""
# Example learned text embeddings:
# https://huggingface.co/LeiTong/Causal-Adapter/tree/main/celeba/controlnet/learned_embeds-steps-200000.safetensors
TEXT_EMBEDDING_PATH = ""
# 3) Optional pretrained SCM head from SCM_modeling/.
# Example SCM checkpoint:
# https://huggingface.co/LeiTong/Causal-Adapter/tree/main/celeba/scm/best_model.pt
SCM_PATH = ""
# 4) CelebA root expected by torchvision.datasets.CelebA(root=...).
DATA_ROOT = os.environ.get("DATA_ROOT", "")
DATASET = "celeA_complex"
SIZE = 256
```
## Repository Structure
The checkpoint files are organized by benchmark and model component. A typical setting may include:
- Causal-Adapter / ControlNet weights
- Learned pseudo-token embeddings
- Optional pretrained SCM head
- Example notebooks for counterfactual image generation
## Usage
Please refer to the notebook examples for loading the pretrained weights and running counterfactual generation:
[Notebook Benchmarks](https://github.com/LeiTong02/Causal-Adapter/tree/main/notebook_benchmarks)
## License
This repository is released under the Apache-2.0 license.