JohanBeytell commited on
Commit
6f4e03e
verified
1 Parent(s): 4f0da9b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -126
app.py CHANGED
@@ -1,30 +1,24 @@
1
  import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
- import math
5
- import re
6
- import unicodedata
7
- import random
8
- import os
9
 
10
- # --- Load constants and model ---
 
11
  SEED = 1337
12
- random.seed(SEED)
13
  torch.manual_seed(SEED)
14
- torch.cuda.manual_seed_all(SEED)
15
-
16
- DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
17
- MAX_LEN = 128
18
-
19
- SPECIAL = ['<pad>', '<bos>', '<eos>', '<sep>']
20
- BOS, EOS, PAD, SEP = 1, 2, 0, 3
21
 
22
- # Load vocab
23
  ckpt = torch.load("kaos.pt", map_location=DEVICE)
24
- stoi = ckpt["stoi"]
25
- itos = ckpt["itos"]
 
26
  VOCAB_SIZE = len(itos)
 
27
 
 
28
  class GPTSmall(nn.Module):
29
  def __init__(self, vocab_size, d_model=256, n_head=8, n_layer=4, dropout=0.2, max_len=MAX_LEN):
30
  super().__init__()
@@ -38,129 +32,71 @@ class GPTSmall(nn.Module):
38
 
39
  def forward(self, x):
40
  B, T = x.shape
41
- tok = self.tok_emb(x)
42
- tok = tok + self.pos_emb[:, :T]
43
  mask = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), 1)
44
  for blk in self.blocks:
45
  tok = blk(tok, src_key_padding_mask=(x == PAD), src_mask=mask)
46
- tok = self.norm(tok)
47
- return self.head(tok)
48
 
49
  model = GPTSmall(VOCAB_SIZE).to(DEVICE)
50
- model.load_state_dict(ckpt["model"])
51
  model.eval()
52
 
53
- # --- Clean + scoring ---
54
  def proper_case(text):
55
- return re.sub(r"\\b(of|the|and|in|on|a)\\b", lambda m: m.group(0).lower(), text.title())
56
 
57
  def clean_name(text, title_case=True, max_repeats=2):
58
  text = unicodedata.normalize("NFC", text)
59
- text = re.sub(r"(.)\\1{2,}", lambda m: m.group(1) * max_repeats, text, flags=re.IGNORECASE)
60
- text = re.sub(r"鈥橲|\\'S", "'s", text)
61
- text = re.sub(r"[^0-9A-Za-z脌-脰脴-枚酶-每'鈥橽\-\\s]", "", text)
62
- text = re.sub(r"\\s+", " ", text).strip()
63
  if title_case:
64
  text = proper_case(text)
65
- text = re.sub(r"\\b(The|Of|In|On|A)\\s+\\1\\b", r"\\1", text, flags=re.IGNORECASE)
66
- text = re.sub(r"([a-zA-Z])'S\\b", lambda m: m.group(1) + "'s", text)
67
- return text
68
-
69
- def has_weird_word_lengths(name, min_len=3, max_len=24):
70
- return any(len(word) < min_len or len(word) > max_len for word in name.split())
71
 
