JohanBeytell commited on
Commit
e490385
·
verified ·
1 Parent(s): c289cc9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -0
app.py CHANGED
@@ -20,6 +20,7 @@ SPECIAL = ['<pad>', '<bos>', '<eos>', '<sep>']
20
  PAD, BOS, EOS, SEP = [stoi[s] for s in SPECIAL]
21
  VOCAB_SIZE = len(itos)
22
  MAX_LEN = 128
 
23
 
24
  # === Model ===
25
  class GPTSmall(nn.Module):
@@ -91,6 +92,8 @@ def generate_names(prompt, temperature, top_k, count, retries):
91
  results = []
92
  for _ in range(count):
93
  for attempt in range(retries):
 
 
94
  name = sample_once(prompt, temperature=temperature, top_k=top_k)
95
  if len(name) >= 4:
96
  results.append({"Generated Name": name})
@@ -98,6 +101,7 @@ def generate_names(prompt, temperature, top_k, count, retries):
98
  df = pd.DataFrame(results)
99
  file_path = tempfile.NamedTemporaryFile(delete=False, suffix=".txt").name
100
  df.to_csv(file_path, index=False, header=False)
 
101
  return df, file_path
102
 
103
  # === UI ===
 
20
  PAD, BOS, EOS, SEP = [stoi[s] for s in SPECIAL]
21
  VOCAB_SIZE = len(itos)
22
  MAX_LEN = 128
23
+ total_retry_count = 0
24
 
25
  # === Model ===
26
  class GPTSmall(nn.Module):
 
92
  results = []
93
  for _ in range(count):
94
  for attempt in range(retries):
95
+ print("Retrying generation...")
96
+ total_retry_count++
97
  name = sample_once(prompt, temperature=temperature, top_k=top_k)
98
  if len(name) >= 4:
99
  results.append({"Generated Name": name})
 
101
  df = pd.DataFrame(results)
102
  file_path = tempfile.NamedTemporaryFile(delete=False, suffix=".txt").name
103
  df.to_csv(file_path, index=False, header=False)
104
+ print(f"Total retries: {total_retry_count}")
105
  return df, file_path
106
 
107
  # === UI ===