kanneboinakumar commited on
Commit
7ba0490
·
verified ·
1 Parent(s): 2c25843

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -58
app.py CHANGED
@@ -1,59 +1,59 @@
1
- import streamlit as st
2
- from joblib import load
3
- import unicodedata
4
- import torch
5
- import torch.nn as nn
6
-
7
- # Load Encoder and Model
8
- Encoder = load("Country_Encoder")
9
- all_classes = Encoder.classes_
10
-
11
- st.title("Name Classification Based on Last Name")
12
-
13
- # Define RNN Model
14
- class SimpleRNN(nn.Module):
15
- def __init__(self, input_size, hidden_size, output_size, num_layers):
16
- super(SimpleRNN, self).__init__()
17
- self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
18
- self.dropout = nn.Dropout(0.3)
19
- self.fc = nn.Linear(hidden_size, output_size)
20
-
21
- def forward(self, x):
22
- out, _ = self.rnn(x)
23
- out = self.dropout(out[:, -1, :]) # Take last time step output
24
- out = self.fc(out)
25
- return out
26
-
27
- # Load Model Weights
28
- model1 = SimpleRNN(1, 512, len(all_classes), 1) # input_size should be 1 for ASCII values
29
- model1.load_state_dict(torch.load("rnn_model.pth", map_location=torch.device('cpu')))
30
- model1.eval()
31
-
32
- # Text Input for Name
33
- name = st.text_input("Enter Last Name")
34
-
35
- # Convert Unicode to ASCII
36
- def unicode_to_ascii(s):
37
- s = s.casefold()
38
- return ''.join(
39
- c for c in unicodedata.normalize('NFD', s)
40
- if unicodedata.category(c) != 'Mn'
41
- )
42
-
43
- if st.button("Submit"):
44
- name = unicode_to_ascii(name)
45
- name_ascii = [ord(letter) for letter in name]
46
-
47
- # Padding or Truncation to 20 characters
48
- name_ascii = name_ascii[:20] + [0] * (20 - len(name_ascii))
49
-
50
- # Convert to Tensor (reshape for RNN input)
51
- X = torch.tensor(name_ascii, dtype=torch.float32).view(1, 20, 1) # Shape: (batch, sequence, input)
52
-
53
- with torch.no_grad():
54
- pred = model1(X)
55
-
56
- # Get Predicted Class
57
- idx = torch.argmax(pred).item()
58
- class_ = all_classes[idx]
59
  st.success(f"Predicted Class: {class_}")
 
1
+ import streamlit as st
2
+ from joblib import load
3
+ import unicodedata
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ # Load Encoder and Model
8
+ Encoder = load("Country_Encoder")
9
+ all_classes = Encoder.classes_
10
+
11
+ st.title("Name Classification Based on Last Name")
12
+
13
+ # Define RNN Model
14
+ class SimpleRNN(nn.Module):
15
+ def __init__(self, input_size, hidden_size, output_size, num_layers):
16
+ super(SimpleRNN, self).__init__()
17
+ self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
18
+ self.dropout = nn.Dropout(0.3)
19
+ self.fc = nn.Linear(hidden_size, output_size)
20
+
21
+ def forward(self, x):
22
+ out, _ = self.rnn(x)
23
+ out = self.dropout(out[:, -1, :]) # Take last time step output
24
+ out = self.fc(out)
25
+ return out
26
+
27
+ # Load Model Weights
28
+ model1 = SimpleRNN(1, 256, len(all_classes), 1) # input_size should be 1 for ASCII values
29
+ model1.load_state_dict(torch.load("rnn_model.pth", map_location=torch.device('cpu')))
30
+ model1.eval()
31
+
32
+ # Text Input for Name
33
+ name = st.text_input("Enter Last Name")
34
+
35
+ # Convert Unicode to ASCII
36
+ def unicode_to_ascii(s):
37
+ s = s.casefold()
38
+ return ''.join(
39
+ c for c in unicodedata.normalize('NFD', s)
40
+ if unicodedata.category(c) != 'Mn'
41
+ )
42
+
43
+ if st.button("Submit"):
44
+ name = unicode_to_ascii(name)
45
+ name_ascii = [ord(letter) for letter in name]
46
+
47
+ # Padding or Truncation to 20 characters
48
+ name_ascii = name_ascii[:20] + [0] * (20 - len(name_ascii))
49
+
50
+ # Convert to Tensor (reshape for RNN input)
51
+ X = torch.tensor(name_ascii, dtype=torch.float32).view(1, 20, 1) # Shape: (batch, sequence, input)
52
+
53
+ with torch.no_grad():
54
+ pred = model1(X)
55
+
56
+ # Get Predicted Class
57
+ idx = torch.argmax(pred).item()
58
+ class_ = all_classes[idx]
59
  st.success(f"Predicted Class: {class_}")