File size: 7,195 Bytes
23126a4
4f0da9b
 
 
b7710e7
6f4e03e
b7710e7
5066051
86f9dff
fc01c21
d3049bb
4f0da9b
6f4e03e
4f0da9b
23126a4
 
4f0da9b
6f4e03e
fc01c21
 
4f0da9b
6f4e03e
 
 
4f0da9b
94368f4
4f0da9b
 
 
 
 
 
 
 
 
 
 
 
 
 
6f4e03e
4f0da9b
 
 
6f4e03e
4f0da9b
 
6f4e03e
4f0da9b
 
 
6f4e03e
4f0da9b
 
 
6f4e03e
 
1f501e4
6f4e03e
4f0da9b
 
6f4e03e
 
4f0da9b
37a5f98
23126a4
4f0da9b
6f4e03e
 
 
4f0da9b
6f4e03e
 
 
 
 
 
4f0da9b
6f4e03e
b7710e7
 
23126a4
 
 
 
 
 
 
6f4e03e
c81ba21
d3049bb
 
 
 
 
 
 
 
c81ba21
 
 
 
fc01c21
6f4e03e
fc01c21
 
23126a4
fc01c21
6f4e03e
5066051
23126a4
d3049bb
 
 
 
 
 
 
fc01c21
37a5f98
23126a4
 
5066051
fc01c21
 
 
5066051
 
23126a4
 
 
9a02502
5066051
14256e9
8463652
5066051
a6bf92b
 
fc01c21
c81ba21
6f4e03e
 
 
5066051
 
 
 
 
 
23126a4
 
1f501e4
5066051
23126a4
1f501e4
23126a4
5066051
23126a4
c81ba21
6f4e03e
902ae22
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# === ADDITIONAL UI FEEDBACK + SEED + TIMING ===
import gradio as gr
import torch
import torch.nn as nn
import re, unicodedata, random
from pathlib import Path
import pandas as pd
import tempfile
import time
import os
from valx import detect_profanity, detect_hate_speech

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
SEED = 1337

# === Model Loading Diagnostics ===
torch.manual_seed(SEED)
random.seed(SEED)
print(f"📦 Model loading on: {DEVICE}")

ckpt = torch.load("kaos.pt", map_location=DEVICE)
stoi, itos = ckpt['stoi'], ckpt['itos']
SPECIAL = ['<pad>', '<bos>', '<eos>', '<sep>']
PAD, BOS, EOS, SEP = [stoi[s] for s in SPECIAL]
VOCAB_SIZE = len(itos)
MAX_LEN = 128

class GPTSmall(nn.Module):
    def __init__(self, vocab_size, d_model=256, n_head=8, n_layer=4, dropout=0.2, max_len=MAX_LEN):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Parameter(torch.zeros(1, max_len, d_model))
        nn.init.trunc_normal_(self.pos_emb, std=0.02)
        block = nn.TransformerEncoderLayer(d_model, n_head, d_model * 4, dropout=dropout, batch_first=True)
        self.blocks = nn.ModuleList([block for _ in range(n_layer)])
        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, x):
        B, T = x.shape
        tok = self.tok_emb(x) + self.pos_emb[:, :T]
        mask = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), 1)
        for blk in self.blocks:
            tok = blk(tok, src_key_padding_mask=(x == PAD), src_mask=mask)
        return self.head(self.norm(tok))

model = GPTSmall(VOCAB_SIZE).to(DEVICE)
model.load_state_dict(ckpt['model'])
model.eval()

def proper_case(text):
    return re.sub(r"\b(of|the|and|in|on|a)\b", lambda m: m.group(0).lower(), text.title())

def clean_name(text, title_case=True, max_repeats=2):
    text = unicodedata.normalize("NFC", text)
    text = re.sub(r'(.)\1{2,}', lambda m: m.group(1) * max_repeats, text)
    text = re.sub(r"’S|\'S", "'s", text)
    text = re.sub(r"[^0-9A-Za-zÀ-ÖØ-öø-ÿ'’\-\s]", "", text)
    text = re.sub(r"\s+", " ", text).strip()
    if title_case:
        text = proper_case(text)
    text = re.sub(r'\b(The|Of|In|On|A)\s+\1\b', r'\1', text, flags=re.IGNORECASE)
    return re.sub(r"([a-zA-Z])'S\b", lambda m: m.group(1) + "'s", text)

