Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from datasets import load_dataset, Image | |
| import torch | |
| import nltk | |
| import io | |
| import base64 | |
| import shutil | |
| from torchvision import transforms | |
| from pytorch_pretrained_biggan import BigGAN, one_hot_from_names, truncated_noise_sample | |
| class PreTrainedPipeline(): | |
| def __init__(self, path=""): | |
| """ | |
| Initialize model | |
| """ | |
| nltk.download('wordnet') | |
| self.model = BigGAN.from_pretrained(path) | |
| self.truncation = 0.1 | |
| def __call__(self, inputs: str): | |
| """ | |
| Args: | |
| inputs (:obj:`str`): | |
| a string containing some text | |
| Return: | |
| A :obj:`PIL.Image` with the raw image representation as PIL. | |
| """ | |
| class_vector = one_hot_from_names([inputs], batch_size=1) | |
| if type(class_vector) == type(None): | |
| raise ValueError("Input is not in ImageNet") | |
| noise_vector = truncated_noise_sample(truncation=self.truncation, batch_size=1) | |
| noise_vector = torch.from_numpy(noise_vector) | |
| class_vector = torch.from_numpy(class_vector) | |
| with torch.no_grad(): | |
| output = self.model(noise_vector, class_vector, self.truncation) | |
| # Scale image | |
| img = output[0] | |
| img = (img + 1) / 2.0 | |
| img = transforms.ToPILImage()(img) | |
| dataset = load_dataset("botmaster/mother-2-battle-sprites", split="train") | |
| gr.Interface.load("models/templates/text-to-image").launch() | |