Luo-Yihong commited on
Commit
3af899a
·
verified ·
1 Parent(s): 6286765

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +94 -0
README.md ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TDM-R1: Reinforcing Few-Step Diffusion Models with Non-Differentiable Reward
2
+ <div align="center">
3
+ <a href="https://luo-yihong.github.io/TDM-R1-Page/"><img src="https://img.shields.io/static/v1?label=Project%20Page&message=Github&color=blue&logo=github-pages"></a> &ensp;
4
+ <a href="https://arxiv.org/abs/xxx"><img src="https://img.shields.io/static/v1?label=Paper&message=Arxiv:TDM-R1&color=red&logo=arxiv"></a> &ensp;
5
+ </div>
6
+
7
+
8
+ This is the Official Repository of "[TDM-R1: Reinforcing Few-Step Diffusion Models with Non-Differentiable Reward](https://arxiv.org/abs/xxx)", by *Yihong Luo, Tianyang Hu, Weijian Luo, Jing Tang*.
9
+
10
+ <div align="center">
11
+ <img src="teaser_git.png" width="100%">
12
+ </div>
13
+
14
+ <p align="center">
15
+ Samples generated by <b>TDM-R1</b> using only <b>4 NFEs</b>, obtained by reinforcing the recent powerful Z-Image model.
16
+ </p>
17
+
18
+
19
+ ## Pre-trained Model
20
+
21
+ - [TDM-R1-ZImage](https://huggingface.co/Luo-Yihong/TDM-R1)
22
+
23
+ ## Usage
24
+
25
+ ```python
26
+ import os
27
+ os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
28
+ import torch
29
+ from diffusers import ZImagePipeline
30
+ from peft import LoraConfig, get_peft_model
31
+ def load_ema(pipeline, lora_path, adapter_name='default'):
32
+ """Load EMA weights into the pipeline's transformer adapter"""
33
+ pipeline.transformer.set_adapter(adapter_name)
34
+ trainable_params = [
35
+ p for n, p in pipeline.transformer.named_parameters()
36
+ if adapter_name in n and p.requires_grad
37
+ ]
38
+ state_dict = torch.load(lora_path, map_location=pipeline.transformer.device)
39
+ ema_params = state_dict["ema_parameters"]
40
+ assert len(trainable_params) == len(ema_params), \
41
+ f"Parameter count mismatch: {len(trainable_params)} vs {len(ema_params)}"
42
+ for param, ema_param in zip(trainable_params, ema_params):
43
+ param.data.copy_(ema_param.to(param.device))
44
+ print(f"Loaded EMA weights for adapter '{adapter_name}' from {lora_path}")
45
+ pipeline = ZImagePipeline.from_pretrained(
46
+ "Tongyi-MAI/Z-Image-Turbo",
47
+ torch_dtype=torch.bfloat16,
48
+ low_cpu_mem_usage=False,
49
+ )
50
+ transformer_lora_config = LoraConfig(
51
+ r=32,
52
+ lora_alpha=64,
53
+ init_lora_weights="gaussian",
54
+ target_modules=["to_q", "to_k", "to_v", "to_out.0", "add_k_proj", "add_v_proj"],
55
+ )
56
+ pipeline.transformer = get_peft_model(
57
+ pipeline.transformer,
58
+ transformer_lora_config,
59
+ adapter_name="tdmr1",
60
+ )
61
+ load_ema(
62
+ pipeline,
63
+ lora_path="./tdmr1_zimage_ema.ckpt",
64
+ adapter_name="tdmr1",
65
+ )
66
+ pipeline = pipeline.to("cuda")
67
+ image = pipeline(
68
+ prompt=prompt,
69
+ height=1024,
70
+ width=1024,
71
+ num_inference_steps=5, # This actually results in 4 DiT forwards
72
+ guidance_scale=0.0,
73
+ generator=torch.Generator("cuda").manual_seed(xxx),
74
+ ).images[0]
75
+ image
76
+ ```
77
+
78
+ ## Contact
79
+
80
+ Please contact Yihong Luo (yluocg@connect.ust.hk) if you have any questions about this work.
81
+
82
+ ## Bibtex
83
+ ```
84
+ @misc{luo2025tdmr1,
85
+ title={TDM-R1: Reinforcing Few-Step Diffusion Models with Non-Differentiable Reward},
86
+ author={Yihong Luo and Tianyang Hu and Weijian Luo and Jing Tang},
87
+ year={2025},
88
+ eprint={TODO},
89
+ archivePrefix={arXiv},
90
+ primaryClass={cs.CV}
91
+ }
92
+ ```
93
+
94
+