72
- def gibberish_score(name):
73
- common_tris = {"the", "and", "ing", "ion", "ent", "ati", "for", "her", "ter", "tha", "ere", "nth", "tio", "ver",
74
- "his", "hat", "ers", "rea", "all", "ill", "ari", "est", "oth", "eve", "eld", "sky", "dra", "sha", "mir"}
75
- text = name.lower().replace(" ", "")
76
- trigrams = [text[i:i+3] for i in range(len(text) - 2)]
77
- if not trigrams:
78
- return 1.0
79
- bad = sum(1 for tri in trigrams if tri not in common_tris)
80
- return bad / len(trigrams)
81
-
82
- def pronounceability_score(name):
83
- name = name.lower()
84
- name = re.sub(r"[^a-z]", "", name)
85
- if not name: return 0.0
86
- vowels = "aeiouy"
87
- v_count = sum(1 for c in name if c in vowels)
88
- c_count = sum(1 for c in name if c not in vowels)
89
- vc_ratio = v_count / (c_count + 1)
90
- cluster_penalty = len(re.findall(r'[^aeiouy]{3,}', name)) * 0.1
91
- alternation = re.findall(r'[aeiouy]+|[^aeiouy]+', name)
92
- smoothness = len(alternation) / len(name)
93
- score = (vc_ratio * 0.6) + (smoothness * 0.6) - cluster_penalty
94
- return max(0.0, min(score, 1.0))
95
-
96
- def has_duplicate_articles(name):
97
- return bool(re.search(r'\\b(the|of|in|on|a)\\s+\\1\\b', name, flags=re.IGNORECASE))
98
-
99
- def is_problematic(name):
100
- return (
101
- re.search(r'\\b(the the|of of|in in)\\b', name.lower()) or
102
- (name.count(' ') == 0 and len(name) < 5) or
103
- len(re.findall(r'[bcdfghjklmnpqrstvwxyz]{5,}', name.lower())) > 0
104
- )
105
-
106
- def is_too_weird(name):
107
- return (
108
- any(len(w) > 14 for w in name.split()) or
109
- re.search(r"[bcdfghjklmnpqrstvwxyz]{5,}", name.lower())
110
- )
111
-
112
- def _sample_once(prompt, max_new=24, temperature=1.0, top_k=40):
113
  seq = [BOS] + [stoi.get(c, PAD) for c in prompt] + [SEP]
114
- with torch.no_grad():
115
- for _ in range(max_new):
116
- x = torch.tensor(seq[-MAX_LEN:], dtype=torch.long, device=DEVICE).unsqueeze(0)
117
  logits = model(x)[:, -1, :] / temperature
118
- if top_k:
119
- v, i = torch.topk(logits, top_k)
120
- idx = i[0, torch.softmax(v, -1).multinomial(1)].item()
121
- else:
122
- idx = torch.softmax(logits, -1).multinomial(1).item()
123
- if idx == EOS or itos[idx] == "</s>":
124
- break
125
- seq.append(idx)
126
- try:
127
- start = seq.index(SEP) + 1
128
- except ValueError:
129
- start = 0
130
- decoded = []
131
- for idx in seq[start:]:
132
- if idx == EOS or itos[idx] == "</s>":
133
  break
134
- if idx != PAD:
135
- decoded.append(itos[idx])
136
- return ''.join(decoded).strip()
137
-
138
- def generate_name(prompt, min_chars=4, min_words=1, min_score=0.55, max_retries=3, temperature=1.0, temp_decay=0.85, max_gibberish=0.5):
139
- last_try = ""
140
- for attempt in range(max_retries):
141
- temp = temperature * (temp_decay ** attempt)
142
- raw = _sample_once(prompt, temperature=temp)
143
- name = clean_name(raw)
144
- last_try = name
145
- score = pronounceability_score(name)
146
- gibber = gibberish_score(name)
147
- has_dupes = has_duplicate_articles(name)
148
- weird_words = has_weird_word_lengths(name)
149
- good = (
150
- len(name) >= min_chars and len(name.split()) >= min_words and
151
- score >= min_score and gibber <= max_gibberish and
152
- not has_dupes and not weird_words
153
- )
154
- if good and not is_too_weird(name) and not is_problematic(name):
155
- return name
156
- return last_try
157
-
158
- def ui_fn(prompt):
159
- names = [generate_name(prompt) for _ in range(3)]
160
- return "\\n".join(names)
161
-
162
- demo = gr.Interface(fn=ui_fn, inputs="text", outputs="text", title="Fantasy Name Generator",
163
- description="Enter a character or world prompt to generate fantasy names.")
164
-
165
- if __name__ == "__main__":
166
- demo.launch()
 
1
  import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
+ import re, unicodedata, random, math
5
+ from pathlib import Path
 
 
 
6
 
7
+ # === Constants and Config ===
8
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
9
  SEED = 1337
 
10
  torch.manual_seed(SEED)
11
+ random.seed(SEED)
 
 
 
 
 
 
12
 
13
+ # === Load Checkpoint ===
14
  ckpt = torch.load("kaos.pt", map_location=DEVICE)
15
+ stoi, itos = ckpt['stoi'], ckpt['itos']
16
+ SPECIAL = ['<pad>', '<bos>', '<eos>', '<sep>']
17
+ PAD, BOS, EOS, SEP = [stoi[s] for s in SPECIAL]
18
  VOCAB_SIZE = len(itos)
