Geo-GenderStudy / get_gender.py
AliMustapha's picture
fix the name handler
d86077e
__copyright__ = "Copyright (C) 2023 Ali Mustapha"
__license__ = "GPL-3.0-or-later"
import tensorflow as tf
import numpy as np
import re
import unicodedata
from utils import data_utils
from unidecode import unidecode
# from utils import self
class GenderPredictor:
def __init__(self, model_path):
self.model = self.load_model(model_path)
def load_model(self, path):
model = tf.keras.models.load_model(path)
# Compile and train the model
model.compile(
loss=tf.keras.losses.categorical_crossentropy,
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy']
)
return model
def predict_gender(self, name):
EMAIL_re = re.compile(r"^[^\s@]+@[^\s@]+$")
proba=100
if EMAIL_re.match(name):
prediction = 2
else:
translator = str.maketrans(r"-._\/+", " ")
name = name.translate(translator)
name = data_utils.text_to_romanize(name)
name=data_utils.remove_spaces_from_ends(name)
if (len(name) < 3 or data_utils.is_most_common_char(name)) and data_utils.is_roman_language(name):
prediction = 2
elif not data_utils.is_alpha(name):
prediction = 2
else:
translator = str.maketrans("", "", "0123456789")
name = name.translate(translator)
name = name.split()[0] if len(name.split()[0]) > 2 else name
try:
predictions_proba = self.model.predict([name], verbose=0).astype('float')
prediction,proba = self.get_label(predictions_proba)
prediction_part = []
if prediction == 2:
parts = name.split()
if len(parts) >1:
for part in parts:
prediction = self.predict_gender(part)
prediction_part.append(prediction)
prediction=data_utils.find_common_item(prediction_part)
except Exception:
prediction = 2
return prediction,proba
def get_label(self, predictions_proba):
for index, row in enumerate(predictions_proba):
proba=100
if row[2] >= 0.1:
max_index = 2
else:
max_index = np.argmax(row)
if max_index == 2:
proba= 100
else:
proba =int(row[max_index] * 100)
return max_index,proba