Spaces:
Runtime error
Runtime error
| __copyright__ = "Copyright (C) 2023 Ali Mustapha" | |
| __license__ = "GPL-3.0-or-later" | |
| import tensorflow as tf | |
| import numpy as np | |
| import re | |
| import unicodedata | |
| from utils import data_utils | |
| from unidecode import unidecode | |
| # from utils import self | |
| class GenderPredictor: | |
| def __init__(self, model_path): | |
| self.model = self.load_model(model_path) | |
| def load_model(self, path): | |
| model = tf.keras.models.load_model(path) | |
| # Compile and train the model | |
| model.compile( | |
| loss=tf.keras.losses.categorical_crossentropy, | |
| optimizer=tf.keras.optimizers.Adam(), | |
| metrics=['accuracy'] | |
| ) | |
| return model | |
| def predict_gender(self, name): | |
| EMAIL_re = re.compile(r"^[^\s@]+@[^\s@]+$") | |
| proba=100 | |
| if EMAIL_re.match(name): | |
| prediction = 2 | |
| else: | |
| translator = str.maketrans(r"-._\/+", " ") | |
| name = name.translate(translator) | |
| name = data_utils.text_to_romanize(name) | |
| name=data_utils.remove_spaces_from_ends(name) | |
| if (len(name) < 3 or data_utils.is_most_common_char(name)) and data_utils.is_roman_language(name): | |
| prediction = 2 | |
| elif not data_utils.is_alpha(name): | |
| prediction = 2 | |
| else: | |
| translator = str.maketrans("", "", "0123456789") | |
| name = name.translate(translator) | |
| name = name.split()[0] if len(name.split()[0]) > 2 else name | |
| try: | |
| predictions_proba = self.model.predict([name], verbose=0).astype('float') | |
| prediction,proba = self.get_label(predictions_proba) | |
| prediction_part = [] | |
| if prediction == 2: | |
| parts = name.split() | |
| if len(parts) >1: | |
| for part in parts: | |
| prediction = self.predict_gender(part) | |
| prediction_part.append(prediction) | |
| prediction=data_utils.find_common_item(prediction_part) | |
| except Exception: | |
| prediction = 2 | |
| return prediction,proba | |
| def get_label(self, predictions_proba): | |
| for index, row in enumerate(predictions_proba): | |
| proba=100 | |
| if row[2] >= 0.1: | |
| max_index = 2 | |
| else: | |
| max_index = np.argmax(row) | |
| if max_index == 2: | |
| proba= 100 | |
| else: | |
| proba =int(row[max_index] * 100) | |
| return max_index,proba | |