Spaces:
Runtime error
Runtime error
Commit ·
70d2f66
1
Parent(s): 7e62304
adds reverse model and feature that allows you to pick last letters in the name
Browse files
app.py
CHANGED
|
@@ -34,6 +34,14 @@ gpt_micro_weights_path = huggingface_hub.hf_hub_download(
|
|
| 34 |
"jefsnacker/surname_generator",
|
| 35 |
"micro_gpt_weights.pt")
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
with open(mlp_config_path, 'r') as file:
|
| 38 |
mlp_config = yaml.safe_load(file)
|
| 39 |
|
|
@@ -43,6 +51,9 @@ with open(wavenet_config_path, 'r') as file:
|
|
| 43 |
with open(gpt_micro_config_path, 'r') as file:
|
| 44 |
gpt_micro_config = yaml.safe_load(file)
|
| 45 |
|
|
|
|
|
|
|
|
|
|
| 46 |
##################################################################################
|
| 47 |
## MLP
|
| 48 |
##################################################################################
|
|
@@ -314,31 +325,48 @@ gpt_micro = GPT(gpt_micro_config)
|
|
| 314 |
gpt_micro.load_state_dict(torch.load(gpt_micro_weights_path))
|
| 315 |
gpt_micro.eval()
|
| 316 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
##################################################################################
|
| 318 |
## Gradio App
|
| 319 |
##################################################################################
|
| 320 |
|
| 321 |
-
def generate_names(name_start, number_of_names, model):
|
| 322 |
if model == "MLP":
|
| 323 |
-
|
| 324 |
-
window = mlp_config['window']
|
| 325 |
elif model == "WaveNet":
|
| 326 |
-
|
| 327 |
-
window = wavenet_config['window']
|
| 328 |
elif model == "GPT Micro":
|
| 329 |
-
|
| 330 |
-
|
|
|
|
| 331 |
else:
|
| 332 |
raise Exception("Model not selected")
|
| 333 |
-
|
|
|
|
| 334 |
itos = {s:i for i,s in stoi.items()}
|
| 335 |
|
| 336 |
names = ""
|
| 337 |
for _ in range((int)(number_of_names)):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
|
| 339 |
# Initialize name with user input
|
| 340 |
-
name = ""
|
| 341 |
-
context = [0] * window
|
| 342 |
for c in name_start.lower():
|
| 343 |
name += c
|
| 344 |
context = context[1:] + [stoi[c]]
|
|
@@ -352,6 +380,8 @@ def generate_names(name_start, number_of_names, model):
|
|
| 352 |
ix = wavenet.sample_char(x)
|
| 353 |
elif model == "GPT Micro":
|
| 354 |
ix = gpt_micro.sample_char(x)
|
|
|
|
|
|
|
| 355 |
else:
|
| 356 |
raise Exception("Model not selected")
|
| 357 |
|
|
@@ -369,8 +399,9 @@ demo = gr.Interface(
|
|
| 369 |
fn=generate_names,
|
| 370 |
inputs=[
|
| 371 |
gr.Textbox(placeholder="Start name with..."),
|
|
|
|
| 372 |
gr.Number(value=5),
|
| 373 |
-
gr.Dropdown(["MLP", "WaveNet", "GPT Micro"], value="GPT
|
| 374 |
],
|
| 375 |
outputs="text",
|
| 376 |
)
|
|
|
|
| 34 |
"jefsnacker/surname_generator",
|
| 35 |
"micro_gpt_weights.pt")
|
| 36 |
|
| 37 |
+
gpt_rev_config_path = huggingface_hub.hf_hub_download(
|
| 38 |
+
"jefsnacker/surname_generator",
|
| 39 |
+
"rev_gpt_config.yaml")
|
| 40 |
+
|
| 41 |
+
gpt_rev_weights_path = huggingface_hub.hf_hub_download(
|
| 42 |
+
"jefsnacker/surname_generator",
|
| 43 |
+
"rev_gpt_weights.pt")
|
| 44 |
+
|
| 45 |
with open(mlp_config_path, 'r') as file:
|
| 46 |
mlp_config = yaml.safe_load(file)
|
| 47 |
|
|
|
|
| 51 |
with open(gpt_micro_config_path, 'r') as file:
|
| 52 |
gpt_micro_config = yaml.safe_load(file)
|
| 53 |
|
| 54 |
+
with open(gpt_rev_config_path, 'r') as file:
|
| 55 |
+
gpt_rev_config = yaml.safe_load(file)
|
| 56 |
+
|
| 57 |
##################################################################################
|
| 58 |
## MLP
|
| 59 |
##################################################################################
|
|
|
|
| 325 |
gpt_micro.load_state_dict(torch.load(gpt_micro_weights_path))
|
| 326 |
gpt_micro.eval()
|
| 327 |
|
| 328 |
+
gpt_rev = GPT(gpt_rev_config)
|
| 329 |
+
gpt_rev.load_state_dict(torch.load(gpt_rev_weights_path))
|
| 330 |
+
gpt_rev.eval()
|
| 331 |
+
|
| 332 |
##################################################################################
|
| 333 |
## Gradio App
|
| 334 |
##################################################################################
|
| 335 |
|
| 336 |
+
def generate_names(name_start, name_end, number_of_names, model):
|
| 337 |
if model == "MLP":
|
| 338 |
+
config = mlp_config
|
|
|
|
| 339 |
elif model == "WaveNet":
|
| 340 |
+
config = wavenet_config
|
|
|
|
| 341 |
elif model == "GPT Micro":
|
| 342 |
+
config = gpt_micro_config
|
| 343 |
+
elif model == "GPT Rev":
|
| 344 |
+
config = gpt_rev_config
|
| 345 |
else:
|
| 346 |
raise Exception("Model not selected")
|
| 347 |
+
|
| 348 |
+
stoi = config['stoi']
|
| 349 |
itos = {s:i for i,s in stoi.items()}
|
| 350 |
|
| 351 |
names = ""
|
| 352 |
for _ in range((int)(number_of_names)):
|
| 353 |
+
name = ""
|
| 354 |
+
context = [0] * config['window']
|
| 355 |
+
|
| 356 |
+
if "num_final_chars_in_dataset" in config:
|
| 357 |
+
# Put final chars in context
|
| 358 |
+
if len(name_end) > config["num_final_chars_in_dataset"]:
|
| 359 |
+
name_end = name_end[-config["num_final_chars_in_dataset"]:]
|
| 360 |
+
print("Only accepts up to " + str(config["num_final_chars_in_dataset"]) + " final chars. Using: " + name_end)
|
| 361 |
+
|
| 362 |
+
for c in name_end:
|
| 363 |
+
context = context[1:] + [stoi[c]]
|
| 364 |
+
context = context[1:] + [stoi['.']]
|
| 365 |
+
|
| 366 |
+
elif (name_end != ""):
|
| 367 |
+
print("Final chars not used. Need to use a model trained with this feature.")
|
| 368 |
|
| 369 |
# Initialize name with user input
|
|
|
|
|
|
|
| 370 |
for c in name_start.lower():
|
| 371 |
name += c
|
| 372 |
context = context[1:] + [stoi[c]]
|
|
|
|
| 380 |
ix = wavenet.sample_char(x)
|
| 381 |
elif model == "GPT Micro":
|
| 382 |
ix = gpt_micro.sample_char(x)
|
| 383 |
+
elif model == "GPT Rev":
|
| 384 |
+
ix = gpt_rev.sample_char(x)
|
| 385 |
else:
|
| 386 |
raise Exception("Model not selected")
|
| 387 |
|
|
|
|
| 399 |
fn=generate_names,
|
| 400 |
inputs=[
|
| 401 |
gr.Textbox(placeholder="Start name with..."),
|
| 402 |
+
gr.Textbox(placeholder="End name with... (only works for rev model)"),
|
| 403 |
gr.Number(value=5),
|
| 404 |
+
gr.Dropdown(["MLP", "WaveNet", "GPT Micro", "GPT Rev"], value="GPT Rev"),
|
| 405 |
],
|
| 406 |
outputs="text",
|
| 407 |
)
|