abir0's picture
Update parameters
0acd18a
from pathlib import Path
import gradio as gr
from fastai.vision.all import load_learner, PILImage
VERSION = 3
model_path = Path("models") / f"charts_classifier_v{VERSION}.pkl"
test_img_paths = list(Path("test_images").iterdir())
chart_labels = [
'arc diagram',
'area chart',
'bar chart',
'block diagram',
'boxplot',
'bubble chart',
'cartogram',
'control chart',
'dendrogram',
'flowchart',
'funnel chart',
'gantt chart',
'heatmap',
'histogram',
'line graph',
'matrix diagram',
'mind map',
'network graph',
'neural network diagram',
'organogram',
'phase diagram',
'pie chart',
'radar chart',
'scatter plot',
'snakey chart',
'surface plot',
'timeline chart',
'venn diagram'
]
class PILImageRGB(PILImage):
_show_args, _open_args = {'cmap': 'Viridis'}, {'mode': 'RGB'}
model = load_learner(model_path)
def predict_image(image):
pred, idx, probs = model.predict(image)
return dict(zip(chart_labels, map(float, probs)))
image = gr.inputs.Image(shape=(256, 256))
label = gr.outputs.Label(num_top_classes=5)
title = "Charts Classifier"
description = "<p align=center>This is a demo to classify charts or diagrams out of 28 categories. Upload or drop any images to see the results.</p>"
iface = gr.Interface(fn=predict_image,
inputs=image,
outputs=label,
examples=test_img_paths,
title=title,
description=description,
)
iface.launch(inline=False)