GPokeT2 / app.py
DelgadoPanadero's picture
change readme
c2bf84d
Raw
History Blame Contribute Delete
6.78 kB
import gradio as gr
import cv2
import io
import queue
import sys
import threading
import time
from huggingface_hub import snapshot_download
from transformers import (
AutoModelForCausalLM,
PreTrainedTokenizerFast,
)
class _ThreadLocalStdout:
def __init__(self, original):
self._local = threading.local()
self._original = original
def write(self, data):
buf = getattr(self._local, "buffer", None)
if buf is not None:
buf.write(data)
else:
self._original.write(data)
def flush(self):
buf = getattr(self._local, "buffer", None)
if buf is not None:
buf.flush()
else:
self._original.flush()
def __getattr__(self, name):
return getattr(self._original, name)
sys.stdout = _ThreadLocalStdout(sys.stdout)
POOL_SIZE = 4
model_pool = queue.Queue()
_ckpt_path = None
def _load_one(ckpt):
tok = PreTrainedTokenizerFast.from_pretrained(ckpt)
m = AutoModelForCausalLM.from_pretrained(ckpt, trust_remote_code=True)
m.eval()
model_pool.put((m, tok))
def load_models():
global _ckpt_path
_ckpt_path = snapshot_download("iamthinbaker/GPokeT2")
for _ in range(POOL_SIZE):
_load_one(_ckpt_path)
TYPES = [
"normal",
"fire",
"water",
"electric",
"grass",
"ice",
"fighting",
"poison",
"ground",
"flying",
"psychic",
"bug",
"rock",
"ghost",
"dragon",
"dark",
"steel",
"fairy",
]
EMOJIS = {
"normal": "⚪",
"fire": "🔥",
"water": "💧",
"electric": "⚡",
"grass": "🌿",
"ice": "❄️",
"fighting": "🥊",
"poison": "☠️",
"ground": "🪨",
"flying": "🕊️",
"psychic": "🔮",
"bug": "🐛",
"rock": "🪨",
"ghost": "👻",
"dragon": "🐉",
"dark": "🌑",
"steel": "⚙️",
"fairy": "🧚",
}
MAX_TOKENS = 4096*2 + 64*6
def progress_bar_html(p: float) -> str:
pct = int(min(p, 1.0) * 100)
return (
f'<div style="background:#e0e0e0;border-radius:8px;height:22px;width:100%;margin:8px 0;">'
f'<div style="background:#4CAF50;width:{pct}%;height:100%;border-radius:8px;transition:width 0.3s;"></div>'
f'</div>'
f'<p style="text-align:center;margin:2px 0 0;">{pct}%</p>'
)
def generate(selected):
type1, type2 = selected[0], selected[1]
m, tok = model_pool.get()
log_buffer = io.StringIO()
result = {"image": None}
def worker():
sys.stdout._local.buffer = log_buffer
image = m.generate_sprite(
tok,
temperature=1.2,
type1=type1,
type2=type2,
verbose=True,
)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(
image,
(256, 256),
interpolation=cv2.INTER_NEAREST,
)
result["image"] = image
thread = threading.Thread(target=worker)
thread.start()
previous_text = ""
cancelled = False
try:
while thread.is_alive():
current_text = log_buffer.getvalue()
if current_text != previous_text:
previous_text = current_text
yield None, current_text, progress_bar_html(len(current_text) / MAX_TOKENS)
time.sleep(0.1)
thread.join()
final_logs = log_buffer.getvalue()
yield result["image"], final_logs, progress_bar_html(1.0)
except:
cancelled = True
raise
finally:
if cancelled:
threading.Thread(
target=_load_one,
args=(_ckpt_path,),
daemon=True,
).start()
else:
thread.join()
model_pool.put((m, tok))
def toggle(t, selected):
selected = selected.copy()
if t in selected:
selected.remove(t)
elif len(selected) < 2:
selected.append(t)
return selected
def update_ui(selected):
btn_updates = []
for t in TYPES:
if t in selected:
btn_updates.append(gr.update(variant="primary"))
else:
btn_updates.append(gr.update(variant="secondary"))
return btn_updates + [gr.update(interactive=(len(selected) == 2))]
css = """
#gen-btn { background: #e53935 !important; border-color: #e53935 !important; color: white !important; }
#gen-btn:hover { background: #b71c1c !important; border-color: #b71c1c !important; }
#gen-btn:disabled { background: #ef9a9a !important; border-color: #ef9a9a !important; }
"""
with gr.Blocks(css=css) as demo:
gr.Markdown("#GPokeT2 🔥🌱💧⚡")
gr.Markdown("A GPT-2 trained to generates 64×64 Pokemon sprites as ASCII art, which is then decoded into an image. Select 2 types and see what it creates!")
selected = gr.State([])
with open("default_text.txt", "r") as f:
default_text = f.read()
with gr.Row():
img = gr.Image(
label="Pokemon",
visible=True,
value="default_image.webp",
height=256,
)
progress_bar = gr.HTML(progress_bar_html(0))
with gr.Row():
# LEFT
with gr.Column(scale=1):
gr.Markdown("# Select 2 types")
buttons = []
with gr.Row():
for i,t in enumerate(TYPES):
buttons.append(
gr.Button(
f"{EMOJIS[t]} {t.title()}",
variant="secondary",
)
)
gen = gr.Button(
"Generate Pokemon!!",
interactive=False,
elem_id="gen-btn",
)
# RIGHT
with gr.Column(scale=2):
gr.Markdown("# Generation Logs")
logs = gr.Textbox(
lines=21,
max_lines=21,
autoscroll=True,
elem_id="logs-box",
value=default_text,
)
gen_event = gen.click(
fn=generate,
inputs=selected,
outputs=[img, logs, progress_bar],
concurrency_limit=POOL_SIZE,
concurrency_id="inference",
)
gen.click(fn=None, cancels=[gen_event])
# Wire up button clicks after gen is defined
for i, t in enumerate(TYPES):
buttons[i].click(
fn=toggle,
inputs=[gr.State(t), selected],
outputs=selected,
).then(
fn=update_ui,
inputs=selected,
outputs=buttons + [gen],
)
demo.queue(max_size=10)
load_models()
demo.launch()