Spaces:
Runtime error
Runtime error
| from keras.models import load_model | |
| from PIL import Image, ImageOps | |
| import numpy as np | |
| import gradio as gr | |
| import pandas as pd | |
| import json | |
| import os | |
| import glob | |
| # === READ AND LOAD FILES === | |
| folder = '.' | |
| data = pd.read_csv(os.path.join(folder, 'species_info.csv')) | |
| with open(os.path.join(folder, 'translation.json'), 'r') as f: | |
| translation = json.load(f) | |
| # Load the model | |
| model = load_model(os.path.join(folder, 'keras_model.h5')) | |
| # Load label file | |
| with open(os.path.join(folder, 'labels.txt'),'r') as f: | |
| labels = f.readlines() | |
| # === GLOBAL VARIABLES === | |
| language = '' | |
| article = "" | |
| def format_label(label): | |
| """ | |
| From '0 rùa khác\n' to 'rùa khác' | |
| """ | |
| try: | |
| int(label.split(' ')[0]) | |
| return label[label.find(" ")+1:-1] | |
| except: | |
| return label[:-1] | |
| def get_name(scientific_name, lan): | |
| """ | |
| Return name in Vietnamese | |
| """ | |
| return data[data[f'scientific_name'] == scientific_name][f'name_{lan}'].to_list()[0] | |
| def get_fun_fact(scientific_name, lan): | |
| """ | |
| Return fun fact of the species | |
| """ | |
| return data[data[f'scientific_name'] == scientific_name][f'fun_fact_{lan}'].to_list()[0] | |
| def get_law(scientific_name): | |
| cites = data[data['scientific_name'] == scientific_name]['CITES'].to_list()[0] | |
| nd06 = data[data['scientific_name'] == scientific_name]['ND06'].to_list()[0] | |
| return cites, nd06 | |
| def get_habitat(scientific_name, lan): | |
| return data[data['scientific_name'] == scientific_name][f'habitat_{lan}'].to_list()[0] | |
| def get_conservation_status(scientific_name, lan): | |
| status_list = ['NE', 'DD', 'LC', 'NT', 'VU', 'EN', 'CR', 'EW', 'EX'] | |
| status = data[data['scientific_name'] == scientific_name]['IUCN'].to_list()[0] | |
| for s in status_list: | |
| if s in status: | |
| return translation['conservation_status'][s][lan] | |
| def get_language_code(lan): | |
| global language | |
| if lan == "Tiếng Việt": | |
| language = 'vi' | |
| if lan == "English": | |
| language = 'en' | |
| return language | |
| def get_species_list(): | |
| """ | |
| Example: | |
| ['Indotestudo elongata', | |
| 'Cuora galbinifrons', | |
| 'Cuora mouhotii', | |
| 'Cuora bourreti'] | |
| """ | |
| return [format_label(s) for s in labels] | |
| def get_species_abbreviation(scientific_name): | |
| return "".join([s[0] for s in scientific_name.split()]) | |
| def get_species_abbreviation_list(): | |
| """ | |
| Example: | |
| ['Ie', 'Cg', 'Cm', 'Cb'] | |
| """ | |
| return [get_species_abbreviation(s) for s in get_species_list()] | |
| def get_description(language): | |
| num_class = len(labels) | |
| num_native = 0 | |
| num_non_native = 0 | |
| native_list = '' | |
| non_native_list = '' | |
| for i in labels: | |
| label = format_label(i) | |
| if label in data[data.native == 'y'].scientific_name.values: | |
| num_native += 1 | |
| native_list += f"({num_native}) {get_name(label, language)}, " | |
| else: | |
| num_non_native += 1 | |
| non_native_list += f"({num_non_native}) {get_name(label, language)}, " | |
| if language=='vi': | |
| description=f""" | |
| VNTurtle nhận diện các loài rùa Việt Nam. Mô hình này có thể nhận diện **{num_class}** loại rùa thường xuất hiện ở VN gồm | |
| - **{num_native}** loài bản địa: {native_list} \n\n | |
| - **{num_non_native}** loài ngoại lai: {non_native_list} | |
| """ | |
| if language=='en': | |
| description=f""" | |
| VNTurtle can recognize turtle species in Vietnam. This model can identify {num_class} common turtles in Vietnam including **{num_native}** native species \n\n | |
| {native_list} \n\n | |
| and **{num_non_native}** non-native species \n\n | |
| {non_native_list} | |
| """ | |
| return description | |
| def update_language(language): | |
| language = get_language_code(language) | |
| return get_description(language), \ | |
| translation['label']['label_run_btn'][language], \ | |
| translation["accordion"]["fun_fact"][language], \ | |
| translation["accordion"]["status"][language], \ | |
| translation["accordion"]["law"][language], \ | |
| translation["accordion"]["info"][language] | |
| def predict(image): | |
| # Create the array of the right shape to feed into the keras model | |
| # The 'length' or number of images you can put into the array is | |
| # determined by the first position in the shape tuple, in this case 1. | |
| data = np.ndarray(shape=(1, 224, 224, 3), dtype=np.float32) | |
| #resize the image to a 224x224 with the same strategy as in TM2: | |
| #resizing the image to be at least 224x224 and then cropping from the center | |
| size = (224, 224) | |
| image = ImageOps.fit(image, size, Image.ANTIALIAS) | |
| #turn the image into a numpy array | |
| image_array = np.asarray(image) | |
| # Normalize the image | |
| normalized_image_array = (image_array.astype(np.float32) / 127.0) - 1 | |
| # Load the image into the array | |
| data[0] = normalized_image_array | |
| # run the inference | |
| pred = model.predict(data) | |
| pred = pred.tolist() | |
| return pred | |
| result = {} | |
| best_prediction = '' | |
| def interpret_prediction(prediction): | |
| global result | |
| sorted_index = np.argsort(prediction).tolist()[0] | |
| display_index = [] | |
| for i in sorted_index[::-1]: | |
| if prediction[0][i] > 0.01: | |
| display_index.append(i) | |
| # best_prediction = format_label(labels[sorted_index[-1]]).strip() | |
| result = {format_label(labels[i]): round(prediction[0][i],2) for i in display_index} | |
| # return best_prediction | |
| def run_btn_click(image): | |
| global best_prediction | |
| best_prediction = None | |
| global article | |
| article = translation["info"]["ATP_contact"][language] | |
| interpret_prediction(predict(image)) | |
| visible_result = [ | |
| False, | |
| False, | |
| False, | |
| False, | |
| False | |
| ] | |
| image_result = [ | |
| os.path.join(folder, 'examples', 'empty.JPG'), | |
| os.path.join(folder, 'examples', 'empty.JPG'), | |
| os.path.join(folder, 'examples', 'empty.JPG'), | |
| os.path.join(folder, 'examples', 'empty.JPG'), | |
| os.path.join(folder, 'examples', 'empty.JPG') | |
| ] | |
| percent_result = [ | |
| "", | |
| "", | |
| "", | |
| "", | |
| "" | |
| ] | |
| species_result = [ | |
| "", | |
| "", | |
| "", | |
| "", | |
| "" | |
| ] | |
| for i, (species, percent) in enumerate(result.items()): | |
| print(species, result) | |
| visible_result[i] = True | |
| image_result[i] = os.path.join(folder, 'examples', f'test_{get_species_abbreviation(species)}.JPG') | |
| percent_result[i] = f'{round(percent*100)}%' | |
| species_result[i] = species | |
| return gr.Accordion.update(open=True, visible=True), \ | |
| gr.Image.update(value=image_result[0], visible=visible_result[0]), \ | |
| gr.HighlightedText.update(value=[('', percent_result[0])], label=species_result[0], visible=visible_result[0]), \ | |
| gr.Button.update(visible=visible_result[0]), \ | |
| \ | |
| gr.Image.update(value=image_result[1], visible=visible_result[1]), \ | |
| gr.HighlightedText.update(value=[('', percent_result[1])], label=species_result[1], visible=visible_result[1]), \ | |
| gr.Button.update(visible=visible_result[1]), \ | |
| \ | |
| gr.Image.update(value=image_result[2], visible=visible_result[2]), \ | |
| gr.HighlightedText.update(value=[('', percent_result[2])], label=species_result[2], visible=visible_result[2]), \ | |
| gr.Button.update(visible=visible_result[2]), \ | |
| \ | |
| gr.Image.update(value=image_result[3], visible=visible_result[3]), \ | |
| gr.HighlightedText.update(value=[('', percent_result[3])], label=species_result[3], visible=visible_result[3]), \ | |
| gr.Button.update(visible=visible_result[3]), \ | |
| \ | |
| gr.Image.update(value=image_result[4], visible=visible_result[4]), \ | |
| gr.HighlightedText.update(value=[('', percent_result[4])], label=species_result[4], visible=visible_result[4]), \ | |
| gr.Button.update(visible=visible_result[4]), \ | |
| gr.Accordion.update(visible=False), \ | |
| [] | |
| # gr.Accordion.update(visible=False), \ | |
| # gr.Accordion.update(visible=False), \ | |
| # gr.Accordion.update(visible=False), \ | |
| # gr.Accordion.update(visible=False), \ | |
| # gr.Markdown.update(value=percent_result[4], visible=visible_result[4]), \ | |
| def get_image_gallery_species_1(): | |
| global best_prediction | |
| for i, name in enumerate(result): | |
| if i == 0: | |
| best_prediction = name | |
| return glob.glob(os.path.join(folder, 'gallery', name, '*')) | |
| def get_image_gallery_species_2(): | |
| global best_prediction | |
| for i, name in enumerate(result): | |
| if i == 1: | |
| best_prediction = name | |
| return glob.glob(os.path.join(folder, 'gallery', name, '*')) | |
| def get_image_gallery_species_3(): | |
| global best_prediction | |
| for i, name in enumerate(result): | |
| if i == 2: | |
| best_prediction = name | |
| return glob.glob(os.path.join(folder, 'gallery', name, '*')) | |
| def get_image_gallery_species_4(): | |
| global best_prediction | |
| for i, name in enumerate(result): | |
| if i == 3: | |
| best_prediction = name | |
| return glob.glob(os.path.join(folder, 'gallery', name, '*')) | |
| def get_image_gallery_species_5(): | |
| global best_prediction | |
| for i, name in enumerate(result): | |
| if i == 4: | |
| best_prediction = name | |
| return glob.glob(os.path.join(folder, 'gallery', name, '*')) | |
| def display_info(): | |
| cites, nd06 = get_law(best_prediction) | |
| fun_fact = f"{get_fun_fact(best_prediction, language)}." | |
| status = f"{get_conservation_status(best_prediction, language)}" | |
| law = f'CITES: {cites}, NĐ06: {nd06}' | |
| info = "" | |
| if str(nd06) != "": | |
| law_protection = translation["info"]["law_protection"][language] | |
| report = translation["info"]["report"][language] | |
| deliver = translation["info"]["deliver"][language] | |
| release = translation["info"]["release"][language] + f" **{get_habitat(best_prediction, language)}**" | |
| info = f"- {law_protection}\n\n- {report}\n\n- {deliver}\n\n- {release}" | |
| return gr.Accordion.update(visible=True), \ | |
| gr.Accordion.update(open=False), \ | |
| gr.Accordion.update(visible=True), \ | |
| gr.Accordion.update(visible=True), \ | |
| gr.Accordion.update(visible=True), \ | |
| gr.Accordion.update(visible=True), \ | |
| fun_fact, status, law, info | |
| default_lan = 'Tiếng Việt' | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# VNTurtle") | |
| radio_lan = gr.Radio(choices=['Tiếng Việt', 'English'], value=default_lan, label='Ngôn ngữ/Language', show_label=True, interactive=True) | |
| md_des = gr.Markdown(get_description(get_language_code(default_lan))) | |
| with gr.Row(equal_height=True): | |
| inp = gr.Image(type="pil", show_label=True, label='Ảnh tải lên', interactive=True).style(height=250) | |
| gallery = gr.Gallery(show_label=True, label='Ảnh đối chiếu').style(grid=[4], height="auto") | |
| with gr.Row(): | |
| run_btn = gr.Button(translation['label']['label_run_btn'][get_language_code(default_lan)]) | |
| result_verify_btn = gr.Button(translation['label']['label_verify_btn'][get_language_code(default_lan)], visible=True) | |
| accordion_result_section = gr.Accordion(translation["accordion"]["result_section"][get_language_code(default_lan)], open=True, visible=False) | |
| with accordion_result_section: | |
| with gr.Row() as display_result: | |
| with gr.Column(scale=0.2, min_width=150) as result_1: | |
| result_percent_1 = gr.HighlightedText(show_label=True, visible=False).style(color_map={f'{i}%': 'green' for i in range(101)}) | |
| # result_percent_1 = gr.Markdown("", visible=False) | |
| result_img_1 = gr.Image(interactive=False, visible=False, show_label=False) | |
| result_view_btn_1 = gr.Button(translation['label']['label_check_btn'][get_language_code(default_lan)], visible=False) | |
| with gr.Column(scale=0.2, min_width=150) as result_2: | |
| result_percent_2 = gr.HighlightedText(show_label=True, visible=False).style(color_map={f'{i}%': 'yellow' for i in range(101)}) | |
| result_img_2 = gr.Image(interactive=False, visible=False, show_label=False) | |
| result_view_btn_2 = gr.Button(translation['label']['label_check_btn'][get_language_code(default_lan)], visible=False) | |
| with gr.Column(scale=0.2, min_width=150) as result_3: | |
| result_percent_3 = gr.HighlightedText(show_label=True, visible=False).style(color_map={f'{i}%': 'orange' for i in range(101)}) | |
| result_img_3 = gr.Image(interactive=False, visible=False, show_label=False) | |
| result_view_btn_3 = gr.Button(translation['label']['label_check_btn'][get_language_code(default_lan)], visible=False) | |
| with gr.Column(scale=0.2, min_width=150) as result_4: | |
| result_percent_4 = gr.HighlightedText(show_label=True, visible=False).style(color_map={f'{i}%': 'chocolate' for i in range(101)}) | |
| result_img_4 = gr.Image(interactive=False, visible=False, show_label=False) | |
| result_view_btn_4 = gr.Button(translation['label']['label_check_btn'][get_language_code(default_lan)], visible=False) | |
| with gr.Column(scale=0.2, min_width=150) as result_5: | |
| result_percent_5 = gr.HighlightedText(show_label=True, visible=False).style(color_map={f'{i}%': 'grey' for i in range(101)}) | |
| result_img_5 = gr.Image(interactive=False, visible=False, show_label=False) | |
| result_view_btn_5 = gr.Button(translation['label']['label_check_btn'][get_language_code(default_lan)], visible=False) | |
| accordion_info_section = gr.Accordion(translation['accordion']['info_section'][get_language_code(default_lan)], visible=False, open=True) | |
| with accordion_info_section: | |
| accordion_fun_fact = gr.Accordion(translation["accordion"]["fun_fact"][get_language_code(default_lan)], open=False, visible=False) | |
| accordion_status = gr.Accordion(translation["accordion"]["status"][get_language_code(default_lan)], open=False, visible=False) | |
| accordion_law = gr.Accordion(translation["accordion"]["law"][get_language_code(default_lan)], open=False, visible=False) | |
| accordion_info = gr.Accordion(translation["accordion"]["info"][get_language_code(default_lan)], open=False, visible=False) | |
| with accordion_fun_fact: | |
| md_fun_fact = gr.Markdown() | |
| with accordion_status: | |
| md_status = gr.Markdown() | |
| with accordion_law: | |
| md_law = gr.Markdown() | |
| with accordion_info: | |
| md_info = gr.Markdown() | |
| gr.Markdown("---") | |
| with gr.Accordion("🌅 Ảnh thử nghiệm", open=False): | |
| gr.Examples( | |
| examples=[[os.path.join(folder, 'examples', f'test_{get_species_abbreviation(s)}.JPG'), get_name(s, language)] for s in get_species_list()], | |
| inputs=[inp], | |
| label="" | |
| ) | |
| radio_lan.change(fn=update_language, inputs=[radio_lan], outputs=[ | |
| md_des, | |
| run_btn, | |
| accordion_fun_fact, | |
| accordion_status, | |
| accordion_law, | |
| accordion_info | |
| ]) | |
| run_btn.click(fn=run_btn_click, inputs=inp, outputs= [ | |
| accordion_result_section, | |
| # md_fun_fact, md_status, md_law, md_info, | |
| result_img_1, result_percent_1, result_view_btn_1, | |
| result_img_2, result_percent_2, result_view_btn_2, | |
| result_img_3, result_percent_3, result_view_btn_3, | |
| result_img_4, result_percent_4, result_view_btn_4, | |
| result_img_5, result_percent_5, result_view_btn_5, | |
| # accordion_fun_fact, accordion_status, accordion_law, accordion_info, | |
| accordion_info_section, | |
| gallery | |
| ], show_progress=True, scroll_to_output=True) | |
| result_view_btn_1.click(fn=get_image_gallery_species_1, outputs=gallery) | |
| result_view_btn_2.click(fn=get_image_gallery_species_2, outputs=gallery) | |
| result_view_btn_3.click(fn=get_image_gallery_species_3, outputs=gallery) | |
| result_view_btn_4.click(fn=get_image_gallery_species_4, outputs=gallery) | |
| result_view_btn_5.click(fn=get_image_gallery_species_5, outputs=gallery) | |
| result_verify_btn.click(fn=display_info, outputs=[ | |
| accordion_info_section, | |
| accordion_result_section, | |
| accordion_fun_fact, | |
| accordion_status, | |
| accordion_law, | |
| accordion_info, | |
| md_fun_fact, | |
| md_status, | |
| md_law, | |
| md_info, | |
| ], scroll_to_output=True) | |
| demo.launch(debug=False) | |