import functools import gdown from collections import Counter import os import torch from S1_CNN_Model import CNN_Model import gradio as gr import numpy as np import cv2 from SpeciesDetail import labels, SpeciesDetail device=torch.device('cuda' if torch.cuda.is_available() else 'cpu') MODEL_LINK = "https://drive.google.com/file/d/18-t2jMpXLxtqE-8Bu0_NNNuie_mguSON/view?usp=sharing" MODEL_PATH = "model.pt" if not os.path.exists(MODEL_PATH): print("Downloading model . . . ") gdown.download(MODEL_LINK,MODEL_PATH,fuzzy=True) model:CNN_Model = torch.load(MODEL_PATH) model.to(device) model.device = device def listdir_full(path: str) -> list[str]: return [f"{path}/{p}" for p in os.listdir(path)] label_names = [l.name for l in labels] class History(): cols = ["Image", "Prediction"] def __init__(self, img, name) -> None: self.img = resize_image(img) self.name = name import sqlite3 def fetch_data(id: int): with sqlite3.connect('my_database.db') as conn: c = conn.cursor() c.execute('SELECT * FROM my_table WHERE id = ?', (id,)) _, *detail = c.fetchone() return SpeciesDetail(*detail) MAX_IMG_LEN = 160 def resize_image(img): h, w, _ = img.shape if w > h: w1 = MAX_IMG_LEN h1 = int(h/w * MAX_IMG_LEN) else: h1 = MAX_IMG_LEN w1 = int(w/h * MAX_IMG_LEN) return cv2.resize(img,(w1,h1)) PD_COLS=["image","predicted species"] MAX_HISTORY = 10 MAX_PREDS = 10 def classify(image: np.array, history): if history == None: history = [] with torch.no_grad(): r, p = model.predict_large_image(cv2.cvtColor(image, cv2. COLOR_RGB2BGR)) ratios = [gr.Textbox(f"{label_names[label]}: {count/len(r)*100:.2f}%",visible=True) for label, count in Counter(r.tolist()).most_common()][-MAX_PREDS:] ratios += [gr.Textbox(visible=False)] * (MAX_PREDS - len(ratios)) detail = fetch_data(p.item()) pred = gr.Markdown(detail.result_text()) history += [(resize_image(image), f"

{detail.name}

\n {detail.desc}")] hist = history[-MAX_HISTORY:] return pred, *ratios, *toggle_history_components(hist), history def toggle_history_components(history: list[History]): n_hidden = MAX_HISTORY - len(history) images, names = list(zip(*history)) components = [gr.Image(x, visible=True) for x in images] components += [gr.Image(visible=False)] * n_hidden components += [gr.Markdown(x, visible=True) for x in names] components += [gr.Markdown(visible=False)] * n_hidden return components def classification_tab(): with gr.Row(): with gr.Column(): image = gr.Image() with gr.Row(): submit = gr.Button("Submit", variant='primary') clear = gr.ClearButton(image) with gr.Column(): pred = gr.Markdown("## Predictions") ratios = [] for _ in range(MAX_PREDS): ratios.append(gr.Textbox(show_label=False,visible=False)) return image, submit, clear, pred, ratios SAMPLE_DIR = "data/image/test_full" MAX_SAMPLE_COUNT = max([len(os.listdir(x)) for x in listdir_full(SAMPLE_DIR)]) def sample_tab(image_input, tabs): def choose_image(image): return gr.Image(image), gr.Image(image), gr.Tabs(selected=0) def refresh_samples(species): images = listdir_full(f"{SAMPLE_DIR}/{species}") n_hidden = MAX_SAMPLE_COUNT-len(images) components = [gr.Image(i,visible=True) for i in images] components += [gr.Image(visible=False)] * n_hidden components += [gr.Button(visible=True) for _ in images] components += [gr.Button(visible=False)] * n_hidden return components dropdown = gr.Dropdown(label_names, label="Species", value="Select a Species") images = [] buttons = [] def sample_panel(): with gr.Column(): image = gr.Image(visible=False ,interactive=False, min_width=1) select = gr.Button("Submit", variant='primary', visible=False) images.append(image) buttons.append(select) select.click(choose_image, image, [image, image_input, tabs]) with gr.Row(): [sample_panel() for _ in range(MAX_SAMPLE_COUNT)] dropdown.change(refresh_samples, dropdown, images+buttons) return def history_tab(): history_imgs = [] history_names = [] with gr.Row(): gr.Markdown("# Image") with gr.Column(scale=2): gr.Markdown("# Species") with gr.Column(): for _ in range(MAX_HISTORY): with gr.Row(): history_imgs.append(gr.Image(height=200,visible=False)) with gr.Column(scale=2): history_names.append(gr.Markdown("",visible=False)) return history_imgs + history_names with open('homepage.md', 'r') as file: home_screen_markdown = file.read() with gr.Blocks() as demo: history = gr.State([]) with gr.Tabs() as tabs: with gr.Tab("Home", id=3): gr.Markdown(home_screen_markdown) with gr.Tab("Classification", id=0): image, submit, clear, pred, ratios = classification_tab() with gr.Tab("Samples", id=1): sample_tab(image, tabs) with gr.Tab("History", id=2): table_contents = history_tab() # history = gr.Gallery(interactive=False) submit.click(classify,[image, history],[pred, *ratios, *table_contents, history]) demo.launch()