File size: 10,140 Bytes
6205838
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
import logging
import argparse
from transformers import AutoModelForMaskedLM
import torch
from rich.live import Live
from rich.console import Console
from rich.progress import Progress, BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
from rich.text import Text
from tokenizer import get_tokenizer
from safetensors.torch import load_file

logging.getLogger("transformers").setLevel(logging.ERROR)

def load_model_and_tokenizer(path_to_weights, hf_model_name, device="cuda"):
    ### Load Tokenizer ###
    tokenizer = get_tokenizer(hf_model_name)
    
    ### Load Model and Update Embedding Size ###
    model = AutoModelForMaskedLM.from_pretrained(hf_model_name, device_map=device)
    model.resize_token_embeddings(len(tokenizer))

    # Load your checkpoint
    state_dict = torch.load(path_to_weights)
    model.load_state_dict(state_dict, strict=True)
    model.tie_weights()
    model.eval()

    return model, tokenizer

def prepare_unconditional_tokens_for_inference(seq_len, mask_token_id, device="cuda"):
    input_tokens = torch.full((1, seq_len), mask_token_id, dtype=torch.long, device=device)
    mask = torch.ones((1, seq_len), dtype=torch.bool, device=device)
    attention_mask = torch.ones((1, seq_len), dtype=torch.long, device=device) 
    return input_tokens, mask, attention_mask

def prepare_conditional_tokens_for_inference(seq_len, tokenizer, prompt, device="cuda"):

    chat_template = [
        {"role": "user", "content": prompt}
    ]

    tokenized = tokenizer.apply_chat_template(
        chat_template,
        tokenize=True,
        add_special_tokens=True,
        add_generation_prompt=True
    )

    prompt_tokens = torch.tensor(tokenized).to(device)

    input_tokens, mask, attention_mask = prepare_unconditional_tokens_for_inference(
        seq_len, tokenizer.mask_token_id, device
    )

    input_tokens[0, :len(prompt_tokens)] = prompt_tokens

    mask[0, :len(prompt_tokens)] = False

    return input_tokens, mask, attention_mask

def format_display_for_qa(user_text, assistant_text):
    output = Text()
    output.append("USER: ", style="bold green")
    output.append(user_text + "\n\n")
    output.append("ASSISTANT: ", style="bold cyan")
    output.append(assistant_text, style="white")
    return output

def format_display_for_unconditional(gen_text):
    output = Text()
    output.append("Unconditional Generation: \n\n", style="bold green")
    output.append(gen_text, style="white")
    return output

def clean_text(raw_text: str) -> str:
    return (
        raw_text.replace("user", "")
        .replace("assistant", "")
        .strip()
    )

