Spaces:
Runtime error
Runtime error
Commit
·
34a8736
1
Parent(s):
70d2f66
better error handling
Browse files
app.py
CHANGED
|
@@ -334,56 +334,73 @@ gpt_rev.eval()
|
|
| 334 |
##################################################################################
|
| 335 |
|
| 336 |
def generate_names(name_start, name_end, number_of_names, model):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
if model == "MLP":
|
| 338 |
config = mlp_config
|
|
|
|
| 339 |
elif model == "WaveNet":
|
| 340 |
config = wavenet_config
|
|
|
|
| 341 |
elif model == "GPT Micro":
|
| 342 |
config = gpt_micro_config
|
|
|
|
| 343 |
elif model == "GPT Rev":
|
| 344 |
config = gpt_rev_config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
else:
|
| 346 |
-
|
| 347 |
|
| 348 |
stoi = config['stoi']
|
| 349 |
itos = {s:i for i,s in stoi.items()}
|
| 350 |
|
| 351 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
for _ in range((int)(number_of_names)):
|
| 353 |
name = ""
|
| 354 |
context = [0] * config['window']
|
| 355 |
|
| 356 |
if "num_final_chars_in_dataset" in config:
|
| 357 |
-
# Put final chars in context
|
| 358 |
-
if len(name_end) > config["num_final_chars_in_dataset"]:
|
| 359 |
-
name_end = name_end[-config["num_final_chars_in_dataset"]:]
|
| 360 |
-
print("Only accepts up to " + str(config["num_final_chars_in_dataset"]) + " final chars. Using: " + name_end)
|
| 361 |
-
|
| 362 |
for c in name_end:
|
| 363 |
context = context[1:] + [stoi[c]]
|
| 364 |
context = context[1:] + [stoi['.']]
|
| 365 |
-
|
| 366 |
-
elif (name_end != ""):
|
| 367 |
-
print("Final chars not used. Need to use a model trained with this feature.")
|
| 368 |
|
| 369 |
# Initialize name with user input
|
| 370 |
-
for c in name_start
|
| 371 |
name += c
|
| 372 |
context = context[1:] + [stoi[c]]
|
| 373 |
|
| 374 |
# Run inference to finish off the name
|
| 375 |
while True:
|
| 376 |
x = torch.tensor(context).view(1, -1)
|
| 377 |
-
|
| 378 |
-
ix = mlp.sample_char(x)
|
| 379 |
-
elif model == "WaveNet":
|
| 380 |
-
ix = wavenet.sample_char(x)
|
| 381 |
-
elif model == "GPT Micro":
|
| 382 |
-
ix = gpt_micro.sample_char(x)
|
| 383 |
-
elif model == "GPT Rev":
|
| 384 |
-
ix = gpt_rev.sample_char(x)
|
| 385 |
-
else:
|
| 386 |
-
raise Exception("Model not selected")
|
| 387 |
|
| 388 |
context = context[1:] + [ix]
|
| 389 |
name += itos[ix]
|
|
@@ -391,9 +408,9 @@ def generate_names(name_start, name_end, number_of_names, model):
|
|
| 391 |
if ix == 0:
|
| 392 |
break
|
| 393 |
|
| 394 |
-
|
| 395 |
|
| 396 |
-
return
|
| 397 |
|
| 398 |
demo = gr.Interface(
|
| 399 |
fn=generate_names,
|
|
@@ -401,7 +418,7 @@ demo = gr.Interface(
|
|
| 401 |
gr.Textbox(placeholder="Start name with..."),
|
| 402 |
gr.Textbox(placeholder="End name with... (only works for rev model)"),
|
| 403 |
gr.Number(value=5),
|
| 404 |
-
gr.Dropdown(["MLP", "WaveNet", "GPT Micro", "GPT Rev"], value="GPT Rev"),
|
| 405 |
],
|
| 406 |
outputs="text",
|
| 407 |
)
|
|
|
|
| 334 |
##################################################################################
|
| 335 |
|
| 336 |
def generate_names(name_start, name_end, number_of_names, model):
|
| 337 |
+
if number_of_names < 0:
|
| 338 |
+
return "Error: Please enter a positive number of names to generate!"
|
| 339 |
+
|
| 340 |
+
# Select model
|
| 341 |
if model == "MLP":
|
| 342 |
config = mlp_config
|
| 343 |
+
sample_fcn = mlp.sample_char
|
| 344 |
elif model == "WaveNet":
|
| 345 |
config = wavenet_config
|
| 346 |
+
sample_fcn = wavenet.sample_char
|
| 347 |
elif model == "GPT Micro":
|
| 348 |
config = gpt_micro_config
|
| 349 |
+
sample_fcn = gpt_micro.sample_char
|
| 350 |
elif model == "GPT Rev":
|
| 351 |
config = gpt_rev_config
|
| 352 |
+
sample_fcn = gpt_rev.sample_char
|
| 353 |
+
elif model == "GPT First Rev":
|
| 354 |
+
# TODO: Change model!
|
| 355 |
+
config = gpt_rev_config
|
| 356 |
+
sample_fcn = gpt_rev.sample_char
|
| 357 |
else:
|
| 358 |
+
return "Error: Model not selected"
|
| 359 |
|
| 360 |
stoi = config['stoi']
|
| 361 |
itos = {s:i for i,s in stoi.items()}
|
| 362 |
|
| 363 |
+
output = ""
|
| 364 |
+
|
| 365 |
+
# Sanitize user inputs, and append errors to output
|
| 366 |
+
name_end = name_end.lower()
|
| 367 |
+
name_start = name_start.lower()
|
| 368 |
+
|
| 369 |
+
for c in name_end:
|
| 370 |
+
if c not in stoi:
|
| 371 |
+
return "Please change name end. \"" + c + "\" not included in the training set."
|
| 372 |
+
|
| 373 |
+
for c in name_start:
|
| 374 |
+
if c not in stoi:
|
| 375 |
+
return "Please change name start. \"" + c + "\" not included in the training set."
|
| 376 |
+
|
| 377 |
+
if "num_final_chars_in_dataset" in config and len(name_end) > config["num_final_chars_in_dataset"]:
|
| 378 |
+
name_end = name_end[-config["num_final_chars_in_dataset"]:]
|
| 379 |
+
output += "Only accepts up to " + str(config["num_final_chars_in_dataset"]) + " final chars. Using: " + str(name_end) + "\n"
|
| 380 |
+
|
| 381 |
+
elif "num_final_chars_in_dataset" not in config and name_end != "":
|
| 382 |
+
output += "Final chars not used. Need to use a \"Rev\" model trained with this feature.\n"
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
## Print requested names
|
| 386 |
for _ in range((int)(number_of_names)):
|
| 387 |
name = ""
|
| 388 |
context = [0] * config['window']
|
| 389 |
|
| 390 |
if "num_final_chars_in_dataset" in config:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 391 |
for c in name_end:
|
| 392 |
context = context[1:] + [stoi[c]]
|
| 393 |
context = context[1:] + [stoi['.']]
|
|
|
|
|
|
|
|
|
|
| 394 |
|
| 395 |
# Initialize name with user input
|
| 396 |
+
for c in name_start:
|
| 397 |
name += c
|
| 398 |
context = context[1:] + [stoi[c]]
|
| 399 |
|
| 400 |
# Run inference to finish off the name
|
| 401 |
while True:
|
| 402 |
x = torch.tensor(context).view(1, -1)
|
| 403 |
+
ix = sample_fcn(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
|
| 405 |
context = context[1:] + [ix]
|
| 406 |
name += itos[ix]
|
|
|
|
| 408 |
if ix == 0:
|
| 409 |
break
|
| 410 |
|
| 411 |
+
output += name + "\n"
|
| 412 |
|
| 413 |
+
return output
|
| 414 |
|
| 415 |
demo = gr.Interface(
|
| 416 |
fn=generate_names,
|
|
|
|
| 418 |
gr.Textbox(placeholder="Start name with..."),
|
| 419 |
gr.Textbox(placeholder="End name with... (only works for rev model)"),
|
| 420 |
gr.Number(value=5),
|
| 421 |
+
gr.Dropdown(["MLP", "WaveNet", "GPT Micro", "GPT Rev", "GPT First Rev"], value="GPT Rev"),
|
| 422 |
],
|
| 423 |
outputs="text",
|
| 424 |
)
|