19
+ MAX_LEN = 128 # match training
20
 
21
+ # === Model ===
22
  class GPTSmall(nn.Module):
23
  def __init__(self, vocab_size, d_model=256, n_head=8, n_layer=4, dropout=0.2, max_len=MAX_LEN):
24
  super().__init__()
 
32
 
33
  def forward(self, x):
34
  B, T = x.shape
35
+ tok = self.tok_emb(x) + self.pos_emb[:, :T]
 
36
  mask = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), 1)
37
  for blk in self.blocks:
38
  tok = blk(tok, src_key_padding_mask=(x == PAD), src_mask=mask)
39
+ return self.head(self.norm(tok))
 
40
 
41
  model = GPTSmall(VOCAB_SIZE).to(DEVICE)
42
+ model.load_state_dict(ckpt['model'])
43
  model.eval()
44
 
45
+ # === Utility ===
46
  def proper_case(text):
47
+ return re.sub(r"\b(of|the|and|in|on|a)\b", lambda m: m.group(0).lower(), text.title())
48
 
49
  def clean_name(text, title_case=True, max_repeats=2):
50
  text = unicodedata.normalize("NFC", text)
51
+ text = re.sub(r'(.)\1{2,}', lambda m: m.group(1) * max_repeats, text)
52
+ text = re.sub(r"鈥橲|\'S", "'s", text)
53
+ text = re.sub(r"[^0-9A-Za-z脌-脰脴-枚酶-每'鈥橽-\s]", "", text)
54
+ text = re.sub(r"\s+", " ", text).strip()
55
  if title_case:
56
  text = proper_case(text)
57
+ text = re.sub(r'\b(The|Of|In|On|A)\s+\1\b', r'\1', text, flags=re.IGNORECASE)
58
+ return re.sub(r"([a-zA-Z])'S\b", lambda m: m.group(1) + "'s", text)
 
 
 
 
59
 
60
+ def sample_once(prompt, temperature=1.0, top_k=40, max_new=24):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  seq = [BOS] + [stoi.get(c, PAD) for c in prompt] + [SEP]
62
+ for _ in range(max_new):
63
+ x = torch.tensor(seq[-MAX_LEN:], dtype=torch.long, device=DEVICE)[None]
64
+ with torch.no_grad():
65
  logits = model(x)[:, -1, :] / temperature
66
+ if top_k:
67
+ v, i = torch.topk(logits, top_k)
68
+ idx = i[0, torch.softmax(v, -1).multinomial(1)].item()
69
+ else:
70
+ idx = torch.softmax(logits, -1).multinomial(1).item()
71
+ if idx == EOS:
 
 
 
 
 
 
 
 
 
72
  break
73
+ seq.append(idx)
74
+
75
+ name = ''.join(itos[i] for i in seq if i not in {BOS, SEP, EOS, PAD})
76
+ return clean_name(name)
77
+
78
+ # === Gradio UI ===
79
+ def generate_ui(prompt, temperature, top_k, count):
80
+ results = []
81
+ for _ in range(count):
82
+ name = sample_once(prompt, temperature=temperature, top_k=top_k)
83
+ results.append(name)
84
+ return "\n".join(results)
85
+
86
+ description = """馃幁 **Fantasy Name Generator**
87
+ Give it a prompt like `a forgotten warrior king` or `mistress of the black swamp` and it'll generate creative fantasy-style names.
88
+ This model is trained from scratch and runs entirely on PyTorch."""
89
+
90
+ with gr.Blocks() as demo:
91
+ gr.Markdown(description)
92
+ with gr.Row():
93
+ prompt = gr.Textbox(label="Prompt", placeholder="e.g. 'a villain who whispers to shadows'", lines=1)
94
+ with gr.Row():
95
+ temperature = gr.Slider(0.1, 1.5, step=0.1, value=1.0, label="Temperature")
96
+ top_k = gr.Slider(10, 100, step=10, value=40, label="Top-K")
97
+ count = gr.Slider(1, 5, step=1, value=3, label="Names to Generate")
98
+ generate_btn = gr.Button("Generate Names")
99
+ output = gr.Textbox(label="Generated Names", lines=5)
100
+ generate_btn.click(fn=generate_ui, inputs=[prompt, temperature, top_k, count], outputs=output)
101
+
102
+ demo.launch()