Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import cv2 | |
| import io | |
| import queue | |
| import sys | |
| import threading | |
| import time | |
| from huggingface_hub import snapshot_download | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| PreTrainedTokenizerFast, | |
| ) | |
| class _ThreadLocalStdout: | |
| def __init__(self, original): | |
| self._local = threading.local() | |
| self._original = original | |
| def write(self, data): | |
| buf = getattr(self._local, "buffer", None) | |
| if buf is not None: | |
| buf.write(data) | |
| else: | |
| self._original.write(data) | |
| def flush(self): | |
| buf = getattr(self._local, "buffer", None) | |
| if buf is not None: | |
| buf.flush() | |
| else: | |
| self._original.flush() | |
| def __getattr__(self, name): | |
| return getattr(self._original, name) | |
| sys.stdout = _ThreadLocalStdout(sys.stdout) | |
| POOL_SIZE = 4 | |
| model_pool = queue.Queue() | |
| _ckpt_path = None | |
| def _load_one(ckpt): | |
| tok = PreTrainedTokenizerFast.from_pretrained(ckpt) | |
| m = AutoModelForCausalLM.from_pretrained(ckpt, trust_remote_code=True) | |
| m.eval() | |
| model_pool.put((m, tok)) | |
| def load_models(): | |
| global _ckpt_path | |
| _ckpt_path = snapshot_download("iamthinbaker/GPokeT2") | |
| for _ in range(POOL_SIZE): | |
| _load_one(_ckpt_path) | |
| TYPES = [ | |
| "normal", | |
| "fire", | |
| "water", | |
| "electric", | |
| "grass", | |
| "ice", | |
| "fighting", | |
| "poison", | |
| "ground", | |
| "flying", | |
| "psychic", | |
| "bug", | |
| "rock", | |
| "ghost", | |
| "dragon", | |
| "dark", | |
| "steel", | |
| "fairy", | |
| ] | |
| EMOJIS = { | |
| "normal": "⚪", | |
| "fire": "🔥", | |
| "water": "💧", | |
| "electric": "⚡", | |
| "grass": "🌿", | |
| "ice": "❄️", | |
| "fighting": "🥊", | |
| "poison": "☠️", | |
| "ground": "🪨", | |
| "flying": "🕊️", | |
| "psychic": "🔮", | |
| "bug": "🐛", | |
| "rock": "🪨", | |
| "ghost": "👻", | |
| "dragon": "🐉", | |
| "dark": "🌑", | |
| "steel": "⚙️", | |
| "fairy": "🧚", | |
| } | |
| MAX_TOKENS = 4096*2 + 64*6 | |
| def progress_bar_html(p: float) -> str: | |
| pct = int(min(p, 1.0) * 100) | |
| return ( | |
| f'<div style="background:#e0e0e0;border-radius:8px;height:22px;width:100%;margin:8px 0;">' | |
| f'<div style="background:#4CAF50;width:{pct}%;height:100%;border-radius:8px;transition:width 0.3s;"></div>' | |
| f'</div>' | |
| f'<p style="text-align:center;margin:2px 0 0;">{pct}%</p>' | |
| ) | |
| def generate(selected): | |
| type1, type2 = selected[0], selected[1] | |
| m, tok = model_pool.get() | |
| log_buffer = io.StringIO() | |
| result = {"image": None} | |
| def worker(): | |
| sys.stdout._local.buffer = log_buffer | |
| image = m.generate_sprite( | |
| tok, | |
| temperature=1.2, | |
| type1=type1, | |
| type2=type2, | |
| verbose=True, | |
| ) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| image = cv2.resize( | |
| image, | |
| (256, 256), | |
| interpolation=cv2.INTER_NEAREST, | |
| ) | |
| result["image"] = image | |
| thread = threading.Thread(target=worker) | |
| thread.start() | |
| previous_text = "" | |
| cancelled = False | |
| try: | |
| while thread.is_alive(): | |
| current_text = log_buffer.getvalue() | |
| if current_text != previous_text: | |
| previous_text = current_text | |
| yield None, current_text, progress_bar_html(len(current_text) / MAX_TOKENS) | |
| time.sleep(0.1) | |
| thread.join() | |
| final_logs = log_buffer.getvalue() | |
| yield result["image"], final_logs, progress_bar_html(1.0) | |
| except: | |
| cancelled = True | |
| raise | |
| finally: | |
| if cancelled: | |
| threading.Thread( | |
| target=_load_one, | |
| args=(_ckpt_path,), | |
| daemon=True, | |
| ).start() | |
| else: | |
| thread.join() | |
| model_pool.put((m, tok)) | |
| def toggle(t, selected): | |
| selected = selected.copy() | |
| if t in selected: | |
| selected.remove(t) | |
| elif len(selected) < 2: | |
| selected.append(t) | |
| return selected | |
| def update_ui(selected): | |
| btn_updates = [] | |
| for t in TYPES: | |
| if t in selected: | |
| btn_updates.append(gr.update(variant="primary")) | |
| else: | |
| btn_updates.append(gr.update(variant="secondary")) | |
| return btn_updates + [gr.update(interactive=(len(selected) == 2))] | |
| css = """ | |
| #gen-btn { background: #e53935 !important; border-color: #e53935 !important; color: white !important; } | |
| #gen-btn:hover { background: #b71c1c !important; border-color: #b71c1c !important; } | |
| #gen-btn:disabled { background: #ef9a9a !important; border-color: #ef9a9a !important; } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown("#GPokeT2 🔥🌱💧⚡") | |
| gr.Markdown("A GPT-2 trained to generates 64×64 Pokemon sprites as ASCII art, which is then decoded into an image. Select 2 types and see what it creates!") | |
| selected = gr.State([]) | |
| with open("default_text.txt", "r") as f: | |
| default_text = f.read() | |
| with gr.Row(): | |
| img = gr.Image( | |
| label="Pokemon", | |
| visible=True, | |
| value="default_image.webp", | |
| height=256, | |
| ) | |
| progress_bar = gr.HTML(progress_bar_html(0)) | |
| with gr.Row(): | |
| # LEFT | |
| with gr.Column(scale=1): | |
| gr.Markdown("# Select 2 types") | |
| buttons = [] | |
| with gr.Row(): | |
| for i,t in enumerate(TYPES): | |
| buttons.append( | |
| gr.Button( | |
| f"{EMOJIS[t]} {t.title()}", | |
| variant="secondary", | |
| ) | |
| ) | |
| gen = gr.Button( | |
| "Generate Pokemon!!", | |
| interactive=False, | |
| elem_id="gen-btn", | |
| ) | |
| # RIGHT | |
| with gr.Column(scale=2): | |
| gr.Markdown("# Generation Logs") | |
| logs = gr.Textbox( | |
| lines=21, | |
| max_lines=21, | |
| autoscroll=True, | |
| elem_id="logs-box", | |
| value=default_text, | |
| ) | |
| gen_event = gen.click( | |
| fn=generate, | |
| inputs=selected, | |
| outputs=[img, logs, progress_bar], | |
| concurrency_limit=POOL_SIZE, | |
| concurrency_id="inference", | |
| ) | |
| gen.click(fn=None, cancels=[gen_event]) | |
| # Wire up button clicks after gen is defined | |
| for i, t in enumerate(TYPES): | |
| buttons[i].click( | |
| fn=toggle, | |
| inputs=[gr.State(t), selected], | |
| outputs=selected, | |
| ).then( | |
| fn=update_ui, | |
| inputs=selected, | |
| outputs=buttons + [gen], | |
| ) | |
| demo.queue(max_size=10) | |
| load_models() | |
| demo.launch() | |