def sample_once(prompt, temperature=1.0, top_k=40, max_new=40):
    start_time = time.time()
    seq = [BOS] + [stoi.get(c, PAD) for c in prompt] + [SEP]
    for _ in range(max_new):
        x = torch.tensor(seq[-MAX_LEN:], dtype=torch.long, device=DEVICE)[None]
        with torch.no_grad():
            logits = model(x)[:, -1, :] / temperature
        if top_k:
            v, i = torch.topk(logits, top_k)
            idx = i[0, torch.softmax(v, -1).multinomial(1)].item()
        else:
            idx = torch.softmax(logits, -1).multinomial(1).item()
        if idx == EOS:
            break
        seq.append(idx)
    generated = [itos[i] for i in seq if i not in {BOS, SEP, EOS, PAD}]
    name = ''.join(generated).replace(prompt, "").strip()
    return clean_name(name), time.time() - start_time

def generate_names(prompt, temperature, top_k, count, retries, seed, randomize_seed):
    if randomize_seed:
        seed = random.randint(1, 999999)
    torch.manual_seed(seed)
    random.seed(seed)

    prompt = prompt.strip()
    promptx = prompt.lower()
    if detect_profanity([promptx], language='All'):
        gr.Warning("Profanity detected in the prompt, using the default prompt.")
        prompt = 'a kind king'
    elif (hate_speech_result := detect_hate_speech(promptx)) and hate_speech_result[0] in ['Hate Speech', 'Offensive Speech']:
        gr.Warning('Harmful speech detected in the prompt, using default prompt.')
        prompt = 'a kind king'

    if not prompt:
        raise gr.Error("Prompt cannot be empty.")
    if len(prompt) > 64:
        raise gr.Error("Prompt is too long. Please keep it under 64 characters.")

    results = []
    rejected = []
    retry_count = 0
    timings = []

    for _ in range(count):
        for attempt in range(retries):
            name, t = sample_once(prompt, temperature=temperature, top_k=top_k)
            namex = name.strip().lower()
            if detect_profanity([namex], language='All'):
                gr.Warning("Profanity detected in the generated name, flagging...")
                rejected.append(name + " (Profanity Detected)")
            elif (hate_speech_result := detect_hate_speech(namex)) and hate_speech_result[0] in ['Hate Speech', 'Offensive Speech']:
                gr.Warning('Harmful speech detected in the generated name, flagging...')
                rejected.append(name + " (Harmful Speech Detected)")
            retry_count += 1
            if len(name) >= 3:
                results.append({"Generated Name": name, "Time (s)": f"{t:.2f}"})
                timings.append(t)
                break
            else:
                rejected.append(name)

    df = pd.DataFrame(results)
    file_path = tempfile.NamedTemporaryFile(delete=False, suffix=".txt").name
    df[["Generated Name"]].to_csv(file_path, index=False, header=False)

    retry_report = f"## Debug Report\n\n- **Total Retries:** {retry_count - len(results)}\n- **Seed Used:** {seed}\n- **Average Sample Time:** {sum(timings)/len(timings):.2f}s\n\n### Rejected Candidates:\n" + '\n'.join(rejected or ["None"])
    return file_path, file_path, df, retry_report

description = """# KaosGen: A Fantasy Name Generator  
`Kaos` is a small GPT-style transformer (~890k parameters) trained from scratch using character-level tokenization.  
It excels at fantasy and mythic naming conventions.
"""

examples = [["a forgotten warrior king"], ["queen of the shattered realm"], ["blacksmith of shadows"], ["titan of the blazing sky"], ["a blade that burns through matter"]]

with gr.Blocks() as demo:
    gr.Markdown(description)
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(label="Prompt", placeholder="e.g. 'a villain who whispers to shadows'")
            temperature = gr.Slider(0.1, 1.5, step=0.1, value=1.0, label="Temperature")
            top_k = gr.Slider(10, 100, step=10, value=40, label="Top-K Sampling")
            count = gr.Slider(1, 20, step=1, value=5, label="Names to Generate")
            retries = gr.Slider(1, 5, step=1, value=3, label="Max Retries per Name")
            seed = gr.Number(label="Seed", value=1337, precision=0)
            randomize_seed = gr.Checkbox(label="Use Random Seed", value=False)
            generate_btn = gr.Button("🎲 Generate Names")
        with gr.Column():
            output = gr.Dataframe(headers=["Generated Name", "Time (s)"], datatype=["str", "str"], label="Generated Names", interactive=False)
            download = gr.File(label="📥 Export Names as .txt")
            retry_report = gr.Markdown("", label="Debug Info")

    generate_btn.click(fn=generate_names, inputs=[prompt, temperature, top_k, count, retries, seed, randomize_seed], outputs=[download, download, output, retry_report])
    gr.Examples(examples=examples, inputs=prompt)

demo.launch()