Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -6,23 +6,23 @@ from pathlib import Path
|
|
| 6 |
import pandas as pd
|
| 7 |
import tempfile
|
| 8 |
import time
|
|
|
|
| 9 |
|
| 10 |
-
# === Constants and Config ===
|
| 11 |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 12 |
SEED = 1337
|
| 13 |
torch.manual_seed(SEED)
|
| 14 |
random.seed(SEED)
|
| 15 |
|
| 16 |
-
#
|
|
|
|
|
|
|
| 17 |
ckpt = torch.load("kaos.pt", map_location=DEVICE)
|
| 18 |
stoi, itos = ckpt['stoi'], ckpt['itos']
|
| 19 |
SPECIAL = ['<pad>', '<bos>', '<eos>', '<sep>']
|
| 20 |
PAD, BOS, EOS, SEP = [stoi[s] for s in SPECIAL]
|
| 21 |
VOCAB_SIZE = len(itos)
|
| 22 |
MAX_LEN = 128
|
| 23 |
-
total_retry_count = 0
|
| 24 |
|
| 25 |
-
# === Model ===
|
| 26 |
class GPTSmall(nn.Module):
|
| 27 |
def __init__(self, vocab_size, d_model=256, n_head=8, n_layer=4, dropout=0.2, max_len=MAX_LEN):
|
| 28 |
super().__init__()
|
|
@@ -46,7 +46,7 @@ model = GPTSmall(VOCAB_SIZE).to(DEVICE)
|
|
| 46 |
model.load_state_dict(ckpt['model'])
|
| 47 |
model.eval()
|
| 48 |
|
| 49 |
-
# === Utility ===
|
| 50 |
def proper_case(text):
|
| 51 |
return re.sub(r"\b(of|the|and|in|on|a)\b", lambda m: m.group(0).lower(), text.title())
|
| 52 |
|
|
@@ -62,7 +62,6 @@ def clean_name(text, title_case=True, max_repeats=2):
|
|
| 62 |
return re.sub(r"([a-zA-Z])'S\b", lambda m: m.group(1) + "'s", text)
|
| 63 |
|
| 64 |
def sample_once(prompt, temperature=1.0, top_k=40, max_new=40):
|
| 65 |
-
sample_i = time.time()
|
| 66 |
seq = [BOS] + [stoi.get(c, PAD) for c in prompt] + [SEP]
|
| 67 |
for _ in range(max_new):
|
| 68 |
x = torch.tensor(seq[-MAX_LEN:], dtype=torch.long, device=DEVICE)[None]
|
|
@@ -77,56 +76,44 @@ def sample_once(prompt, temperature=1.0, top_k=40, max_new=40):
|
|
| 77 |
break
|
| 78 |
seq.append(idx)
|
| 79 |
generated = [itos[i] for i in seq if i not in {BOS, SEP, EOS, PAD}]
|
| 80 |
-
print(f"Generated token IDs: {seq}")
|
| 81 |
name = ''.join(generated).replace(prompt, "").strip()
|
| 82 |
-
print(f"Sample took: {time.time() - sample_i:.2f}s")
|
| 83 |
return clean_name(name)
|
| 84 |
|
| 85 |
-
# === Generation Function ===
|
| 86 |
def generate_names(prompt, temperature, top_k, count, retries):
|
| 87 |
-
global total_retry_count
|
| 88 |
prompt = prompt.strip()
|
| 89 |
if not prompt:
|
| 90 |
raise gr.Error("Prompt cannot be empty.")
|
| 91 |
if len(prompt) > 64:
|
| 92 |
raise gr.Error("Prompt is too long. Please keep it under 64 characters.")
|
| 93 |
-
|
| 94 |
results = []
|
|
|
|
|
|
|
|
|
|
| 95 |
for _ in range(count):
|
| 96 |
for attempt in range(retries):
|
| 97 |
-
print("Retrying generation...")
|
| 98 |
-
total_retry_count = total_retry_count + 1
|
| 99 |
name = sample_once(prompt, temperature=temperature, top_k=top_k)
|
|
|
|
| 100 |
if len(name) >= 3:
|
| 101 |
results.append({"Generated Name": name})
|
| 102 |
break
|
|
|
|
|
|
|
|
|
|
| 103 |
df = pd.DataFrame(results)
|
| 104 |
file_path = tempfile.NamedTemporaryFile(delete=False, suffix=".txt").name
|
| 105 |
df.to_csv(file_path, index=False, header=False)
|
| 106 |
-
|
| 107 |
-
|
|
|
|
| 108 |
|
| 109 |
-
# === UI ===
|
| 110 |
description = """# KaosGen: A Fantasy Name Generator
|
| 111 |
`Kaos` is a small GPT-style transformer (~890k parameters) trained from scratch using character-level tokenization.
|
| 112 |
It excels at fantasy and mythic naming conventions.
|
| 113 |
-
|
| 114 |
-
Give it a prompt like `'a forgotten warrior king'`, `'priestess of the dusk sea'`, or `'demon of frost'`.
|
| 115 |
-
It will generate names for characters, gods, factions, or places.
|
| 116 |
-
|
| 117 |
-
### ⚠️ Disclaimers
|
| 118 |
-
- This model may occasionally produce inaccurate, inappropriate, or nonsensical results.
|
| 119 |
-
- It is a fantasy tool and **not intended for general-purpose language tasks**.
|
| 120 |
-
- The creators are not responsible for any weirdness it spits out. Use responsibly.
|
| 121 |
"""
|
| 122 |
|
| 123 |
-
examples = [
|
| 124 |
-
["a forgotten warrior king"],
|
| 125 |
-
["queen of the shattered realm"],
|
| 126 |
-
["blacksmith of shadows"],
|
| 127 |
-
["titan of the blazing sky"],
|
| 128 |
-
["a blade that burns through matter"]
|
| 129 |
-
]
|
| 130 |
|
| 131 |
with gr.Blocks() as demo:
|
| 132 |
gr.Markdown(description)
|
|
@@ -141,8 +128,9 @@ with gr.Blocks() as demo:
|
|
| 141 |
with gr.Column():
|
| 142 |
output = gr.Dataframe(headers=["Generated Name"], datatype="str", label="Generated Names", interactive=False)
|
| 143 |
download = gr.File(label="📥 Export Names as .txt")
|
|
|
|
| 144 |
|
| 145 |
-
generate_btn.click(fn=generate_names, inputs=[prompt, temperature, top_k, count, retries], outputs=[output, download])
|
| 146 |
gr.Examples(examples=examples, inputs=prompt)
|
| 147 |
|
| 148 |
demo.launch()
|
|
|
|
| 6 |
import pandas as pd
|
| 7 |
import tempfile
|
| 8 |
import time
|
| 9 |
+
import os
|
| 10 |
|
|
|
|
| 11 |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 12 |
SEED = 1337
|
| 13 |
torch.manual_seed(SEED)
|
| 14 |
random.seed(SEED)
|
| 15 |
|
| 16 |
+
# Log model load details
|
| 17 |
+
print(f"📦 Model loading on: {DEVICE}")
|
| 18 |
+
|
| 19 |
ckpt = torch.load("kaos.pt", map_location=DEVICE)
|
| 20 |
stoi, itos = ckpt['stoi'], ckpt['itos']
|
| 21 |
SPECIAL = ['<pad>', '<bos>', '<eos>', '<sep>']
|
| 22 |
PAD, BOS, EOS, SEP = [stoi[s] for s in SPECIAL]
|
| 23 |
VOCAB_SIZE = len(itos)
|
| 24 |
MAX_LEN = 128
|
|
|
|
| 25 |
|
|
|
|
| 26 |
class GPTSmall(nn.Module):
|
| 27 |
def __init__(self, vocab_size, d_model=256, n_head=8, n_layer=4, dropout=0.2, max_len=MAX_LEN):
|
| 28 |
super().__init__()
|
|
|
|
| 46 |
model.load_state_dict(ckpt['model'])
|
| 47 |
model.eval()
|
| 48 |
|
| 49 |
+
# === Utility Functions ===
|
| 50 |
def proper_case(text):
|
| 51 |
return re.sub(r"\b(of|the|and|in|on|a)\b", lambda m: m.group(0).lower(), text.title())
|
| 52 |
|
|
|
|
| 62 |
return re.sub(r"([a-zA-Z])'S\b", lambda m: m.group(1) + "'s", text)
|
| 63 |
|
| 64 |
def sample_once(prompt, temperature=1.0, top_k=40, max_new=40):
|
|
|
|
| 65 |
seq = [BOS] + [stoi.get(c, PAD) for c in prompt] + [SEP]
|
| 66 |
for _ in range(max_new):
|
| 67 |
x = torch.tensor(seq[-MAX_LEN:], dtype=torch.long, device=DEVICE)[None]
|
|
|
|
| 76 |
break
|
| 77 |
seq.append(idx)
|
| 78 |
generated = [itos[i] for i in seq if i not in {BOS, SEP, EOS, PAD}]
|
|
|
|
| 79 |
name = ''.join(generated).replace(prompt, "").strip()
|
|
|
|
| 80 |
return clean_name(name)
|
| 81 |
|
|
|
|
| 82 |
def generate_names(prompt, temperature, top_k, count, retries):
|
|
|
|
| 83 |
prompt = prompt.strip()
|
| 84 |
if not prompt:
|
| 85 |
raise gr.Error("Prompt cannot be empty.")
|
| 86 |
if len(prompt) > 64:
|
| 87 |
raise gr.Error("Prompt is too long. Please keep it under 64 characters.")
|
| 88 |
+
|
| 89 |
results = []
|
| 90 |
+
rejected = []
|
| 91 |
+
retry_count = 0
|
| 92 |
+
|
| 93 |
for _ in range(count):
|
| 94 |
for attempt in range(retries):
|
|
|
|
|
|
|
| 95 |
name = sample_once(prompt, temperature=temperature, top_k=top_k)
|
| 96 |
+
retry_count += 1
|
| 97 |
if len(name) >= 3:
|
| 98 |
results.append({"Generated Name": name})
|
| 99 |
break
|
| 100 |
+
else:
|
| 101 |
+
rejected.append(name)
|
| 102 |
+
|
| 103 |
df = pd.DataFrame(results)
|
| 104 |
file_path = tempfile.NamedTemporaryFile(delete=False, suffix=".txt").name
|
| 105 |
df.to_csv(file_path, index=False, header=False)
|
| 106 |
+
|
| 107 |
+
retry_report = f"Total Retries: {retry_count - len(results)}\n\nRejected Candidates:\n" + '\n'.join(rejected or ["None"])
|
| 108 |
+
return df, file_path, retry_report
|
| 109 |
|
| 110 |
+
# === Gradio UI ===
|
| 111 |
description = """# KaosGen: A Fantasy Name Generator
|
| 112 |
`Kaos` is a small GPT-style transformer (~890k parameters) trained from scratch using character-level tokenization.
|
| 113 |
It excels at fantasy and mythic naming conventions.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
"""
|
| 115 |
|
| 116 |
+
examples = [["a forgotten warrior king"], ["queen of the shattered realm"], ["blacksmith of shadows"], ["titan of the blazing sky"], ["a blade that burns through matter"]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
with gr.Blocks() as demo:
|
| 119 |
gr.Markdown(description)
|
|
|
|
| 128 |
with gr.Column():
|
| 129 |
output = gr.Dataframe(headers=["Generated Name"], datatype="str", label="Generated Names", interactive=False)
|
| 130 |
download = gr.File(label="📥 Export Names as .txt")
|
| 131 |
+
retry_report = gr.Textbox(label="Debug Info: Retries & Rejected Names", lines=6, interactive=False)
|
| 132 |
|
| 133 |
+
generate_btn.click(fn=generate_names, inputs=[prompt, temperature, top_k, count, retries], outputs=[output, download, retry_report])
|
| 134 |
gr.Examples(examples=examples, inputs=prompt)
|
| 135 |
|
| 136 |
demo.launch()
|