File size: 1,788 Bytes
7ba0490
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6dfaa17
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import streamlit as st
from joblib import load
import unicodedata
import torch 
import torch.nn as nn

# Load Encoder and Model
Encoder = load("Country_Encoder")
all_classes = Encoder.classes_

st.title("Name Classification Based on Last Name")

# Define RNN Model
class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers):
        super(SimpleRNN, self).__init__()
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        self.dropout = nn.Dropout(0.3)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out, _ = self.rnn(x)
        out = self.dropout(out[:, -1, :])  # Take last time step output
        out = self.fc(out)
        return out

# Load Model Weights
model1 = SimpleRNN(1, 256, len(all_classes), 1)  # input_size should be 1 for ASCII values
model1.load_state_dict(torch.load("rnn_model.pth", map_location=torch.device('cpu')))
model1.eval()

# Text Input for Name
name = st.text_input("Enter Last Name")

# Convert Unicode to ASCII
def unicode_to_ascii(s):
    s = s.casefold()
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

if st.button("Submit"):
    name = unicode_to_ascii(name)
    name_ascii = [ord(letter) for letter in name]

    # Padding or Truncation to 20 characters
    name_ascii = name_ascii[:20] + [0] * (20 - len(name_ascii))

    # Convert to Tensor (reshape for RNN input)
    X = torch.tensor(name_ascii, dtype=torch.float32).view(1, 20, 1)  # Shape: (batch, sequence, input)

    with torch.no_grad():
        pred = model1(X)

    # Get Predicted Class
    idx = torch.argmax(pred).item()
    class_ = all_classes[idx]
    st.success(f"Predicted Class: {class_}")