Spaces:
Runtime error
Runtime error
| import base64 | |
| import pathlib | |
| import re | |
| import time | |
| from io import BytesIO | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from fastai.callback.core import Callback | |
| from fastai.learner import * | |
| from fastai.torch_core import TitledStr | |
| from html2image import Html2Image | |
| from min_dalle import MinDalle | |
| from torch import tensor, Tensor, float16, float32 | |
| from torch.distributions import Transform | |
| # These utility functions need to be in main (or otherwise where created) because fastai loads from that module, see: | |
| # https://docs.fast.ai/learner.html#load_learner | |
| from transformers import GPT2TokenizerFast | |
| # update requirements.txt with: | |
| # C:\Users\Grant\PycharmProjects\test_space\venv\Scripts\pip3.exe freeze > requirements.txt | |
| # Huggingface Spaces have 16GB RAM and 8 CPU cores | |
| # See https://huggingface.co/docs/hub/spaces-overview#hardware-resources | |
| pretrained_weights = 'gpt2' | |
| tokenizer = GPT2TokenizerFast.from_pretrained(pretrained_weights) | |
| def tokenize(text): | |
| toks = tokenizer.tokenize(text) | |
| return tensor(tokenizer.convert_tokens_to_ids(toks)) | |
| class TransformersTokenizer(Transform): | |
| def __init__(self, tokenizer): self.tokenizer = tokenizer | |
| def encodes(self, x): | |
| return x if isinstance(x, Tensor) else tokenize(x) | |
| def decodes(self, x): return TitledStr(self.tokenizer.decode(x.cpu().numpy())) | |
| class DropOutput(Callback): | |
| def after_pred(self): self.learn.pred = self.pred[0] | |
| # initialize only once | |
| # Takes about 2 minutes (126 seconds) to generate an image in Huggingface spaces on CPU | |
| model = MinDalle( | |
| models_root='./pretrained', | |
| dtype=float32, | |
| device='cpu', | |
| is_mega=True, | |
| is_reusable=True | |
| ) | |
| def gen_image(prompt): | |
| # See https://huggingface.co/spaces/pootow/min-dalle/blob/main/app.py | |
| # Hugging Space faces seems to run out of memory if grads are not disabled | |
| torch.set_grad_enabled(False) | |
| print(f'RUNNING gen_image with prompt: {prompt}') | |
| images = model.generate_images( | |
| text=prompt, | |
| seed=-1, | |
| grid_size=1, # grid size above 2 causes out of memory on 12 GB 3080Ti; grid size 2 gives 4 images | |
| is_seamless=False, | |
| temperature=1, | |
| top_k=256, | |
| supercondition_factor=16, | |
| is_verbose=True | |
| ) | |
| print('COMPLETED GENERATION') | |
| images = images.to('cpu').numpy() | |
| images = images.astype(np.uint8) | |
| return Image.fromarray(images[0]) | |
| gpu = False | |
| # init only once | |
| learner = load_learner('export.pkl', | |
| cpu=not gpu) # cpu=False uses GPU; make sure installed torch is GPU e.g. `pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116` | |
| def parse_monster_description(text): | |
| match = re.search(r"Description: (.*)", text) | |
| description = match.group(1) | |
| print(description.split('.')[0]) | |
| return description.split('.')[0] | |
| def gen_monster_text(name): | |
| prompt = f"Name: {name}\r\n" | |
| print(f'GENERATING MONSTER TEXT with prompt: {prompt}') | |
| prompt_ids = tokenizer.encode(prompt) | |
| if gpu: | |
| inp = tensor(prompt_ids)[None].cuda() # Use .cuda() for torch GPU | |
| else: | |
| inp = tensor(prompt_ids)[None] | |
| preds = learner.model.generate(inp, max_length=1024, num_beams=5, temperature=1.5, do_sample=True) | |
| result = tokenizer.decode(preds[0].cpu().numpy()) | |
| result = result.split('###')[0].replace(r'\r\n', '\n') | |
| print(f'GENERATING MONSTER COMPLETE') | |
| return result | |
| def extract_text_for_header(text, header): | |
| match = re.search(fr"{header}: (.*)", text) | |
| if match is None: | |
| return '' | |
| return match.group(1) | |
| def remove_section(html, html_class): | |
| match = re.search(f'<li class="{html_class}"([\w\W])*?li>', html) | |
| if match is not None: | |
| html = html.replace(match.group(0), '') | |
| return html | |
| def format_monster_card(monster_text, image_data): | |
| print('FORMATTING MONSTER TEXT') | |
| # see giffyglyph's monster maker https://giffyglyph.com/monstermaker/app/ | |
| # Different Formatting style examples and some json export formats | |
| card = pathlib.Path('monsterMakerTemplate.html').read_text() | |
| if not isinstance(image_data, (bytes, bytearray)): | |
| card = card.replace('{image_data}', f'{image_data}') | |
| else: | |
| card = card.replace('{image_data}', f'data:image/png;base64,{image_data.decode("utf-8")}') | |
| name = extract_text_for_header(monster_text, 'Name') | |
| card = card.replace('{name}', name) | |
| monster_type = extract_text_for_header(monster_text, 'Type') | |
| card = card.replace('{monster_type}', monster_type) | |
| armor_class = extract_text_for_header(monster_text, 'Armor Class') | |
| card = card.replace('{armor_class}', armor_class) | |
| hit_points = extract_text_for_header(monster_text, 'Hit Points') | |
| card = card.replace('{hit_points}', hit_points) | |
| speed = extract_text_for_header(monster_text, 'Speed') | |
| card = card.replace('{speed}', speed) | |
| str_stat = extract_text_for_header(monster_text, 'STR') | |
| card = card.replace('{str_stat}', str_stat) | |
| dex_stat = extract_text_for_header(monster_text, 'DEX') | |
| card = card.replace('{dex_stat}', dex_stat) | |
| con_stat = extract_text_for_header(monster_text, 'CON') | |
| card = card.replace('{con_stat}', con_stat) | |
| int_stat = extract_text_for_header(monster_text, 'INT') | |
| card = card.replace('{int_stat}', int_stat) | |
| wis_stat = extract_text_for_header(monster_text, 'WIS') | |
| card = card.replace('{wis_stat}', wis_stat) | |
| cha_stat = extract_text_for_header(monster_text, 'CHA') | |
| card = card.replace('{cha_stat}', cha_stat) | |
| saving_throws = extract_text_for_header(monster_text, 'Saving Throws') | |
| card = card.replace('{saving_throws}', saving_throws) | |
| if not saving_throws: | |
| card = remove_section(card, 'monster-saves') | |
| skills = extract_text_for_header(monster_text, 'Skills') | |
| card = card.replace('{skills}', skills) | |
| if not skills: | |
| card = remove_section(card, 'monster-skills') | |
| damage_vulnerabilities = extract_text_for_header(monster_text, 'Damage Vulnerabilities') | |
| card = card.replace('{damage_vulnerabilities}', damage_vulnerabilities) | |
| if not damage_vulnerabilities: | |
| card = remove_section(card, 'monster-vulnerabilities') | |
| damage_resistances = extract_text_for_header(monster_text, 'Damage Resistances') | |
| card = card.replace('{damage_resistances}', damage_resistances) | |
| if not damage_resistances: | |
| card = remove_section(card, 'monster-resistances') | |
| damage_immunities = extract_text_for_header(monster_text, 'Damage Immunities') | |
| card = card.replace('{damage_immunities}', damage_immunities) | |
| if not damage_immunities: | |
| card = remove_section(card, 'monster-immunities') | |
| condition_immunities = extract_text_for_header(monster_text, 'Condition Immunities') | |
| card = card.replace('{condition_immunities}', condition_immunities) | |
| if not condition_immunities: | |
| card = remove_section(card, 'monster-conditions') | |
| senses = extract_text_for_header(monster_text, 'Senses') | |
| card = card.replace('{senses}', senses) | |
| if not senses: | |
| card = remove_section(card, 'monster-senses') | |
| languages = extract_text_for_header(monster_text, 'Languages') | |
| card = card.replace('{languages}', languages) | |
| if not languages: | |
| card = remove_section(card, 'monster-languages') | |
| challenge = extract_text_for_header(monster_text, 'Challenge') | |
| card = card.replace('{challenge}', challenge) | |
| if not challenge: | |
| card = remove_section(card, 'monster-challenge') | |
| description = extract_text_for_header(monster_text, 'Description') | |
| card = card.replace('{description}', description) | |
| match = re.search(r"Passives:\n([\w\W]*)", monster_text) | |
| if match is None: | |
| passives = '' | |
| else: | |
| passives = match.group(1) | |
| p = passives.split(':') | |
| if len(p) > 1: | |
| p = ":".join(p) | |
| p = p.split('\n') | |
| passives_data = '' | |
| for x in p: | |
| x = x.split(':') | |
| if len(x) > 1: | |
| trait = x[0] | |
| if 'Action' in trait: | |
| break | |
| detail = ":".join(x[1:]) | |
| passives_data += f'<div class="monster-trait"><p><span class="name">{trait}</span> <span class="detail">{detail}</span></p></div>' | |
| card = card.replace('{passives}', passives_data) | |
| else: | |
| card = card.replace('{passives}', f'<div class="monster-trait"><p>{passives}</p></div>') | |
| match = re.search(r"Actions:\n([\w\W]*)", monster_text) | |
| if match is None: | |
| actions = '' | |
| else: | |
| actions = match.group(1) | |
| a = actions.split(':') | |
| if len(a) > 1: | |
| a = ":".join(a) | |
| a = a.split('\n') | |
| actions_data = '' | |
| for x in a: | |
| x = x.split(':') | |
| if len(x) > 1: | |
| action = x[0] | |
| if 'Passive' in action: | |
| break | |
| detail = ":".join(x[1:]) | |
| actions_data += f'<div class="monster-action"><p><span class="name">{action}</span> <span class="detail">{detail}</span></p></div>' | |
| card = card.replace('{actions}', actions_data) | |
| else: | |
| card = card.replace('{actions}', f'<div class="monster-action"><p>{actions}</p></div>') | |
| card = card.replace('Melee Weapon Attack:', '<i>Melee Weapon Attack:</i>') | |
| card = card.replace('Ranged Weapon Attack:', '<i>Ranged Weapon Attack:</i>') | |
| card = card.replace('Hit:', '<i>Hit:</i>') | |
| print('FORMATTING MONSTER TEXT COMPLETE') | |
| return card | |
| def pil_to_base64(image): | |
| print('CONVERTING PIL IMAGE TO BASE64 STRING') | |
| buffered = BytesIO() | |
| image.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()) | |
| print('CONVERTING PIL IMAGE TO BASE64 STRING COMPLETE') | |
| return img_str | |
| hti = Html2Image(output_path='rendered_cards') | |
| def html_to_png(html): | |
| print('CONVERTING HTML CARD TO PNG IMAGE') | |
| paths = hti.screenshot(html_str=html, css_file="monstermaker.css", save_as="test.png") | |
| path = paths[0] | |
| img = Image.open(path).convert("RGB") | |
| print('CONVERTING HTML CARD TO PNG IMAGE COMPLETE') | |
| return img | |
| def run(name): | |
| start = time.time() | |
| print(f'BEGINNING RUN FOR {name}') | |
| text = gen_monster_text(name) | |
| description = parse_monster_description(text) | |
| pil = gen_image(description) | |
| image_data = pil_to_base64(pil) | |
| card_html = format_monster_card(text, image_data) | |
| card_image = html_to_png(card_html) | |
| end = time.time() | |
| print(f'RUN COMPLETED IN {end - start} seconds') | |
| return text, pil, card_image, card_html | |
| iface = gr.Interface(fn=run, inputs="text", outputs=["text", "pil", "pil", 'html']) | |
| iface.launch() | |