JohanBeytell commited on
Commit
fc01c21
·
verified ·
1 Parent(s): 94368f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -32
app.py CHANGED
@@ -6,23 +6,23 @@ from pathlib import Path
6
  import pandas as pd
7
  import tempfile
8
  import time
 
9
 
10
- # === Constants and Config ===
11
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
12
  SEED = 1337
13
  torch.manual_seed(SEED)
14
  random.seed(SEED)
15
 
16
- # === Load Checkpoint ===
 
 
17
  ckpt = torch.load("kaos.pt", map_location=DEVICE)
18
  stoi, itos = ckpt['stoi'], ckpt['itos']
19
  SPECIAL = ['<pad>', '<bos>', '<eos>', '<sep>']
20
  PAD, BOS, EOS, SEP = [stoi[s] for s in SPECIAL]
21
  VOCAB_SIZE = len(itos)
22
  MAX_LEN = 128
23
- total_retry_count = 0
24
 
25
- # === Model ===
26
  class GPTSmall(nn.Module):
27
  def __init__(self, vocab_size, d_model=256, n_head=8, n_layer=4, dropout=0.2, max_len=MAX_LEN):
28
  super().__init__()
@@ -46,7 +46,7 @@ model = GPTSmall(VOCAB_SIZE).to(DEVICE)
46
  model.load_state_dict(ckpt['model'])
47
  model.eval()
48
 
49
- # === Utility ===
50
  def proper_case(text):
51
  return re.sub(r"\b(of|the|and|in|on|a)\b", lambda m: m.group(0).lower(), text.title())
52
 
@@ -62,7 +62,6 @@ def clean_name(text, title_case=True, max_repeats=2):
62
  return re.sub(r"([a-zA-Z])'S\b", lambda m: m.group(1) + "'s", text)
63
 
64
  def sample_once(prompt, temperature=1.0, top_k=40, max_new=40):
65
- sample_i = time.time()
66
  seq = [BOS] + [stoi.get(c, PAD) for c in prompt] + [SEP]
67
  for _ in range(max_new):
68
  x = torch.tensor(seq[-MAX_LEN:], dtype=torch.long, device=DEVICE)[None]
@@ -77,56 +76,44 @@ def sample_once(prompt, temperature=1.0, top_k=40, max_new=40):
77
  break
78
  seq.append(idx)
79
  generated = [itos[i] for i in seq if i not in {BOS, SEP, EOS, PAD}]
80
- print(f"Generated token IDs: {seq}")
81
  name = ''.join(generated).replace(prompt, "").strip()
82
- print(f"Sample took: {time.time() - sample_i:.2f}s")
83
  return clean_name(name)
84
 
85
- # === Generation Function ===
86
  def generate_names(prompt, temperature, top_k, count, retries):
87
- global total_retry_count
88
  prompt = prompt.strip()
89
  if not prompt:
90
  raise gr.Error("Prompt cannot be empty.")
91
  if len(prompt) > 64:
92
  raise gr.Error("Prompt is too long. Please keep it under 64 characters.")
93
-
94
  results = []
 
 
 
95
  for _ in range(count):
96
  for attempt in range(retries):
97
- print("Retrying generation...")
98
- total_retry_count = total_retry_count + 1
99
  name = sample_once(prompt, temperature=temperature, top_k=top_k)
 
100
  if len(name) >= 3:
101
  results.append({"Generated Name": name})
102
  break
 
 
 
103
  df = pd.DataFrame(results)
104
  file_path = tempfile.NamedTemporaryFile(delete=False, suffix=".txt").name
105
  df.to_csv(file_path, index=False, header=False)
106
- print(f"Total retries: {total_retry_count}")
107
- return df, file_path
 
108
 
109
- # === UI ===
110
  description = """# KaosGen: A Fantasy Name Generator
111
  `Kaos` is a small GPT-style transformer (~890k parameters) trained from scratch using character-level tokenization.
112
  It excels at fantasy and mythic naming conventions.
113
-
114
- Give it a prompt like `'a forgotten warrior king'`, `'priestess of the dusk sea'`, or `'demon of frost'`.
115
- It will generate names for characters, gods, factions, or places.
116
-
117
- ### ⚠️ Disclaimers
118
- - This model may occasionally produce inaccurate, inappropriate, or nonsensical results.
119
- - It is a fantasy tool and **not intended for general-purpose language tasks**.
120
- - The creators are not responsible for any weirdness it spits out. Use responsibly.
121
  """
122
 
123
- examples = [
124
- ["a forgotten warrior king"],
125
- ["queen of the shattered realm"],
126
- ["blacksmith of shadows"],
127
- ["titan of the blazing sky"],
128
- ["a blade that burns through matter"]
129
- ]
130
 
131
  with gr.Blocks() as demo:
132
  gr.Markdown(description)
