''' Model Gradio UI ''' ######################################################################### # imports from fastai.vision.all import * import gradio as gr import pathlib from huggingface_hub import hf_hub_download ######################################################################### # user access token for HF model library ACCESS_TOKEN = "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" ######################################################################### #Consider path seperators for alternate OS plt = platform.system() if plt != 'Windows': pathlib.WindowsPath = pathlib.PosixPath ######################################################################### def import_model(model_name): path = hf_hub_download(repo_id='amandasarubbi/tm-tko-models', filename=model_name, use_auth_token=ACCESS_TOKEN, repo_type='model') learn = load_learner(path, cpu=True) return learn ######################################################################### ######################################################################### # Function to predict outputs def predict(img, model_name): if (model_name == 'Geometric Figures & Solids'): geo_learn = import_model('geometric_model.pkl') preds = geo_learn.predict(img) elif (model_name == 'Scenery, Natural Phenomena'): landscape_learn = import_model('landscape_model.pkl') preds = landscape_learn.predict(img) elif (model_name == 'Human & Supernatural Beings'): human_learn = import_model('human_model.pkl') preds = human_learn.predict(img) elif (model_name == 'Colors & Characters'): colors_learn = import_model('colors_model.pkl') preds = colors_learn.predict(img) elif (model_name == 'Buildings, Dwellings & Furniture'): build_learn = import_model('buildings.pkl') preds = build_learn.predict(img) elif (model_name == 'Animals'): anim_learn = import_model('animals.pkl') preds = anim_learn.predict(img) label_pred = str(preds[0]) return label_pred ######################################################################### title = "TM-TKO Trademark Logo Image Classification Model" description = "Users can upload an image and corresponding image file name to get US design-code standard predictions on a trained model that utilizes the benchmark ResNet50 architecture." iFace = gr.Interface(fn=predict, inputs=[gr.inputs.Image(label="Upload Logo Here"), gr.inputs.Dropdown(choices=['Geometric Figures & Solids', 'Scenery, Natural Phenomena', 'Human & Supernatural Beings', 'Colors & Characters', 'Buildings, Dwellings & Furniture', 'Animals'], label='Choose a Model')], outputs=gr.Label(label="TM-TKO Trademark Classification Model"), title=title, description=description) iFace.launch()