iiiutch commited on
Commit
82bc71f
·
verified ·
1 Parent(s): 16e9002

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +168 -3
README.md CHANGED
@@ -1,3 +1,168 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ base_model:
4
+ - GSAI-ML/LLaDA-8B-Instruct
5
+ ---
6
+
7
+
8
+ We provide the inference code below:
9
+
10
+ ```python
11
+ import torch
12
+ import transformers
13
+ from transformers.cache_utils import DynamicCache
14
+ # refer to https://github.com/iiiutch-ii/RemeDi/blob/main/RL-code
15
+ from networks.block_llada.modelling_llada_bitowel import LLaDAUPMModelLM
16
+
17
+ @torch.no_grad()
18
+ def generate_block_diffusion(
19
+ model,
20
+ conv,
21
+ tokenizer,
22
+ device,
23
+ num_generations,
24
+ kv_cache=None,
25
+ steps: int = 32,
26
+ max_length = 1024,
27
+ block_size = 32,
28
+ mask_token_id = 126336,
29
+ eos_id = 126081,
30
+ ):
31
+ m = [conv]
32
+ prompts = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
33
+ inputs = tokenizer(prompts, return_tensors='pt', padding=True, padding_side='left')
34
+ x_t = inputs['input_ids'].to(device)
35
+
36
+ attention_mask = inputs['attention_mask'].to(device)
37
+ prompt_len = attention_mask.sum(dim=1)
38
+ attn_bias = torch.where(
39
+ attention_mask + attention_mask.T > 0,
40
+ 0, -torch.inf
41
+ )[None, None].repeat(x_t.shape[0], 1, 1, 1)
42
+
43
+ x_t = x_t.repeat(num_generations, 1)
44
+ prompt_len = prompt_len.repeat(num_generations)
45
+ attn_bias = attn_bias.repeat(num_generations, 1, 1, 1)
46
+ batch_size = x_t.shape[0]
47
+
48
+ position_ids = torch.arange(x_t.shape[1], device=x_t.device, dtype=torch.long).unsqueeze(0) - (1 - attention_mask).sum(dim=-1)
49
+ if kv_cache is None:
50
+ kv_cache = DynamicCache()
51
+
52
+ # cache prompt first
53
+ with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
54
+ model(
55
+ x_t,
56
+ kv_cache=kv_cache,
57
+ update_kv_cache=True,
58
+ )
59
+
60
+ cur_blocks = 0
61
+ responses = [x_t]
62
+ is_eos_meet = torch.zeros((batch_size,), device=x_t.device, dtype=torch.bool)
63
+
64
+ while (cur_blocks * block_size) < max_length:
65
+ x_t = torch.full((batch_size, block_size), fill_value=mask_token_id, device=device, dtype=torch.long)
66
+
67
+ position_ids = torch.arange(
68
+ cur_blocks * block_size,
69
+ (cur_blocks + 1) * block_size,
70
+ device=x_t.device, dtype=torch.long).unsqueeze(0) + prompt_len.unsqueeze(1)
71
+
72
+ num_transfer_tokens = torch.tensor([block_size // steps for _ in range(steps)])
73
+ if block_size % steps != 0:
74
+ num_transfer_tokens[-block_size % steps:] += 1
75
+ # cumsum
76
+ num_transfer_tokens = num_transfer_tokens.cumsum(dim=0)
77
+
78
+ for i in range(steps):
79
+ mask_index = (x_t == mask_token_id)
80
+
81
+ with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
82
+ out = model(
83
+ x_t,
84
+ position_ids=position_ids,
85
+ kv_cache=kv_cache,
86
+ )
87
+ logits = out.logits.to(torch.float32)
88
+ x0 = torch.argmax(logits, dim=-1) # b, l
89
+ x0 = torch.where(mask_index, x0, x_t)
90
+
91
+ upm_prob = logits.gather(dim=-1, index=x0.unsqueeze(-1)).squeeze(-1)
92
+ samples = torch.topk(upm_prob, k=num_transfer_tokens[i], dim=-1).indices
93
+
94
+ bs_idx = torch.arange(batch_size, dtype=samples.dtype).unsqueeze(1)
95
+ remask_index = torch.ones_like(x_t).bool()
96
+ remask_index[bs_idx, samples] = False
97
+
98
+ x_t = torch.where(remask_index, mask_token_id, x0)
99
+
100
+ responses.append(x_t.clone())
101
+ cur_blocks += 1
102
+ if is_eos_meet.all(): break
103
+
104
+ # update kv_cache
105
+ with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
106
+ model(
107
+ x_t,
108
+ position_ids=position_ids,
109
+ kv_cache=kv_cache,
110
+ update_kv_cache=True,
111
+ )
112
+
113
+
114
+ response_tokens = torch.cat(responses, dim=1)
115
+ responses = []
116
+ responses_length = []
117
+ for i in range(batch_size):
118
+ if eos_id in response_tokens[i]:
119
+ eos_token_idx = (response_tokens[i] == eos_id).nonzero(as_tuple=True)[0][0].item()
120
+ resp_token = response_tokens[i, prompt_len[i]:eos_token_idx]
121
+ else:
122
+ resp_token = response_tokens[i, prompt_len[i]:]
123
+ responses.append(tokenizer.decode(resp_token, skip_special_tokens=True))
124
+ responses_length.append(resp_token.shape[0])
125
+
126
+ return responses
127
+
128
+ def main(
129
+ ckpt_path = 'iiiutch/RemeDi-Instruct',
130
+ seed: int = 112,
131
+ ):
132
+ torch.manual_seed(seed)
133
+ device = 'cuda'
134
+
135
+ tokenizer = transformers.AutoTokenizer.from_pretrained(ckpt_path)
136
+
137
+ model = LLaDAUPMModelLM.from_pretrained(
138
+ ckpt_path,
139
+ torch_dtype=torch.bfloat16,
140
+ )
141
+ model.eval().requires_grad_(False).to(device)
142
+
143
+ conv = []
144
+ while True:
145
+ conv = []
146
+ print('=' * 20)
147
+ prompt = input("User: ").strip()
148
+ print('Assistant: ', end='')
149
+ conv = [{'role': 'user', 'content': prompt}]
150
+
151
+ inputs = generate_block_diffusion(
152
+ model,
153
+ conv,
154
+ tokenizer,
155
+ reward_fn=None,
156
+ device=device,
157
+ viz=True,
158
+ num_generations=1,
159
+ steps=32, max_length=1024, block_size=32,
160
+ )
161
+
162
+ conv.append({'role': 'assistant', 'content': inputs[0]})
163
+
164
+
165
+ if __name__ == "__main__":
166
+ main()
167
+
168
+ ```