Spaces:
Sleeping
Sleeping
File size: 7,195 Bytes
23126a4 4f0da9b b7710e7 6f4e03e b7710e7 5066051 86f9dff fc01c21 d3049bb 4f0da9b 6f4e03e 4f0da9b 23126a4 4f0da9b 6f4e03e fc01c21 4f0da9b 6f4e03e 4f0da9b 94368f4 4f0da9b 6f4e03e 4f0da9b 6f4e03e 4f0da9b 6f4e03e 4f0da9b 6f4e03e 4f0da9b 6f4e03e 1f501e4 6f4e03e 4f0da9b 6f4e03e 4f0da9b 37a5f98 23126a4 4f0da9b 6f4e03e 4f0da9b 6f4e03e 4f0da9b 6f4e03e b7710e7 23126a4 6f4e03e c81ba21 d3049bb c81ba21 fc01c21 6f4e03e fc01c21 23126a4 fc01c21 6f4e03e 5066051 23126a4 d3049bb fc01c21 37a5f98 23126a4 5066051 fc01c21 5066051 23126a4 9a02502 5066051 14256e9 8463652 5066051 a6bf92b fc01c21 c81ba21 6f4e03e 5066051 23126a4 1f501e4 5066051 23126a4 1f501e4 23126a4 5066051 23126a4 c81ba21 6f4e03e 902ae22 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
# === ADDITIONAL UI FEEDBACK + SEED + TIMING ===
import gradio as gr
import torch
import torch.nn as nn
import re, unicodedata, random
from pathlib import Path
import pandas as pd
import tempfile
import time
import os
from valx import detect_profanity, detect_hate_speech
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
SEED = 1337
# === Model Loading Diagnostics ===
torch.manual_seed(SEED)
random.seed(SEED)
print(f"📦 Model loading on: {DEVICE}")
ckpt = torch.load("kaos.pt", map_location=DEVICE)
stoi, itos = ckpt['stoi'], ckpt['itos']
SPECIAL = ['<pad>', '<bos>', '<eos>', '<sep>']
PAD, BOS, EOS, SEP = [stoi[s] for s in SPECIAL]
VOCAB_SIZE = len(itos)
MAX_LEN = 128
class GPTSmall(nn.Module):
def __init__(self, vocab_size, d_model=256, n_head=8, n_layer=4, dropout=0.2, max_len=MAX_LEN):
super().__init__()
self.tok_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Parameter(torch.zeros(1, max_len, d_model))
nn.init.trunc_normal_(self.pos_emb, std=0.02)
block = nn.TransformerEncoderLayer(d_model, n_head, d_model * 4, dropout=dropout, batch_first=True)
self.blocks = nn.ModuleList([block for _ in range(n_layer)])
self.norm = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, vocab_size, bias=False)
def forward(self, x):
B, T = x.shape
tok = self.tok_emb(x) + self.pos_emb[:, :T]
mask = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), 1)
for blk in self.blocks:
tok = blk(tok, src_key_padding_mask=(x == PAD), src_mask=mask)
return self.head(self.norm(tok))
model = GPTSmall(VOCAB_SIZE).to(DEVICE)
model.load_state_dict(ckpt['model'])
model.eval()
def proper_case(text):
return re.sub(r"\b(of|the|and|in|on|a)\b", lambda m: m.group(0).lower(), text.title())
def clean_name(text, title_case=True, max_repeats=2):
text = unicodedata.normalize("NFC", text)
text = re.sub(r'(.)\1{2,}', lambda m: m.group(1) * max_repeats, text)
text = re.sub(r"’S|\'S", "'s", text)
text = re.sub(r"[^0-9A-Za-zÀ-ÖØ-öø-ÿ'’\-\s]", "", text)
text = re.sub(r"\s+", " ", text).strip()
if title_case:
text = proper_case(text)
text = re.sub(r'\b(The|Of|In|On|A)\s+\1\b', r'\1', text, flags=re.IGNORECASE)
return re.sub(r"([a-zA-Z])'S\b", lambda m: m.group(1) + "'s", text)
def sample_once(prompt, temperature=1.0, top_k=40, max_new=40):
start_time = time.time()
seq = [BOS] + [stoi.get(c, PAD) for c in prompt] + [SEP]
for _ in range(max_new):
x = torch.tensor(seq[-MAX_LEN:], dtype=torch.long, device=DEVICE)[None]
with torch.no_grad():
logits = model(x)[:, -1, :] / temperature
if top_k:
v, i = torch.topk(logits, top_k)
idx = i[0, torch.softmax(v, -1).multinomial(1)].item()
else:
idx = torch.softmax(logits, -1).multinomial(1).item()
if idx == EOS:
break
seq.append(idx)
generated = [itos[i] for i in seq if i not in {BOS, SEP, EOS, PAD}]
name = ''.join(generated).replace(prompt, "").strip()
return clean_name(name), time.time() - start_time
def generate_names(prompt, temperature, top_k, count, retries, seed, randomize_seed):
if randomize_seed:
seed = random.randint(1, 999999)
torch.manual_seed(seed)
random.seed(seed)
prompt = prompt.strip()
promptx = prompt.lower()
if detect_profanity([promptx], language='All'):
gr.Warning("Profanity detected in the prompt, using the default prompt.")
prompt = 'a kind king'
elif (hate_speech_result := detect_hate_speech(promptx)) and hate_speech_result[0] in ['Hate Speech', 'Offensive Speech']:
gr.Warning('Harmful speech detected in the prompt, using default prompt.')
prompt = 'a kind king'
if not prompt:
raise gr.Error("Prompt cannot be empty.")
if len(prompt) > 64:
raise gr.Error("Prompt is too long. Please keep it under 64 characters.")
results = []
rejected = []
retry_count = 0
timings = []
for _ in range(count):
for attempt in range(retries):
name, t = sample_once(prompt, temperature=temperature, top_k=top_k)
namex = name.strip().lower()
if detect_profanity([namex], language='All'):
gr.Warning("Profanity detected in the generated name, flagging...")
rejected.append(name + " (Profanity Detected)")
elif (hate_speech_result := detect_hate_speech(namex)) and hate_speech_result[0] in ['Hate Speech', 'Offensive Speech']:
gr.Warning('Harmful speech detected in the generated name, flagging...')
rejected.append(name + " (Harmful Speech Detected)")
retry_count += 1
if len(name) >= 3:
results.append({"Generated Name": name, "Time (s)": f"{t:.2f}"})
timings.append(t)
break
else:
rejected.append(name)
df = pd.DataFrame(results)
file_path = tempfile.NamedTemporaryFile(delete=False, suffix=".txt").name
df[["Generated Name"]].to_csv(file_path, index=False, header=False)
retry_report = f"## Debug Report\n\n- **Total Retries:** {retry_count - len(results)}\n- **Seed Used:** {seed}\n- **Average Sample Time:** {sum(timings)/len(timings):.2f}s\n\n### Rejected Candidates:\n" + '\n'.join(rejected or ["None"])
return file_path, file_path, df, retry_report
description = """# KaosGen: A Fantasy Name Generator
`Kaos` is a small GPT-style transformer (~890k parameters) trained from scratch using character-level tokenization.
It excels at fantasy and mythic naming conventions.
"""
examples = [["a forgotten warrior king"], ["queen of the shattered realm"], ["blacksmith of shadows"], ["titan of the blazing sky"], ["a blade that burns through matter"]]
with gr.Blocks() as demo:
gr.Markdown(description)
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", placeholder="e.g. 'a villain who whispers to shadows'")
temperature = gr.Slider(0.1, 1.5, step=0.1, value=1.0, label="Temperature")
top_k = gr.Slider(10, 100, step=10, value=40, label="Top-K Sampling")
count = gr.Slider(1, 20, step=1, value=5, label="Names to Generate")
retries = gr.Slider(1, 5, step=1, value=3, label="Max Retries per Name")
seed = gr.Number(label="Seed", value=1337, precision=0)
randomize_seed = gr.Checkbox(label="Use Random Seed", value=False)
generate_btn = gr.Button("🎲 Generate Names")
with gr.Column():
output = gr.Dataframe(headers=["Generated Name", "Time (s)"], datatype=["str", "str"], label="Generated Names", interactive=False)
download = gr.File(label="📥 Export Names as .txt")
retry_report = gr.Markdown("", label="Debug Info")
generate_btn.click(fn=generate_names, inputs=[prompt, temperature, top_k, count, retries, seed, randomize_seed], outputs=[download, download, output, retry_report])
gr.Examples(examples=examples, inputs=prompt)
demo.launch() |