import gradio as gr import cv2 import io import queue import sys import threading import time from huggingface_hub import snapshot_download from transformers import ( AutoModelForCausalLM, PreTrainedTokenizerFast, ) class _ThreadLocalStdout: def __init__(self, original): self._local = threading.local() self._original = original def write(self, data): buf = getattr(self._local, "buffer", None) if buf is not None: buf.write(data) else: self._original.write(data) def flush(self): buf = getattr(self._local, "buffer", None) if buf is not None: buf.flush() else: self._original.flush() def __getattr__(self, name): return getattr(self._original, name) sys.stdout = _ThreadLocalStdout(sys.stdout) POOL_SIZE = 4 model_pool = queue.Queue() _ckpt_path = None def _load_one(ckpt): tok = PreTrainedTokenizerFast.from_pretrained(ckpt) m = AutoModelForCausalLM.from_pretrained(ckpt, trust_remote_code=True) m.eval() model_pool.put((m, tok)) def load_models(): global _ckpt_path _ckpt_path = snapshot_download("iamthinbaker/GPokeT2") for _ in range(POOL_SIZE): _load_one(_ckpt_path) TYPES = [ "normal", "fire", "water", "electric", "grass", "ice", "fighting", "poison", "ground", "flying", "psychic", "bug", "rock", "ghost", "dragon", "dark", "steel", "fairy", ] EMOJIS = { "normal": "โช", "fire": "๐ฅ", "water": "๐ง", "electric": "โก", "grass": "๐ฟ", "ice": "โ๏ธ", "fighting": "๐ฅ", "poison": "โ ๏ธ", "ground": "๐ชจ", "flying": "๐๏ธ", "psychic": "๐ฎ", "bug": "๐", "rock": "๐ชจ", "ghost": "๐ป", "dragon": "๐", "dark": "๐", "steel": "โ๏ธ", "fairy": "๐ง", } MAX_TOKENS = 4096*2 + 64*6 def progress_bar_html(p: float) -> str: pct = int(min(p, 1.0) * 100) return ( f'
{pct}%
' ) def generate(selected): type1, type2 = selected[0], selected[1] m, tok = model_pool.get() log_buffer = io.StringIO() result = {"image": None} def worker(): sys.stdout._local.buffer = log_buffer image = m.generate_sprite( tok, temperature=1.2, type1=type1, type2=type2, verbose=True, ) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = cv2.resize( image, (256, 256), interpolation=cv2.INTER_NEAREST, ) result["image"] = image thread = threading.Thread(target=worker) thread.start() previous_text = "" cancelled = False try: while thread.is_alive(): current_text = log_buffer.getvalue() if current_text != previous_text: previous_text = current_text yield None, current_text, progress_bar_html(len(current_text) / MAX_TOKENS) time.sleep(0.1) thread.join() final_logs = log_buffer.getvalue() yield result["image"], final_logs, progress_bar_html(1.0) except: cancelled = True raise finally: if cancelled: threading.Thread( target=_load_one, args=(_ckpt_path,), daemon=True, ).start() else: thread.join() model_pool.put((m, tok)) def toggle(t, selected): selected = selected.copy() if t in selected: selected.remove(t) elif len(selected) < 2: selected.append(t) return selected def update_ui(selected): btn_updates = [] for t in TYPES: if t in selected: btn_updates.append(gr.update(variant="primary")) else: btn_updates.append(gr.update(variant="secondary")) return btn_updates + [gr.update(interactive=(len(selected) == 2))] css = """ #gen-btn { background: #e53935 !important; border-color: #e53935 !important; color: white !important; } #gen-btn:hover { background: #b71c1c !important; border-color: #b71c1c !important; } #gen-btn:disabled { background: #ef9a9a !important; border-color: #ef9a9a !important; } """ with gr.Blocks(css=css) as demo: gr.Markdown("#GPokeT2 ๐ฅ๐ฑ๐งโก") gr.Markdown("A GPT-2 trained to generates 64ร64 Pokemon sprites as ASCII art, which is then decoded into an image. Select 2 types and see what it creates!") selected = gr.State([]) with open("default_text.txt", "r") as f: default_text = f.read() with gr.Row(): img = gr.Image( label="Pokemon", visible=True, value="default_image.webp", height=256, ) progress_bar = gr.HTML(progress_bar_html(0)) with gr.Row(): # LEFT with gr.Column(scale=1): gr.Markdown("# Select 2 types") buttons = [] with gr.Row(): for i,t in enumerate(TYPES): buttons.append( gr.Button( f"{EMOJIS[t]} {t.title()}", variant="secondary", ) ) gen = gr.Button( "Generate Pokemon!!", interactive=False, elem_id="gen-btn", ) # RIGHT with gr.Column(scale=2): gr.Markdown("# Generation Logs") logs = gr.Textbox( lines=21, max_lines=21, autoscroll=True, elem_id="logs-box", value=default_text, ) gen_event = gen.click( fn=generate, inputs=selected, outputs=[img, logs, progress_bar], concurrency_limit=POOL_SIZE, concurrency_id="inference", ) gen.click(fn=None, cancels=[gen_event]) # Wire up button clicks after gen is defined for i, t in enumerate(TYPES): buttons[i].click( fn=toggle, inputs=[gr.State(t), selected], outputs=selected, ).then( fn=update_ui, inputs=selected, outputs=buttons + [gen], ) demo.queue(max_size=10) load_models() demo.launch()