KaosGen / app.py
JohanBeytell's picture
Update app.py
d3049bb verified
# === ADDITIONAL UI FEEDBACK + SEED + TIMING ===
import gradio as gr
import torch
import torch.nn as nn
import re, unicodedata, random
from pathlib import Path
import pandas as pd
import tempfile
import time
import os
from valx import detect_profanity, detect_hate_speech
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
SEED = 1337
# === Model Loading Diagnostics ===
torch.manual_seed(SEED)
random.seed(SEED)
print(f"📦 Model loading on: {DEVICE}")
ckpt = torch.load("kaos.pt", map_location=DEVICE)
stoi, itos = ckpt['stoi'], ckpt['itos']
SPECIAL = ['<pad>', '<bos>', '<eos>', '<sep>']
PAD, BOS, EOS, SEP = [stoi[s] for s in SPECIAL]
VOCAB_SIZE = len(itos)
MAX_LEN = 128
class GPTSmall(nn.Module):
def __init__(self, vocab_size, d_model=256, n_head=8, n_layer=4, dropout=0.2, max_len=MAX_LEN):
super().__init__()
self.tok_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Parameter(torch.zeros(1, max_len, d_model))
nn.init.trunc_normal_(self.pos_emb, std=0.02)
block = nn.TransformerEncoderLayer(d_model, n_head, d_model * 4, dropout=dropout, batch_first=True)
self.blocks = nn.ModuleList([block for _ in range(n_layer)])
self.norm = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, vocab_size, bias=False)
def forward(self, x):
B, T = x.shape
tok = self.tok_emb(x) + self.pos_emb[:, :T]
mask = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), 1)
for blk in self.blocks:
tok = blk(tok, src_key_padding_mask=(x == PAD), src_mask=mask)
return self.head(self.norm(tok))
model = GPTSmall(VOCAB_SIZE).to(DEVICE)
model.load_state_dict(ckpt['model'])
model.eval()
def proper_case(text):
return re.sub(r"\b(of|the|and|in|on|a)\b", lambda m: m.group(0).lower(), text.title())
def clean_name(text, title_case=True, max_repeats=2):
text = unicodedata.normalize("NFC", text)
text = re.sub(r'(.)\1{2,}', lambda m: m.group(1) * max_repeats, text)
text = re.sub(r"’S|\'S", "'s", text)
text = re.sub(r"[^0-9A-Za-zÀ-ÖØ-öø-ÿ'’\-\s]", "", text)
text = re.sub(r"\s+", " ", text).strip()
if title_case:
text = proper_case(text)
text = re.sub(r'\b(The|Of|In|On|A)\s+\1\b', r'\1', text, flags=re.IGNORECASE)
return re.sub(r"([a-zA-Z])'S\b", lambda m: m.group(1) + "'s", text)
def sample_once(prompt, temperature=1.0, top_k=40, max_new=40):
start_time = time.time()
seq = [BOS] + [stoi.get(c, PAD) for c in prompt] + [SEP]
for _ in range(max_new):
x = torch.tensor(seq[-MAX_LEN:], dtype=torch.long, device=DEVICE)[None]
with torch.no_grad():
logits = model(x)[:, -1, :] / temperature
if top_k:
v, i = torch.topk(logits, top_k)
idx = i[0, torch.softmax(v, -1).multinomial(1)].item()
else:
idx = torch.softmax(logits, -1).multinomial(1).item()
if idx == EOS:
break
seq.append(idx)
generated = [itos[i] for i in seq if i not in {BOS, SEP, EOS, PAD}]
name = ''.join(generated).replace(prompt, "").strip()
return clean_name(name), time.time() - start_time
def generate_names(prompt, temperature, top_k, count, retries, seed, randomize_seed):
if randomize_seed:
seed = random.randint(1, 999999)
torch.manual_seed(seed)
random.seed(seed)
prompt = prompt.strip()
promptx = prompt.lower()
if detect_profanity([promptx], language='All'):
gr.Warning("Profanity detected in the prompt, using the default prompt.")
prompt = 'a kind king'
elif (hate_speech_result := detect_hate_speech(promptx)) and hate_speech_result[0] in ['Hate Speech', 'Offensive Speech']:
gr.Warning('Harmful speech detected in the prompt, using default prompt.')
prompt = 'a kind king'
if not prompt:
raise gr.Error("Prompt cannot be empty.")
if len(prompt) > 64:
raise gr.Error("Prompt is too long. Please keep it under 64 characters.")
results = []
rejected = []
retry_count = 0
timings = []
for _ in range(count):
for attempt in range(retries):
name, t = sample_once(prompt, temperature=temperature, top_k=top_k)
namex = name.strip().lower()
if detect_profanity([namex], language='All'):
gr.Warning("Profanity detected in the generated name, flagging...")
rejected.append(name + " (Profanity Detected)")
elif (hate_speech_result := detect_hate_speech(namex)) and hate_speech_result[0] in ['Hate Speech', 'Offensive Speech']:
gr.Warning('Harmful speech detected in the generated name, flagging...')
rejected.append(name + " (Harmful Speech Detected)")
retry_count += 1
if len(name) >= 3:
results.append({"Generated Name": name, "Time (s)": f"{t:.2f}"})
timings.append(t)
break
else:
rejected.append(name)
df = pd.DataFrame(results)
file_path = tempfile.NamedTemporaryFile(delete=False, suffix=".txt").name
df[["Generated Name"]].to_csv(file_path, index=False, header=False)
retry_report = f"## Debug Report\n\n- **Total Retries:** {retry_count - len(results)}\n- **Seed Used:** {seed}\n- **Average Sample Time:** {sum(timings)/len(timings):.2f}s\n\n### Rejected Candidates:\n" + '\n'.join(rejected or ["None"])
return file_path, file_path, df, retry_report
description = """# KaosGen: A Fantasy Name Generator
`Kaos` is a small GPT-style transformer (~890k parameters) trained from scratch using character-level tokenization.
It excels at fantasy and mythic naming conventions.
"""
examples = [["a forgotten warrior king"], ["queen of the shattered realm"], ["blacksmith of shadows"], ["titan of the blazing sky"], ["a blade that burns through matter"]]
with gr.Blocks() as demo:
gr.Markdown(description)
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", placeholder="e.g. 'a villain who whispers to shadows'")
temperature = gr.Slider(0.1, 1.5, step=0.1, value=1.0, label="Temperature")
top_k = gr.Slider(10, 100, step=10, value=40, label="Top-K Sampling")
count = gr.Slider(1, 20, step=1, value=5, label="Names to Generate")
retries = gr.Slider(1, 5, step=1, value=3, label="Max Retries per Name")
seed = gr.Number(label="Seed", value=1337, precision=0)
randomize_seed = gr.Checkbox(label="Use Random Seed", value=False)
generate_btn = gr.Button("🎲 Generate Names")
with gr.Column():
output = gr.Dataframe(headers=["Generated Name", "Time (s)"], datatype=["str", "str"], label="Generated Names", interactive=False)
download = gr.File(label="📥 Export Names as .txt")
retry_report = gr.Markdown("", label="Debug Info")
generate_btn.click(fn=generate_names, inputs=[prompt, temperature, top_k, count, retries, seed, randomize_seed], outputs=[download, download, output, retry_report])
gr.Examples(examples=examples, inputs=prompt)
demo.launch()