File size: 3,939 Bytes
38b2ab2
 
 
 
 
3af899a
38b2ab2
 
3af899a
2726ef7
38b2ab2
3af899a
 
38b2ab2
3af899a
 
 
 
 
 
 
 
 
38b2ab2
 
3af899a
 
 
 
 
 
 
38b2ab2
 
3af899a
 
 
 
 
 
38b2ab2
3af899a
 
 
 
 
 
 
 
 
 
 
 
 
 
38b2ab2
3af899a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38b2ab2
3af899a
 
 
 
 
 
 
38b2ab2
3af899a
 
 
 
38b2ab2
3af899a
 
 
 
 
 
 
 
 
38b2ab2
3af899a
 
 
 
38b2ab2
3af899a
 
 
38b2ab2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
---
library_name: diffusers
pipeline_tag: text-to-image
---

# TDM-R1: Reinforcing Few-Step Diffusion Models with Non-Differentiable Reward

<div align="center Lark">
  <a href="https://luo-yihong.github.io/TDM-R1-Page/"><img src="https://img.shields.io/static/v1?label=Project%20Page&message=Github&color=blue&logo=github-pages"></a> &ensp;
  <a href="https://arxiv.org/abs/2603.07700"><img src="https://img.shields.io/static/v1?label=Paper&message=Arxiv:TDM-R1&color=red&logo=arxiv"></a> &ensp;
  <a href="https://github.com/Luo-Yihong/TDM-R1"><img src="https://img.shields.io/static/v1?label=Code&message=Github&color=green&logo=github"></a>
</div>

This is the Official Repository of "[TDM-R1: Reinforcing Few-Step Diffusion Models with Non-Differentiable Reward](https://arxiv.org/abs/2603.07700)", by *Yihong Luo, Tianyang Hu, Weijian Luo, Jing Tang*.

<div align="center">
  <img src="teaser_git.png" width="100%">
</div>

<p align="center">
  Samples generated by <b>TDM-R1</b> using only <b>4 NFEs</b>, obtained by reinforcing the recent powerful Z-Image model.
</p>

## Description
TDM-R1 is a reinforcement learning (RL) paradigm for few-step generative models. It decouples the learning process into surrogate reward learning and generator learning, allowing for the use of non-differentiable rewards (e.g., human preference, object counts). This repository contains the reinforced version of the [Z-Image-Turbo](https://huggingface.co/Tongyi-MAI/Z-Image-Turbo) model.

## Pre-trained Model

- [TDM-R1-ZImage](https://huggingface.co/Luo-Yihong/TDM-R1)

## Usage

You can use this model with `diffusers` and `peft`. Below is an example of how to load the weights as a LoRA adapter.

```python
import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
import torch
from diffusers import ZImagePipeline
from peft import LoraConfig, get_peft_model

def load_ema(pipeline, lora_path, adapter_name='default'):
    """Load EMA weights into the pipeline's transformer adapter"""
    pipeline.transformer.set_adapter(adapter_name)
    trainable_params = [
        p for n, p in pipeline.transformer.named_parameters()
        if adapter_name in n and p.requires_grad
    ]
    state_dict = torch.load(lora_path, map_location=pipeline.transformer.device)
    ema_params = state_dict["ema_parameters"]
    assert len(trainable_params) == len(ema_params), \
        f"Parameter count mismatch: {len(trainable_params)} vs {len(ema_params)}"
    for param, ema_param in zip(trainable_params, ema_params):
        param.data.copy_(ema_param.to(param.device))
    print(f"Loaded EMA weights for adapter '{adapter_name}' from {lora_path}")

pipeline = ZImagePipeline.from_pretrained(
    "Tongyi-MAI/Z-Image-Turbo",
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=False,
)
transformer_lora_config = LoraConfig(
    r=32,
    lora_alpha=64,
    init_lora_weights="gaussian",
    target_modules=["to_q", "to_k", "to_v", "to_out.0", "add_k_proj", "add_v_proj"],
)
pipeline.transformer = get_peft_model(
    pipeline.transformer,
    transformer_lora_config,
    adapter_name="tdmr1",
)
# Ensure the checkpoint file is downloaded locally
load_ema(
    pipeline,
    lora_path="./tdmr1_zimage_ema.ckpt",
    adapter_name="tdmr1",
)
pipeline = pipeline.to("cuda")
image = pipeline(
      prompt="A high quality photo of a cat",
      height=1024,
      width=1024,
      num_inference_steps=5,  # This actually results in 4 DiT forwards
      guidance_scale=0.0, 
      generator=torch.Generator("cuda").manual_seed(42),
  ).images[0]
image
```

## Contact

Please contact Yihong Luo (yluocg@connect.ust.hk) if you have any questions about this work.

## Bibtex
```bibtex
@misc{luo2025tdmr1,
  title={TDM-R1: Reinforcing Few-Step Diffusion Models with Non-Differentiable Reward},
  author={Yihong Luo and Tianyang Hu and Weijian Luo and Jing Tang},
  year={2025},
  eprint={2603.07700},
  archivePrefix={arXiv},
  primaryClass={cs.CV}
}
```