Spaces:
Sleeping
Sleeping
| # === ADDITIONAL UI FEEDBACK + SEED + TIMING === | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import re, unicodedata, random | |
| from pathlib import Path | |
| import pandas as pd | |
| import tempfile | |
| import time | |
| import os | |
| from valx import detect_profanity, detect_hate_speech | |
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| SEED = 1337 | |
| # === Model Loading Diagnostics === | |
| torch.manual_seed(SEED) | |
| random.seed(SEED) | |
| print(f"📦 Model loading on: {DEVICE}") | |
| ckpt = torch.load("kaos.pt", map_location=DEVICE) | |
| stoi, itos = ckpt['stoi'], ckpt['itos'] | |
| SPECIAL = ['<pad>', '<bos>', '<eos>', '<sep>'] | |
| PAD, BOS, EOS, SEP = [stoi[s] for s in SPECIAL] | |
| VOCAB_SIZE = len(itos) | |
| MAX_LEN = 128 | |
| class GPTSmall(nn.Module): | |
| def __init__(self, vocab_size, d_model=256, n_head=8, n_layer=4, dropout=0.2, max_len=MAX_LEN): | |
| super().__init__() | |
| self.tok_emb = nn.Embedding(vocab_size, d_model) | |
| self.pos_emb = nn.Parameter(torch.zeros(1, max_len, d_model)) | |
| nn.init.trunc_normal_(self.pos_emb, std=0.02) | |
| block = nn.TransformerEncoderLayer(d_model, n_head, d_model * 4, dropout=dropout, batch_first=True) | |
| self.blocks = nn.ModuleList([block for _ in range(n_layer)]) | |
| self.norm = nn.LayerNorm(d_model) | |
| self.head = nn.Linear(d_model, vocab_size, bias=False) | |
| def forward(self, x): | |
| B, T = x.shape | |
| tok = self.tok_emb(x) + self.pos_emb[:, :T] | |
| mask = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), 1) | |
| for blk in self.blocks: | |
| tok = blk(tok, src_key_padding_mask=(x == PAD), src_mask=mask) | |
| return self.head(self.norm(tok)) | |
| model = GPTSmall(VOCAB_SIZE).to(DEVICE) | |
| model.load_state_dict(ckpt['model']) | |
| model.eval() | |
| def proper_case(text): | |
| return re.sub(r"\b(of|the|and|in|on|a)\b", lambda m: m.group(0).lower(), text.title()) | |
| def clean_name(text, title_case=True, max_repeats=2): | |
| text = unicodedata.normalize("NFC", text) | |
| text = re.sub(r'(.)\1{2,}', lambda m: m.group(1) * max_repeats, text) | |
| text = re.sub(r"’S|\'S", "'s", text) | |
| text = re.sub(r"[^0-9A-Za-zÀ-ÖØ-öø-ÿ'’\-\s]", "", text) | |
| text = re.sub(r"\s+", " ", text).strip() | |
| if title_case: | |
| text = proper_case(text) | |
| text = re.sub(r'\b(The|Of|In|On|A)\s+\1\b', r'\1', text, flags=re.IGNORECASE) | |
| return re.sub(r"([a-zA-Z])'S\b", lambda m: m.group(1) + "'s", text) | |
| def sample_once(prompt, temperature=1.0, top_k=40, max_new=40): | |
| start_time = time.time() | |
| seq = [BOS] + [stoi.get(c, PAD) for c in prompt] + [SEP] | |
| for _ in range(max_new): | |
| x = torch.tensor(seq[-MAX_LEN:], dtype=torch.long, device=DEVICE)[None] | |
| with torch.no_grad(): | |
| logits = model(x)[:, -1, :] / temperature | |
| if top_k: | |
| v, i = torch.topk(logits, top_k) | |
| idx = i[0, torch.softmax(v, -1).multinomial(1)].item() | |
| else: | |
| idx = torch.softmax(logits, -1).multinomial(1).item() | |
| if idx == EOS: | |
| break | |
| seq.append(idx) | |
| generated = [itos[i] for i in seq if i not in {BOS, SEP, EOS, PAD}] | |
| name = ''.join(generated).replace(prompt, "").strip() | |
| return clean_name(name), time.time() - start_time | |
| def generate_names(prompt, temperature, top_k, count, retries, seed, randomize_seed): | |
| if randomize_seed: | |
| seed = random.randint(1, 999999) | |
| torch.manual_seed(seed) | |
| random.seed(seed) | |
| prompt = prompt.strip() | |
| promptx = prompt.lower() | |
| if detect_profanity([promptx], language='All'): | |
| gr.Warning("Profanity detected in the prompt, using the default prompt.") | |
| prompt = 'a kind king' | |
| elif (hate_speech_result := detect_hate_speech(promptx)) and hate_speech_result[0] in ['Hate Speech', 'Offensive Speech']: | |
| gr.Warning('Harmful speech detected in the prompt, using default prompt.') | |
| prompt = 'a kind king' | |
| if not prompt: | |
| raise gr.Error("Prompt cannot be empty.") | |
| if len(prompt) > 64: | |
| raise gr.Error("Prompt is too long. Please keep it under 64 characters.") | |
| results = [] | |
| rejected = [] | |
| retry_count = 0 | |
| timings = [] | |
| for _ in range(count): | |
| for attempt in range(retries): | |
| name, t = sample_once(prompt, temperature=temperature, top_k=top_k) | |
| namex = name.strip().lower() | |
| if detect_profanity([namex], language='All'): | |
| gr.Warning("Profanity detected in the generated name, flagging...") | |
| rejected.append(name + " (Profanity Detected)") | |
| elif (hate_speech_result := detect_hate_speech(namex)) and hate_speech_result[0] in ['Hate Speech', 'Offensive Speech']: | |
| gr.Warning('Harmful speech detected in the generated name, flagging...') | |
| rejected.append(name + " (Harmful Speech Detected)") | |
| retry_count += 1 | |
| if len(name) >= 3: | |
| results.append({"Generated Name": name, "Time (s)": f"{t:.2f}"}) | |
| timings.append(t) | |
| break | |
| else: | |
| rejected.append(name) | |
| df = pd.DataFrame(results) | |
| file_path = tempfile.NamedTemporaryFile(delete=False, suffix=".txt").name | |
| df[["Generated Name"]].to_csv(file_path, index=False, header=False) | |
| retry_report = f"## Debug Report\n\n- **Total Retries:** {retry_count - len(results)}\n- **Seed Used:** {seed}\n- **Average Sample Time:** {sum(timings)/len(timings):.2f}s\n\n### Rejected Candidates:\n" + '\n'.join(rejected or ["None"]) | |
| return file_path, file_path, df, retry_report | |
| description = """# KaosGen: A Fantasy Name Generator | |
| `Kaos` is a small GPT-style transformer (~890k parameters) trained from scratch using character-level tokenization. | |
| It excels at fantasy and mythic naming conventions. | |
| """ | |
| examples = [["a forgotten warrior king"], ["queen of the shattered realm"], ["blacksmith of shadows"], ["titan of the blazing sky"], ["a blade that burns through matter"]] | |
| with gr.Blocks() as demo: | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.Textbox(label="Prompt", placeholder="e.g. 'a villain who whispers to shadows'") | |
| temperature = gr.Slider(0.1, 1.5, step=0.1, value=1.0, label="Temperature") | |
| top_k = gr.Slider(10, 100, step=10, value=40, label="Top-K Sampling") | |
| count = gr.Slider(1, 20, step=1, value=5, label="Names to Generate") | |
| retries = gr.Slider(1, 5, step=1, value=3, label="Max Retries per Name") | |
| seed = gr.Number(label="Seed", value=1337, precision=0) | |
| randomize_seed = gr.Checkbox(label="Use Random Seed", value=False) | |
| generate_btn = gr.Button("🎲 Generate Names") | |
| with gr.Column(): | |
| output = gr.Dataframe(headers=["Generated Name", "Time (s)"], datatype=["str", "str"], label="Generated Names", interactive=False) | |
| download = gr.File(label="📥 Export Names as .txt") | |
| retry_report = gr.Markdown("", label="Debug Info") | |
| generate_btn.click(fn=generate_names, inputs=[prompt, temperature, top_k, count, retries, seed, randomize_seed], outputs=[download, download, output, retry_report]) | |
| gr.Examples(examples=examples, inputs=prompt) | |
| demo.launch() |