JohanBeytell commited on
Commit
4f0da9b
verified
1 Parent(s): e085329

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +166 -0
app.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import math
5
+ import re
6
+ import unicodedata
7
+ import random
8
+ import os
9
+
10
+ # --- Load constants and model ---
11
+ SEED = 1337
12
+ random.seed(SEED)
13
+ torch.manual_seed(SEED)
14
+ torch.cuda.manual_seed_all(SEED)
15
+
16
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
17
+ MAX_LEN = 128
18
+
19
+ SPECIAL = ['<pad>', '<bos>', '<eos>', '<sep>']
20
+ BOS, EOS, PAD, SEP = 1, 2, 0, 3
21
+
22
+ # Load vocab
23
+ ckpt = torch.load("kaos.pt", map_location=DEVICE)
24
+ stoi = ckpt["stoi"]
25
+ itos = ckpt["itos"]
26
+ VOCAB_SIZE = len(itos)
27
+
28
+ class GPTSmall(nn.Module):
29
+ def __init__(self, vocab_size, d_model=256, n_head=8, n_layer=4, dropout=0.2, max_len=MAX_LEN):
30
+ super().__init__()
31
+ self.tok_emb = nn.Embedding(vocab_size, d_model)
32
+ self.pos_emb = nn.Parameter(torch.zeros(1, max_len, d_model))
33
+ nn.init.trunc_normal_(self.pos_emb, std=0.02)
34
+ block = nn.TransformerEncoderLayer(d_model, n_head, d_model * 4, dropout=dropout, batch_first=True)
35
+ self.blocks = nn.ModuleList([block for _ in range(n_layer)])
36
+ self.norm = nn.LayerNorm(d_model)
37
+ self.head = nn.Linear(d_model, vocab_size, bias=False)
38
+
39
+ def forward(self, x):
40
+ B, T = x.shape
41
+ tok = self.tok_emb(x)
42
+ tok = tok + self.pos_emb[:, :T]
43
+ mask = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), 1)
44
+ for blk in self.blocks:
45
+ tok = blk(tok, src_key_padding_mask=(x == PAD), src_mask=mask)
46
+ tok = self.norm(tok)
47
+ return self.head(tok)
48
+
49
+ model = GPTSmall(VOCAB_SIZE).to(DEVICE)
50
+ model.load_state_dict(ckpt["model"])
51
+ model.eval()
52
+
53
+ # --- Clean + scoring ---
54
+ def proper_case(text):
55
+ return re.sub(r"\\b(of|the|and|in|on|a)\\b", lambda m: m.group(0).lower(), text.title())
56
+
57
+ def clean_name(text, title_case=True, max_repeats=2):
58
+ text = unicodedata.normalize("NFC", text)
59
+ text = re.sub(r"(.)\\1{2,}", lambda m: m.group(1) * max_repeats, text, flags=re.IGNORECASE)
60
+ text = re.sub(r"鈥橲|\\'S", "'s", text)
61
+ text = re.sub(r"[^0-9A-Za-z脌-脰脴-枚酶-每'鈥橽\-\\s]", "", text)
62
+ text = re.sub(r"\\s+", " ", text).strip()
63
+ if title_case:
64
+ text = proper_case(text)
65
+ text = re.sub(r"\\b(The|Of|In|On|A)\\s+\\1\\b", r"\\1", text, flags=re.IGNORECASE)
66
+ text = re.sub(r"([a-zA-Z])'S\\b", lambda m: m.group(1) + "'s", text)
67
+ return text
68
+
69
+ def has_weird_word_lengths(name, min_len=3, max_len=24):
70
+ return any(len(word) < min_len or len(word) > max_len for word in name.split())
71
+
72
+ def gibberish_score(name):
73
+ common_tris = {"the", "and", "ing", "ion", "ent", "ati", "for", "her", "ter", "tha", "ere", "nth", "tio", "ver",
74
+ "his", "hat", "ers", "rea", "all", "ill", "ari", "est", "oth", "eve", "eld", "sky", "dra", "sha", "mir"}
75
+ text = name.lower().replace(" ", "")
76
+ trigrams = [text[i:i+3] for i in range(len(text) - 2)]
77
+ if not trigrams:
78
+ return 1.0
79
+ bad = sum(1 for tri in trigrams if tri not in common_tris)
80
+ return bad / len(trigrams)
81
+
82
+ def pronounceability_score(name):
83
+ name = name.lower()
84
+ name = re.sub(r"[^a-z]", "", name)
85
+ if not name: return 0.0
86
+ vowels = "aeiouy"
87
+ v_count = sum(1 for c in name if c in vowels)
88
+ c_count = sum(1 for c in name if c not in vowels)
89
+ vc_ratio = v_count / (c_count + 1)
90
+ cluster_penalty = len(re.findall(r'[^aeiouy]{3,}', name)) * 0.1
91
+ alternation = re.findall(r'[aeiouy]+|[^aeiouy]+', name)
92
+ smoothness = len(alternation) / len(name)
93
+ score = (vc_ratio * 0.6) + (smoothness * 0.6) - cluster_penalty
94
+ return max(0.0, min(score, 1.0))
95
+
96
+ def has_duplicate_articles(name):
97
+ return bool(re.search(r'\\b(the|of|in|on|a)\\s+\\1\\b', name, flags=re.IGNORECASE))
98
+
99
+ def is_problematic(name):
100
+ return (
101
+ re.search(r'\\b(the the|of of|in in)\\b', name.lower()) or
102
+ (name.count(' ') == 0 and len(name) < 5) or
103
+ len(re.findall(r'[bcdfghjklmnpqrstvwxyz]{5,}', name.lower())) > 0
104
+ )
105
+
106
+ def is_too_weird(name):
107
+ return (
108
+ any(len(w) > 14 for w in name.split()) or
109
+ re.search(r"[bcdfghjklmnpqrstvwxyz]{5,}", name.lower())
110
+ )
111
+
112
+ def _sample_once(prompt, max_new=24, temperature=1.0, top_k=40):
113
+ seq = [BOS] + [stoi.get(c, PAD) for c in prompt] + [SEP]
114
+ with torch.no_grad():
115
+ for _ in range(max_new):
116
+ x = torch.tensor(seq[-MAX_LEN:], dtype=torch.long, device=DEVICE).unsqueeze(0)
117
+ logits = model(x)[:, -1, :] / temperature
118
+ if top_k:
119
+ v, i = torch.topk(logits, top_k)
120
+ idx = i[0, torch.softmax(v, -1).multinomial(1)].item()
121
+ else:
122
+ idx = torch.softmax(logits, -1).multinomial(1).item()
123
+ if idx == EOS or itos[idx] == "</s>":
124
+ break
125
+ seq.append(idx)
126
+ try:
127
+ start = seq.index(SEP) + 1
128
+ except ValueError:
129
+ start = 0
130
+ decoded = []
131
+ for idx in seq[start:]:
132
+ if idx == EOS or itos[idx] == "</s>":
133
+ break
134
+ if idx != PAD:
135
+ decoded.append(itos[idx])
136
+ return ''.join(decoded).strip()
137
+
138
+ def generate_name(prompt, min_chars=4, min_words=1, min_score=0.55, max_retries=3, temperature=1.0, temp_decay=0.85, max_gibberish=0.5):
139
+ last_try = ""
140
+ for attempt in range(max_retries):
141
+ temp = temperature * (temp_decay ** attempt)
142
+ raw = _sample_once(prompt, temperature=temp)
143
+ name = clean_name(raw)
144
+ last_try = name
145
+ score = pronounceability_score(name)
146
+ gibber = gibberish_score(name)
147
+ has_dupes = has_duplicate_articles(name)
148
+ weird_words = has_weird_word_lengths(name)
149
+ good = (
150
+ len(name) >= min_chars and len(name.split()) >= min_words and
151
+ score >= min_score and gibber <= max_gibberish and
152
+ not has_dupes and not weird_words
153
+ )
154
+ if good and not is_too_weird(name) and not is_problematic(name):
155
+ return name
156
+ return last_try
157
+
158
+ def ui_fn(prompt):
159
+ names = [generate_name(prompt) for _ in range(3)]
160
+ return "\\n".join(names)
161
+
162
+ demo = gr.Interface(fn=ui_fn, inputs="text", outputs="text", title="Fantasy Name Generator",
163
+ description="Enter a character or world prompt to generate fantasy names.")
164
+
165
+ if __name__ == "__main__":
166
+ demo.launch()