@torch.inference_mode()
def inference(tokenizer,
              model,
              num_steps, 
              strategy="random", 
              device="cuda",
              prompt=None,
              show_mask=True):

    if prompt is None:
        input_tokens, mask, attention_mask = prepare_unconditional_tokens_for_inference(args.seq_len, 
                                                                                        mask_token_id=tokenizer.mask_token_id,
                                                                                        device=args.device)
    else:
        input_tokens, mask, attention_mask = prepare_conditional_tokens_for_inference(args.seq_len, 
                                                                                      tokenizer=tokenizer,
                                                                                      prompt=args.prompt,
                                                                                      device=args.device)
    original_mask = mask.clone()

    ### Nice Printing Stuff ##
    console = Console(highlight=False)

    with Progress(
        TextColumn("[progress.description]{task.description}"),
        BarColumn(),
        "[progress.percentage]{task.percentage:>3.0f}%",
        TimeElapsedColumn(),
        TimeRemainingColumn(),
        console=console,
        transient=True,
    ) as progress:
        
        ### What Controls our Progress Bar ###
        task = progress.add_task("Generating...", total=num_steps)

        ### Get Timesteps for Inference ###
        times = torch.linspace(1, 0, num_steps + 1, device=device)

        with Live("", refresh_per_second=5, console=console) as live:
            for t, s in zip(times[:-1], times[1:]):

                if strategy == "backward":
                    logits = model(input_tokens, attention_mask=attention_mask).logits

                    probs = torch.softmax(logits[mask], dim=-1)
                    input_tokens[mask] = torch.multinomial(probs, num_samples=1).squeeze(-1)

                    remask_probs = torch.rand_like(mask, dtype=torch.float, device=device)
                    remask_probs = (remask_probs < s/t)
                    mask = mask & remask_probs
                    input_tokens[mask] = tokenizer.mask_token_id

                if strategy == "predictor_corrector":
                    logits = model(input_tokens, attention_mask=attention_mask).logits
                    
                    probs = torch.softmax(logits[mask], dim=-1)
                    input_tokens[mask] = torch.multinomial(probs, num_samples=1).squeeze(-1)

                    remask_probs = torch.rand_like(mask, dtype=torch.float, device=device)
                    remask_decision = (remask_probs < s/t)
                    
                    mask = mask & remask_decision 
                    input_tokens[mask] = tokenizer.mask_token_id

                    n_corrector_steps = 1
                    corrector_step_size = 0.5 * (t-s)/(1-s)

                    if n_corrector_steps > 0 and s > 0.3:
                        for _ in range(n_corrector_steps):
                            known_mask = ~mask ^ ~original_mask
                            noise_rng = torch.rand_like(known_mask, dtype=torch.float, device=device)
                            
                            to_remask = known_mask & (noise_rng < corrector_step_size)
                            
                            input_tokens[to_remask] = tokenizer.mask_token_id
                            
                            corr_logits = model(input_tokens, attention_mask=attention_mask).logits
                            
                            corr_probs = torch.softmax(corr_logits[to_remask], dim=-1)
                            corr_samples = torch.multinomial(corr_probs, num_samples=1).squeeze(-1)
                            
                            input_tokens[to_remask] = corr_samples
                
                if show_mask:
                    ### Get all of the Tokens ###
                    decoded_tokens = tokenizer.convert_ids_to_tokens(input_tokens[0])

                    ### Keep [MASK] tokens, drop all other special tokens ###
                    cleaned_tokens = []
                    for tok in decoded_tokens:
                        if tok == tokenizer.mask_token:  # keep mask tokens
                            cleaned_tokens.append(tok)
                        elif tok in tokenizer.all_special_tokens:  # drop all other specials
                            continue
                        else:
                            cleaned_tokens.append(tok)

                    ### Put all the tokens back together into a string ###
                    decoded_after = tokenizer.convert_tokens_to_string(cleaned_tokens)
                
                else:
                    decoded_after = tokenizer.batch_decode(input_tokens, skip_special_tokens=True)[0]

                if prompt is None:
                    format_text = format_display_for_unconditional(decoded_after)
                else:
                    ### Remove Prompt Text from Assistant Text ###
                    assistant_text = decoded_after.replace(prompt, "").strip()
                    ### Remove Keywords user and assistant ###
                    assistant_text = clean_text(assistant_text)
                    format_text = format_display_for_qa(prompt, assistant_text)
                live.update(format_text)
                progress.update(task, advance=1)

if __name__ == "__main__":
    parser = argparse.ArgumentParser("Inference LDM")
    parser.add_argument("--safetensors_path", required=True, type=str)
    parser.add_argument("--prompt", type=str, default=None)
    parser.add_argument("--seq_len", type=int, default=512)
    parser.add_argument("--num_steps", type=int, default=512)
    parser.add_argument("--strategy", type=str, default="predictor_corrector", choices=["backward", "predictor_corrector"])
    parser.add_argument("--hf_model_name", type=str, default="distilbert/distilroberta-base")
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--seed", type=int, default=1234)
    
    args = parser.parse_args()

    # class args:
    #     safetensors_path = "/kaggle/working/runs/Philosopher_v0/final_model/model.safetensors"
    #     prompt = "Generate a quote on life"
    #     seq_len = 64
    #     num_steps = 128
    #     strategy = "predictor_corrector"
    #     # strategy = "backward"
    #     hf_model_name = "answerdotai/ModernBERT-base"
    #     device = "cpu"
    #     seed = 1234
    
    # args = args()
    
    def seed_everything(seed: int):
        import random, os
        import numpy as np
        import torch
        
        random.seed(seed)
        os.environ['PYTHONHASHSEED'] = str(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        
    seed_everything(args.seed)

    ### Load Model ###
    model, tokenizer = load_model_and_tokenizer(args.safetensors_path, 
                                                args.hf_model_name, 
                                                args.device)

    inference(tokenizer,
              model,
              args.num_steps, 
              strategy=args.strategy, 
              device=args.device,
              prompt=args.prompt)