|
|
import logging |
|
|
import argparse |
|
|
from transformers import AutoModelForMaskedLM |
|
|
import torch |
|
|
from rich.live import Live |
|
|
from rich.console import Console |
|
|
from rich.progress import Progress, BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn |
|
|
from rich.text import Text |
|
|
from tokenizer import get_tokenizer |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
logging.getLogger("transformers").setLevel(logging.ERROR) |
|
|
|
|
|
def load_model_and_tokenizer(path_to_weights, hf_model_name, device="cuda"): |
|
|
|
|
|
tokenizer = get_tokenizer(hf_model_name) |
|
|
|
|
|
|
|
|
model = AutoModelForMaskedLM.from_pretrained(hf_model_name, device_map=device) |
|
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
|
|
|
|
|
state_dict = torch.load(path_to_weights) |
|
|
model.load_state_dict(state_dict, strict=True) |
|
|
model.tie_weights() |
|
|
model.eval() |
|
|
|
|
|
return model, tokenizer |
|
|
|
|
|
def prepare_unconditional_tokens_for_inference(seq_len, mask_token_id, device="cuda"): |
|
|
input_tokens = torch.full((1, seq_len), mask_token_id, dtype=torch.long, device=device) |
|
|
mask = torch.ones((1, seq_len), dtype=torch.bool, device=device) |
|
|
attention_mask = torch.ones((1, seq_len), dtype=torch.long, device=device) |
|
|
return input_tokens, mask, attention_mask |
|
|
|
|
|
def prepare_conditional_tokens_for_inference(seq_len, tokenizer, prompt, device="cuda"): |
|
|
|
|
|
chat_template = [ |
|
|
{"role": "user", "content": prompt} |
|
|
] |
|
|
|
|
|
tokenized = tokenizer.apply_chat_template( |
|
|
chat_template, |
|
|
tokenize=True, |
|
|
add_special_tokens=True, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
|
|
|
prompt_tokens = torch.tensor(tokenized).to(device) |
|
|
|
|
|
input_tokens, mask, attention_mask = prepare_unconditional_tokens_for_inference( |
|
|
seq_len, tokenizer.mask_token_id, device |
|
|
) |
|
|
|
|
|
input_tokens[0, :len(prompt_tokens)] = prompt_tokens |
|
|
|
|
|
mask[0, :len(prompt_tokens)] = False |
|
|
|
|
|
return input_tokens, mask, attention_mask |
|
|
|
|
|
def format_display_for_qa(user_text, assistant_text): |
|
|
output = Text() |
|
|
output.append("USER: ", style="bold green") |
|
|
output.append(user_text + "\n\n") |
|
|
output.append("ASSISTANT: ", style="bold cyan") |
|
|
output.append(assistant_text, style="white") |
|
|
return output |
|
|
|
|
|
def format_display_for_unconditional(gen_text): |
|
|
output = Text() |
|
|
output.append("Unconditional Generation: \n\n", style="bold green") |
|
|
output.append(gen_text, style="white") |
|
|
return output |
|
|
|
|
|
def clean_text(raw_text: str) -> str: |
|
|
return ( |
|
|
raw_text.replace("user", "") |
|
|
.replace("assistant", "") |
|
|
.strip() |
|
|
) |
|
|
|
|
|
@torch.inference_mode() |
|
|
def inference(tokenizer, |
|
|
model, |
|
|
num_steps, |
|
|
strategy="random", |
|
|
device="cuda", |
|
|
prompt=None, |
|
|
show_mask=True): |
|
|
|
|
|
if prompt is None: |
|
|
input_tokens, mask, attention_mask = prepare_unconditional_tokens_for_inference(args.seq_len, |
|
|
mask_token_id=tokenizer.mask_token_id, |
|
|
device=args.device) |
|
|
else: |
|
|
input_tokens, mask, attention_mask = prepare_conditional_tokens_for_inference(args.seq_len, |
|
|
tokenizer=tokenizer, |
|
|
prompt=args.prompt, |
|
|
device=args.device) |
|
|
original_mask = mask.clone() |
|
|
|
|
|
|
|
|
console = Console(highlight=False) |
|
|
|
|
|
with Progress( |
|
|
TextColumn("[progress.description]{task.description}"), |
|
|
BarColumn(), |
|
|
"[progress.percentage]{task.percentage:>3.0f}%", |
|
|
TimeElapsedColumn(), |
|
|
TimeRemainingColumn(), |
|
|
console=console, |
|
|
transient=True, |
|
|
) as progress: |
|
|
|
|
|
|
|
|
task = progress.add_task("Generating...", total=num_steps) |
|
|
|
|
|
|
|
|
times = torch.linspace(1, 0, num_steps + 1, device=device) |
|
|
|
|
|
with Live("", refresh_per_second=5, console=console) as live: |
|
|
for t, s in zip(times[:-1], times[1:]): |
|
|
|
|
|
if strategy == "backward": |
|
|
logits = model(input_tokens, attention_mask=attention_mask).logits |
|
|
|
|
|
probs = torch.softmax(logits[mask], dim=-1) |
|
|
input_tokens[mask] = torch.multinomial(probs, num_samples=1).squeeze(-1) |
|
|
|
|
|
remask_probs = torch.rand_like(mask, dtype=torch.float, device=device) |
|
|
remask_probs = (remask_probs < s/t) |
|
|
mask = mask & remask_probs |
|
|
input_tokens[mask] = tokenizer.mask_token_id |
|
|
|
|
|
if strategy == "predictor_corrector": |
|
|
logits = model(input_tokens, attention_mask=attention_mask).logits |
|
|
|
|
|
probs = torch.softmax(logits[mask], dim=-1) |
|
|
input_tokens[mask] = torch.multinomial(probs, num_samples=1).squeeze(-1) |
|
|
|
|
|
remask_probs = torch.rand_like(mask, dtype=torch.float, device=device) |
|
|
remask_decision = (remask_probs < s/t) |
|
|
|
|
|
mask = mask & remask_decision |
|
|
input_tokens[mask] = tokenizer.mask_token_id |
|
|
|
|
|
n_corrector_steps = 1 |
|
|
corrector_step_size = 0.5 * (t-s)/(1-s) |
|
|
|
|
|
if n_corrector_steps > 0 and s > 0.3: |
|
|
for _ in range(n_corrector_steps): |
|
|
known_mask = ~mask ^ ~original_mask |
|
|
noise_rng = torch.rand_like(known_mask, dtype=torch.float, device=device) |
|
|
|
|
|
to_remask = known_mask & (noise_rng < corrector_step_size) |
|
|
|
|
|
input_tokens[to_remask] = tokenizer.mask_token_id |
|
|
|
|
|
corr_logits = model(input_tokens, attention_mask=attention_mask).logits |
|
|
|
|
|
corr_probs = torch.softmax(corr_logits[to_remask], dim=-1) |
|
|
corr_samples = torch.multinomial(corr_probs, num_samples=1).squeeze(-1) |
|
|
|
|
|
input_tokens[to_remask] = corr_samples |
|
|
|
|
|
if show_mask: |
|
|
|
|
|
decoded_tokens = tokenizer.convert_ids_to_tokens(input_tokens[0]) |
|
|
|
|
|
|
|
|
cleaned_tokens = [] |
|
|
for tok in decoded_tokens: |
|
|
if tok == tokenizer.mask_token: |
|
|
cleaned_tokens.append(tok) |
|
|
elif tok in tokenizer.all_special_tokens: |
|
|
continue |
|
|
else: |
|
|
cleaned_tokens.append(tok) |
|
|
|
|
|
|
|
|
decoded_after = tokenizer.convert_tokens_to_string(cleaned_tokens) |
|
|
|
|
|
else: |
|
|
decoded_after = tokenizer.batch_decode(input_tokens, skip_special_tokens=True)[0] |
|
|
|
|
|
if prompt is None: |
|
|
format_text = format_display_for_unconditional(decoded_after) |
|
|
else: |
|
|
|
|
|
assistant_text = decoded_after.replace(prompt, "").strip() |
|
|
|
|
|
assistant_text = clean_text(assistant_text) |
|
|
format_text = format_display_for_qa(prompt, assistant_text) |
|
|
live.update(format_text) |
|
|
progress.update(task, advance=1) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser("Inference LDM") |
|
|
parser.add_argument("--safetensors_path", required=True, type=str) |
|
|
parser.add_argument("--prompt", type=str, default=None) |
|
|
parser.add_argument("--seq_len", type=int, default=512) |
|
|
parser.add_argument("--num_steps", type=int, default=512) |
|
|
parser.add_argument("--strategy", type=str, default="predictor_corrector", choices=["backward", "predictor_corrector"]) |
|
|
parser.add_argument("--hf_model_name", type=str, default="distilbert/distilroberta-base") |
|
|
parser.add_argument("--device", type=str, default="cuda") |
|
|
parser.add_argument("--seed", type=int, default=1234) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def seed_everything(seed: int): |
|
|
import random, os |
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
random.seed(seed) |
|
|
os.environ['PYTHONHASHSEED'] = str(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed(seed) |
|
|
torch.backends.cudnn.deterministic = True |
|
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
seed_everything(args.seed) |
|
|
|
|
|
|
|
|
model, tokenizer = load_model_and_tokenizer(args.safetensors_path, |
|
|
args.hf_model_name, |
|
|
args.device) |
|
|
|
|
|
inference(tokenizer, |
|
|
model, |
|
|
args.num_steps, |
|
|
strategy=args.strategy, |
|
|
device=args.device, |
|
|
prompt=args.prompt) |