jflo commited on
Commit
86a34fe
·
1 Parent(s): 1e92b51

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -13
app.py CHANGED
@@ -6,7 +6,7 @@ import gradio as gr
6
  all_letters = string.ascii_letters + " .,;'"
7
  n_letters = len(all_letters)
8
 
9
- all_categories = ['Arabic','Chinese','Czech','Dutch','English','French','German','Greek',
10
  'Irish','Italian','Japanese','Korean','Polish','Portuguese','Russian','Scottish',
11
  'Spanish','Vietnamese']
12
 
@@ -35,28 +35,26 @@ def evaluate(line_tensor):
35
  return output
36
 
37
  # Feeding in a name and number of top predictions you want to output
38
- def predict(last_name,n_predictions=3):
39
 
40
  last_name = last_name.title()
41
  with torch.no_grad():
42
  output = evaluate(lineToTensor(last_name))
43
  output = F.softmax(output,dim=1)
44
 
45
- topv,topi = output.topk(n_predictions,1,True)
46
 
47
- top_3_countries = ''
48
- for i in range(n_predictions):
49
- value = topv[0]
50
- category_index = topi[0][i].item()
51
- top_3_countries += f'{all_categories[category_index]}: {round(value[i].item()*100,2)}% confident'
52
- top_3_countries += '\n'
53
- return top_3_countries
54
-
55
- demo = gr.Interface(predict,
56
  inputs = "text",
57
  outputs = "text",
58
  title = "Classify Last Name :)",
59
- description="Classifies last name into one of 18 language of origin. Returns top 3 languages of origin"
60
  )
61
 
62
  demo.launch(inline=False)
 
6
  all_letters = string.ascii_letters + " .,;'"
7
  n_letters = len(all_letters)
8
 
9
+ class_names = ['Arabic','Chinese','Czech','Dutch','English','French','German','Greek',
10
  'Irish','Italian','Japanese','Korean','Polish','Portuguese','Russian','Scottish',
11
  'Spanish','Vietnamese']
12
 
 
35
  return output
36
 
37
  # Feeding in a name and number of top predictions you want to output
38
+ def classify_lastname(last_name,n_predictions=3):
39
 
40
  last_name = last_name.title()
41
  with torch.no_grad():
42
  output = evaluate(lineToTensor(last_name))
43
  output = F.softmax(output,dim=1)
44
 
45
+ top3_prob,top3_catid = torch.topk(output,3)
46
 
47
+ model_output = {}
48
+ for i in range(top3_prob.size(1)):
49
+ model_output[class_names[top3_catid[0][i].item()]] = top3_prob[0][i].item()
50
+
51
+ return model_output
52
+
53
+ demo = gr.Interface(classify_lastname,
 
 
54
  inputs = "text",
55
  outputs = "text",
56
  title = "Classify Last Name :)",
57
+ description="Classifies last name into one of 18 language of origin. Returns confidence % for the top three categories"
58
  )
59
 
60
  demo.launch(inline=False)