Spaces:
Runtime error
Runtime error
| """ | |
| This script provides an exmaple to wrap TencentPretrain for generation. | |
| Given the beginning of a text, language model generates the rest. | |
| """ | |
| import sys | |
| import os | |
| import argparse | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| from PIL import Image | |
| tencentpretrain_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | |
| sys.path.append(tencentpretrain_dir) | |
| from tencentpretrain.embeddings import * | |
| from tencentpretrain.layers import * | |
| from tencentpretrain.encoders import * | |
| from tencentpretrain.targets import * | |
| from tencentpretrain.utils.constants import * | |
| from tencentpretrain.utils import * | |
| from tencentpretrain.utils.config import load_hyperparam | |
| from tencentpretrain.model_loader import load_model | |
| from tencentpretrain.opts import model_opts, tokenizer_opts | |
| from scripts.generate_lm import top_k_top_p_filtering | |
| from tencentpretrain.utils.image_tokenizer import * | |
| class GenerateLm(torch.nn.Module): | |
| def __init__(self, args): | |
| super(GenerateLm, self).__init__() | |
| self.embedding = Embedding(args) | |
| for embedding_name in args.embedding: | |
| tmp_emb = str2embedding[embedding_name](args, len(args.tokenizer.vocab)) | |
| self.embedding.update(tmp_emb, embedding_name) | |
| self.encoder = str2encoder[args.encoder](args) | |
| self.target = Target() | |
| self.target.update(LmTarget(args, len(args.tokenizer.vocab)), "lm") | |
| def forward(self, src, seg): | |
| emb = self.embedding(src, seg) | |
| output = self.encoder(emb, seg) | |
| output = self.target.lm.output_layer(output) | |
| return output | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
| # Model options. | |
| model_opts(parser) | |
| # Inference options. | |
| parser.add_argument("--load_model_path", default=None, type=str, | |
| help="Path of the input model.") | |
| parser.add_argument("--config_path", type=str, required=True, | |
| help="Path of the config file.") | |
| parser.add_argument("--seq_length", type=int, default=48, | |
| help="Sequence length.") | |
| parser.add_argument("--samples_num", type=int, default=10, | |
| help="Number of iterations for sampling.") | |
| parser.add_argument("--prompt", choices=["to_attributes", "to_caption", "to_image"], default="to_attributes", | |
| help="Prompt that indicates the output format.") | |
| parser.add_argument("--text_prefix_path", type=str, default=None, | |
| help="Text prefix for to_attributes and to_image.") | |
| parser.add_argument("--image_prefix_path", type=str, default=None, | |
| help="Input image path.") | |
| parser.add_argument("--top_k", type=int, default=70) | |
| parser.add_argument("--top_p", type=float, default=0) | |
| parser.add_argument("--temperature", type=float, default=1.0) | |
| tokenizer_opts(parser) | |
| args = parser.parse_args() | |
| args.batch_size = 1 | |
| args = load_hyperparam(args) | |
| args.tokenizer = str2tokenizer[args.tokenizer](args) | |
| args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| args.text_prefix = None | |
| if args.text_prefix_path is not None: | |
| with open(args.text_prefix_path, "r") as f: | |
| args.text_prefix = f.readline() | |
| def preprocess_vqgan(x): | |
| x = 2.*x - 1. | |
| return x | |
| def convert_color(image, channel): | |
| if channel == 3: | |
| if image.mode == "RGBA": | |
| r, g, b, a = image.split() | |
| image = Image.merge("RGB", (r, g, b)) | |
| elif image.mode != "RGB": | |
| image = image.convert("RGBA") | |
| r, g, b, a = image.split() | |
| image = Image.merge("RGB", (r, g, b)) | |
| elif channel == 1: | |
| image = image.convert("L") | |
| return image | |
| transform = transforms.Compose([ | |
| transforms.Lambda(lambda img: convert_color(img, 3)), | |
| transforms.Resize((args.image_height, args.image_width)), | |
| transforms.ToTensor(), | |
| transforms.Lambda(lambda x: preprocess_vqgan(x)), | |
| ]) | |
| model = GenerateLm(args) | |
| model = load_model(model, args.load_model_path) | |
| model = model.to(args.device) | |
| model.eval() | |
| vqgan = build_vqgan_model(args) | |
| vqgan = vqgan.to(args.device) | |
| prompt = " ".join(args.prompt.split("_")) | |
| PAD_ID, CLS_ID, SEP_ID, MASK_ID = 0, 101, 102, 103 | |
| if args.image_prefix_path is None: | |
| src = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + args.tokenizer.tokenize(prompt) + [SEP_TOKEN] + | |
| args.tokenizer.tokenize(args.text_prefix) + [SEP_TOKEN]) | |
| if len(src) > 64: | |
| src = src[:64] | |
| seg = [1] * len(src) | |
| else: | |
| image = Image.open(args.image_prefix_path) | |
| image = transform(image).to(args.device) | |
| image_token = image_tokenize(vqgan, image) | |
| p_src = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + args.tokenizer.tokenize(prompt) + [SEP_TOKEN]) | |
| src = p_src + [i + args.tokenizer.vocab_bias for i in image_token] + [SEP_ID] | |
| seg = [1] * len(src) | |
| if args.text_prefix is not None: | |
| attr_prompt_src = args.tokenizer.convert_tokens_to_ids(args.tokenizer.tokenize(args.text_prefix)) | |
| src = src + attr_prompt_src | |
| seg = seg + [2] * len(attr_prompt_src) | |
| beginning_length = len(src) | |
| for r in range(args.samples_num): | |
| src_tensor, seg_tensor = torch.LongTensor([src]), torch.LongTensor([seg]) | |
| for i in range(args.seq_length - beginning_length): | |
| src_tensor = src_tensor.to(args.device) | |
| seg_tensor = seg_tensor.to(args.device) | |
| with torch.no_grad(): | |
| output = model(src_tensor, seg_tensor) | |
| next_token_logits = output[0][-1] / args.temperature | |
| filtered_logits = top_k_top_p_filtering(next_token_logits, args.top_k, args.top_p) | |
| next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) | |
| src_tensor = torch.cat([src_tensor, next_token.view(1, 1)], dim=1) | |
| seg_tensor = torch.cat([seg_tensor, torch.tensor([[2]], device=args.device)], dim=1) | |
| if args.image_prefix_path is not None: | |
| text_id = [str(token_id.item()) for token_id in src_tensor[0][beginning_length:]] | |
| text_id = " ".join(text_id) | |
| text_id = text_id.split(str(SEP_ID))[0].strip().split(" ") | |
| generated_sentence = " ".join( | |
| args.tokenizer.convert_ids_to_tokens([int(i) for i in text_id]) | |
| ) | |
| print("output " + str(r) + ":" + "\n") | |
| print(generated_sentence + "\n") | |
| else: | |
| image_id = [token_id.item() for token_id in src_tensor[0][beginning_length:]] | |
| img_length = (args.image_height // args.image_tokenizer["frame_size"]) ** 2 | |
| img_seg = [i - args.tokenizer.vocab_bias for i in image_id[: img_length]] | |
| image_detokenize(vqgan, img_seg, args.image_tokenizer["image_vocab_size"], False, "output-" + str(r) + ".jpg") | |