@@ -141,8 +128,9 @@ with gr.Blocks() as demo:
141
  with gr.Column():
142
  output = gr.Dataframe(headers=["Generated Name"], datatype="str", label="Generated Names", interactive=False)
143
  download = gr.File(label="📥 Export Names as .txt")
 
144
 
145
- generate_btn.click(fn=generate_names, inputs=[prompt, temperature, top_k, count, retries], outputs=[output, download])
146
  gr.Examples(examples=examples, inputs=prompt)
147
 
148
  demo.launch()
 
6
  import pandas as pd
7
  import tempfile
8
  import time
9
+ import os
10
 
 
11
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
12
  SEED = 1337
13
  torch.manual_seed(SEED)
14
  random.seed(SEED)
15
 
16
+ # Log model load details
17
+ print(f"📦 Model loading on: {DEVICE}")
18
+
19
  ckpt = torch.load("kaos.pt", map_location=DEVICE)
20
  stoi, itos = ckpt['stoi'], ckpt['itos']
21
  SPECIAL = ['<pad>', '<bos>', '<eos>', '<sep>']
22
  PAD, BOS, EOS, SEP = [stoi[s] for s in SPECIAL]
23
  VOCAB_SIZE = len(itos)
24
  MAX_LEN = 128
 
25
 
 
26
  class GPTSmall(nn.Module):
27
  def __init__(self, vocab_size, d_model=256, n_head=8, n_layer=4, dropout=0.2, max_len=MAX_LEN):
28
  super().__init__()
 
46
  model.load_state_dict(ckpt['model'])
47
  model.eval()
48
 
49
+ # === Utility Functions ===
50
  def proper_case(text):
51
  return re.sub(r"\b(of|the|and|in|on|a)\b", lambda m: m.group(0).lower(), text.title())
52
 
 
62
  return re.sub(r"([a-zA-Z])'S\b", lambda m: m.group(1) + "'s", text)
63
 
64
  def sample_once(prompt, temperature=1.0, top_k=40, max_new=40):
 
65
  seq = [BOS] + [stoi.get(c, PAD) for c in prompt] + [SEP]
66
  for _ in range(max_new):
67
  x = torch.tensor(seq[-MAX_LEN:], dtype=torch.long, device=DEVICE)[None]
 
76
  break
77
  seq.append(idx)
78
  generated = [itos[i] for i in seq if i not in {BOS, SEP, EOS, PAD}]
 
79
  name = ''.join(generated).replace(prompt, "").strip()
 
80
  return clean_name(name)
81
 
 
82
  def generate_names(prompt, temperature, top_k, count, retries):
 
83
  prompt = prompt.strip()
84
  if not prompt:
85
  raise gr.Error("Prompt cannot be empty.")
86
  if len(prompt) > 64:
87
  raise gr.Error("Prompt is too long. Please keep it under 64 characters.")
88
+
89
  results = []
90
+ rejected = []
91
+ retry_count = 0
92
+
93
  for _ in range(count):
94
  for attempt in range(retries):
 
 
95
  name = sample_once(prompt, temperature=temperature, top_k=top_k)
96
+ retry_count += 1
97
  if len(name) >= 3:
98
  results.append({"Generated Name": name})
99
  break
100
+ else:
101
+ rejected.append(name)
102
+
103
  df = pd.DataFrame(results)
104
  file_path = tempfile.NamedTemporaryFile(delete=False, suffix=".txt").name
105
  df.to_csv(file_path, index=False, header=False)
106
+
107
+ retry_report = f"Total Retries: {retry_count - len(results)}\n\nRejected Candidates:\n" + '\n'.join(rejected or ["None"])
108
+ return df, file_path, retry_report
109
 
110
+ # === Gradio UI ===
111
  description = """# KaosGen: A Fantasy Name Generator
112
  `Kaos` is a small GPT-style transformer (~890k parameters) trained from scratch using character-level tokenization.
113
  It excels at fantasy and mythic naming conventions.
 
 
 
 
 
 
 
 
114
  """
115
 
116
+ examples = [["a forgotten warrior king"], ["queen of the shattered realm"], ["blacksmith of shadows"], ["titan of the blazing sky"], ["a blade that burns through matter"]]
 
 
 
 
 
 
117
 
118
  with gr.Blocks() as demo:
119
  gr.Markdown(description)
 
128
  with gr.Column():
129
  output = gr.Dataframe(headers=["Generated Name"], datatype="str", label="Generated Names", interactive=False)
130
  download = gr.File(label="📥 Export Names as .txt")
131
+ retry_report = gr.Textbox(label="Debug Info: Retries & Rejected Names", lines=6, interactive=False)
132
 
133
+ generate_btn.click(fn=generate_names, inputs=[prompt, temperature, top_k, count, retries], outputs=[output, download, retry_report])
134
  gr.Examples(examples=examples, inputs=prompt)
135
 
136
  demo.launch()