# === ADDITIONAL UI FEEDBACK + SEED + TIMING === import gradio as gr import torch import torch.nn as nn import re, unicodedata, random from pathlib import Path import pandas as pd import tempfile import time import os from valx import detect_profanity, detect_hate_speech DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' SEED = 1337 # === Model Loading Diagnostics === torch.manual_seed(SEED) random.seed(SEED) print(f"📦 Model loading on: {DEVICE}") ckpt = torch.load("kaos.pt", map_location=DEVICE) stoi, itos = ckpt['stoi'], ckpt['itos'] SPECIAL = ['', '', '', ''] PAD, BOS, EOS, SEP = [stoi[s] for s in SPECIAL] VOCAB_SIZE = len(itos) MAX_LEN = 128 class GPTSmall(nn.Module): def __init__(self, vocab_size, d_model=256, n_head=8, n_layer=4, dropout=0.2, max_len=MAX_LEN): super().__init__() self.tok_emb = nn.Embedding(vocab_size, d_model) self.pos_emb = nn.Parameter(torch.zeros(1, max_len, d_model)) nn.init.trunc_normal_(self.pos_emb, std=0.02) block = nn.TransformerEncoderLayer(d_model, n_head, d_model * 4, dropout=dropout, batch_first=True) self.blocks = nn.ModuleList([block for _ in range(n_layer)]) self.norm = nn.LayerNorm(d_model) self.head = nn.Linear(d_model, vocab_size, bias=False) def forward(self, x): B, T = x.shape tok = self.tok_emb(x) + self.pos_emb[:, :T] mask = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), 1) for blk in self.blocks: tok = blk(tok, src_key_padding_mask=(x == PAD), src_mask=mask) return self.head(self.norm(tok)) model = GPTSmall(VOCAB_SIZE).to(DEVICE) model.load_state_dict(ckpt['model']) model.eval() def proper_case(text): return re.sub(r"\b(of|the|and|in|on|a)\b", lambda m: m.group(0).lower(), text.title()) def clean_name(text, title_case=True, max_repeats=2): text = unicodedata.normalize("NFC", text) text = re.sub(r'(.)\1{2,}', lambda m: m.group(1) * max_repeats, text) text = re.sub(r"’S|\'S", "'s", text) text = re.sub(r"[^0-9A-Za-zÀ-ÖØ-öø-ÿ'’\-\s]", "", text) text = re.sub(r"\s+", " ", text).strip() if title_case: text = proper_case(text) text = re.sub(r'\b(The|Of|In|On|A)\s+\1\b', r'\1', text, flags=re.IGNORECASE) return re.sub(r"([a-zA-Z])'S\b", lambda m: m.group(1) + "'s", text) def sample_once(prompt, temperature=1.0, top_k=40, max_new=40): start_time = time.time() seq = [BOS] + [stoi.get(c, PAD) for c in prompt] + [SEP] for _ in range(max_new): x = torch.tensor(seq[-MAX_LEN:], dtype=torch.long, device=DEVICE)[None] with torch.no_grad(): logits = model(x)[:, -1, :] / temperature if top_k: v, i = torch.topk(logits, top_k) idx = i[0, torch.softmax(v, -1).multinomial(1)].item() else: idx = torch.softmax(logits, -1).multinomial(1).item() if idx == EOS: break seq.append(idx) generated = [itos[i] for i in seq if i not in {BOS, SEP, EOS, PAD}] name = ''.join(generated).replace(prompt, "").strip() return clean_name(name), time.time() - start_time def generate_names(prompt, temperature, top_k, count, retries, seed, randomize_seed): if randomize_seed: seed = random.randint(1, 999999) torch.manual_seed(seed) random.seed(seed) prompt = prompt.strip() promptx = prompt.lower() if detect_profanity([promptx], language='All'): gr.Warning("Profanity detected in the prompt, using the default prompt.") prompt = 'a kind king' elif (hate_speech_result := detect_hate_speech(promptx)) and hate_speech_result[0] in ['Hate Speech', 'Offensive Speech']: gr.Warning('Harmful speech detected in the prompt, using default prompt.') prompt = 'a kind king' if not prompt: raise gr.Error("Prompt cannot be empty.") if len(prompt) > 64: raise gr.Error("Prompt is too long. Please keep it under 64 characters.") results = [] rejected = [] retry_count = 0 timings = [] for _ in range(count): for attempt in range(retries): name, t = sample_once(prompt, temperature=temperature, top_k=top_k) namex = name.strip().lower() if detect_profanity([namex], language='All'): gr.Warning("Profanity detected in the generated name, flagging...") rejected.append(name + " (Profanity Detected)") elif (hate_speech_result := detect_hate_speech(namex)) and hate_speech_result[0] in ['Hate Speech', 'Offensive Speech']: gr.Warning('Harmful speech detected in the generated name, flagging...') rejected.append(name + " (Harmful Speech Detected)") retry_count += 1 if len(name) >= 3: results.append({"Generated Name": name, "Time (s)": f"{t:.2f}"}) timings.append(t) break else: rejected.append(name) df = pd.DataFrame(results) file_path = tempfile.NamedTemporaryFile(delete=False, suffix=".txt").name df[["Generated Name"]].to_csv(file_path, index=False, header=False) retry_report = f"## Debug Report\n\n- **Total Retries:** {retry_count - len(results)}\n- **Seed Used:** {seed}\n- **Average Sample Time:** {sum(timings)/len(timings):.2f}s\n\n### Rejected Candidates:\n" + '\n'.join(rejected or ["None"]) return file_path, file_path, df, retry_report description = """# KaosGen: A Fantasy Name Generator `Kaos` is a small GPT-style transformer (~890k parameters) trained from scratch using character-level tokenization. It excels at fantasy and mythic naming conventions. """ examples = [["a forgotten warrior king"], ["queen of the shattered realm"], ["blacksmith of shadows"], ["titan of the blazing sky"], ["a blade that burns through matter"]] with gr.Blocks() as demo: gr.Markdown(description) with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Prompt", placeholder="e.g. 'a villain who whispers to shadows'") temperature = gr.Slider(0.1, 1.5, step=0.1, value=1.0, label="Temperature") top_k = gr.Slider(10, 100, step=10, value=40, label="Top-K Sampling") count = gr.Slider(1, 20, step=1, value=5, label="Names to Generate") retries = gr.Slider(1, 5, step=1, value=3, label="Max Retries per Name") seed = gr.Number(label="Seed", value=1337, precision=0) randomize_seed = gr.Checkbox(label="Use Random Seed", value=False) generate_btn = gr.Button("🎲 Generate Names") with gr.Column(): output = gr.Dataframe(headers=["Generated Name", "Time (s)"], datatype=["str", "str"], label="Generated Names", interactive=False) download = gr.File(label="📥 Export Names as .txt") retry_report = gr.Markdown("", label="Debug Info") generate_btn.click(fn=generate_names, inputs=[prompt, temperature, top_k, count, retries, seed, randomize_seed], outputs=[download, download, output, retry_report]) gr.Examples(examples=examples, inputs=prompt) demo.launch()