Anuj-Panthri's picture
error handling and color changes
65972e1
import tensorflow as tf
import numpy as np
import json
from .basemodel import ModelBaseClass
class Colgen1(ModelBaseClass):
def __init__(self,model_dir):
self.model=tf.keras.models.load_model(model_dir+"model.h5",compile=False)
self.token_to_idx=json.load(open(model_dir+"token_to_idx.txt",'r'))
self.TOKENS=list(self.token_to_idx.keys())
def tokenize(self,name):
""" tokenize single name """
return [self.token_to_idx[char] for char in name]
def one_hot_encode(self,tokens,num_classes):
return tf.keras.utils.to_categorical(tokens,num_classes=num_classes)
def add_padding(self,one_hot_vectors,num_classes,max_num_tokens):
''' one_hot_vectors:np.array shape:(tokens,len(all_tokens)) '''
num_of_padding = max_num_tokens-len(one_hot_vectors)
padding = []
for _ in range(num_of_padding):
padding.append(np.zeros([num_classes]))
padding = np.array(padding)
return np.r_[padding,one_hot_vectors] if len(padding)>0 else one_hot_vectors
def preprocess(self,names:list[str]):
""" names: [name,name,name,...] """
max_num_tokens=0
one_hots_list = []
for name in names:
name = name.lower() # convert to lowercase
name = "".join([char if char.isalnum() else " " for char in name]) # remove special characters
tokens = self.tokenize(name)
one_hot_vectors = self.one_hot_encode(tokens,len(self.TOKENS))
if len(tokens)>max_num_tokens: max_num_tokens=len(tokens)
one_hots_list.append(one_hot_vectors)
for i in range(len(one_hots_list)):
# we need to add padding so that all the examples have same number of tokens
one_hots = one_hots_list[i]
one_hots_list[i] = self.add_padding(one_hots,len(self.TOKENS),max_num_tokens)
return np.array(one_hots_list)
def predict(self,names: list):
tokens = self.preprocess(names)
colors = (self.model.predict(tokens,verbose=0)*255).astype("uint8")
return colors