AnnasBlackHat's picture
basic ui
c49a9ad
raw
history blame
1.94 kB
import gradio as gr
import requests
import random
from src.classification_model import ClassificationModel
#only for dummy data
response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")
clf = ClassificationModel()
model_names = clf.get_model_names()
output_labels = []
def predict(models, img_urls, img_files):
print(f'model choosen: {models}')
model_predictions = {}
#set all labels visibility to false
for i, name in enumerate(model_names):
model_predictions[output_labels[i]] = gr.Label(label=f'# {name}', visible=False)
print(f'id {i} invisible')
for m in models:
idx = model_names.index(m)
print(f' {m} idx: ', idx)
result = {labels[random.randrange(0, len(labels))]: random.uniform(0, 1.0) for i in range(5)}
model_predictions[output_labels[idx]] = gr.Label(label=f'# {m}, 3 seconds', value=result, visible=True)
return model_predictions
with gr.Blocks() as demo:
gr.Markdown("# Image Classification Benchmark")
with gr.Row():
with gr.Column(scale=1):
model = gr.Dropdown(choices=model_names, multiselect=True, label='Choose the model')
img_urls = gr.Textbox(label='Image Urls (separated with comma)')
img_files = gr.File(label='Upload Files',file_count='multiple', file_types=['image'])
apply = gr.Button("Classify", variant='primary')
with gr.Column(scale=1):
for name in clf.get_model_names():
output_labels.append(gr.Label(label=f'# {name}'))
apply.click(fn=predict,
inputs=[model, img_urls, img_files],
outputs=output_labels)
if __name__ == "__main__":
demo.launch()
# inputs = [
# gr.Dropdown(choices=clf.get_model_names(), multiselect=True)
# ]
# iface = gr.Interface(fn=greet, inputs=inputs, outputs="text")
# iface.launch()