JohanBeytell commited on
Commit
86f9dff
·
verified ·
1 Parent(s): 8463652

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -0
app.py CHANGED
@@ -5,6 +5,7 @@ import re, unicodedata, random
5
  from pathlib import Path
6
  import pandas as pd
7
  import tempfile
 
8
 
9
  # === Constants and Config ===
10
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
@@ -60,6 +61,7 @@ def clean_name(text, title_case=True, max_repeats=2):
60
  return re.sub(r"([a-zA-Z])'S\b", lambda m: m.group(1) + "'s", text)
61
 
62
  def sample_once(prompt, temperature=1.0, top_k=40, max_new=24):
 
63
  seq = [BOS] + [stoi.get(c, PAD) for c in prompt] + [SEP]
64
  for _ in range(max_new):
65
  x = torch.tensor(seq[-MAX_LEN:], dtype=torch.long, device=DEVICE)[None]
@@ -75,6 +77,7 @@ def sample_once(prompt, temperature=1.0, top_k=40, max_new=24):
75
  seq.append(idx)
76
  generated = [itos[i] for i in seq if i not in {BOS, SEP, EOS, PAD}]
77
  name = ''.join(generated).replace(prompt, "").strip()
 
78
  return clean_name(name)
79
 
80
  # === Generation Function ===
 
5
  from pathlib import Path
6
  import pandas as pd
7
  import tempfile
8
+ import time
9
 
10
  # === Constants and Config ===
11
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
 
61
  return re.sub(r"([a-zA-Z])'S\b", lambda m: m.group(1) + "'s", text)
62
 
63
  def sample_once(prompt, temperature=1.0, top_k=40, max_new=24):
64
+ sample_i = time.time()
65
  seq = [BOS] + [stoi.get(c, PAD) for c in prompt] + [SEP]
66
  for _ in range(max_new):
67
  x = torch.tensor(seq[-MAX_LEN:], dtype=torch.long, device=DEVICE)[None]
 
77
  seq.append(idx)
78
  generated = [itos[i] for i in seq if i not in {BOS, SEP, EOS, PAD}]
79
  name = ''.join(generated).replace(prompt, "").strip()
80
+ print(f"Sample took: {time.time() - sample_start:.2f}s")
81
  return clean_name(name)
82
 
83
  # === Generation Function ===