File size: 5,149 Bytes
39c01d4
 
 
 
 
02ebd21
 
39c01d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7289dfb
 
 
39c01d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
---
license: apache-2.0
---
该模型是将Mamba的注意力用在stable diffusion V1.5 的U-Net网络里(即替换替换原有的自注意力层),然后进行了训练和评估,评估指标有FID和CLIP-T,GPU峰值显存占用。目前模型还未进一步改进。下面的说明是运行程序的指令。

修改后的U-Net网络图
![image/png](https://cdn-uploads.huggingface.co/production/uploads/67b858cc26e7d5f7cb139325/gs0JYLZuuNmadj4daNm00.png)

推理代码
python msd_infer.py --base_model="runwayml/stable-diffusion-v1-5" --checkpoint_dir="/root/mamba/sd-mamba-mscoco-urltext-10k-run3/checkpoint-31000" --unet_subfolder="unet_mamba" --prompt="a river" --output_path="ccat.png" --device="cuda" --seed=12345 --num_inference_steps=50 --guidance_scale=8.0 --mamba_d_state=16 --mamba_d_conv=4 --mamba_expand=2 --pipeline_dtype="float32"


训练代码(shuffling ,sample,5000/10000)
accelerate launch train_mamba_sd.py --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" --dataset_name="ChristophSchuhmann/MS_COCO_2017_URL_TEXT" --output_dir="sd-mamba-mscoco-urltext-10k-run3" --resolution=512 --max_train_samples=6000 --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing --max_train_steps=50000 --learning_rate=1e-5 --lr_scheduler="cosine" --lr_warmup_steps=100 --mamba_d_state=16 --mamba_d_conv=4 --mamba_expand=2 --dataloader_num_workers=8 --preprocessing_num_workers=16 --seed=28 --mixed_precision="fp16" --use_8bit_adam --allow_tf32 --report_to="tensorboard" --validation_prompt="A high-resolution photo of a fluffy cat sitting on a windowsill, bathed in sunlight." --validation_steps=500 --num_validation_images=1 --checkpointing_steps=500 --checkpoints_total_limit=3 --resume_from_checkpoint="/root/mamba/sd-mamba-mscoco-urltext-20k-run2/checkpoint-16500"
此时设置的seed=42,是从打乱后的数据集中选取10001个样本,并且是按照seed=42方式打乱的,所以seed改变

rm -rf /root/.cache/huggingface/datasets/ChristophSchuhmann___ms_coco_2017_url_text

评估命令
python eval.py --model_checkpoint_path /root/mamba/sd-mamba-mscoco-urltext-10k-run3/checkpoint-31000 --coco_val_images_path /root/mamba/val2014 --coco_annotations_path /root/mamba/annotations --output_dir ./eval_results --num_samples 5000 --unet_subfolder unet_mamba --skip_generation

wget annotations/val2014(复现时)

768×768分辨率(对于eval2.py直接Python eval2.py)(tuili.py是msd_infer.py的768×768版本)
python tuili.py --base_model="runwayml/stable-diffusion-v1-5" --checkpoint_dir="/root/mamba/sd-mamba-mscoco-urltext-10k-run3/checkpoint-31000" --unet_subfolder="unet_mamba" --prompt="a garden" --output_path="ccat.png" --device="cuda" --seed=12345 --num_inference_steps=50 --guidance_scale=8.0 --width 768 --height 768 --mamba_d_state=16 --mamba_d_conv=4 --mamba_expand=2 --pipeline_dtype="float32" 

下面是环境配置说明

![image/png](https://cdn-uploads.huggingface.co/production/uploads/67b858cc26e7d5f7cb139325/Y7GUws9nxydyLIiEKNYkw.png)

# Name                    Version                   Build  Channel
accelerate                1.6.0                    pypi_0    pypi
bitsandbytes              0.45.5                   pypi_0    pypi
clean-fid                 0.1.35                   pypi_0    pypi
clip                      1.0                      pypi_0    pypi
cuda-version              12.8                          3    nvidia
datasets                  3.5.0                    pypi_0    pypi
diffusers                 0.33.1                   pypi_0    pypi
einops                    0.8.1                    pypi_0    pypi

importlib-metadata        8.6.1                    pypi_0    pypi

jinja2                    3.1.6           py310h06a4308_0  

mamba-ssm                 2.2.4                    pypi_0    pypi

ninja                     1.11.1.4                 pypi_0    pypi
numpy                     2.0.1           py310h5f9d8c6_1  
numpy-base                2.0.1           py310hb5e798b_1  

open-clip-torch           2.32.0                   pypi_0    pypi

pandas                    2.2.3                    pypi_0    pypi
pillow                    11.1.0          py310hac6e08b_1  
pip                       25.0.1                   pypi_0    pypi

python                    3.10.16              he870216_1  

pytorch                   2.5.1           py3.10_cuda12.4_cudnn9.1.0_0    pytorch
pytorch-cuda              12.4                 hc786d27_7    pytorch

scipy                     1.15.2                   pypi_0    pypi

tokenizers                0.21.1                   pypi_0    pypi
torch-fidelity            0.3.0                    pypi_0    pypi
torchaudio                2.5.1               py310_cu124    pytorch
torchmetrics              1.7.1                    pypi_0    pypi
torchtriton               3.1.0                     py310    pytorch
torchvision               0.20.1              py310_cu124    pytorch
tqdm                      4.67.1                   
transformers              4.51.3                   
typing_extensions         4.12.2          

xformers                  0.0.30