from pathlib import Path import gradio as gr from fastai.vision.all import load_learner, PILImage VERSION = 3 model_path = Path("models") / f"charts_classifier_v{VERSION}.pkl" test_img_paths = list(Path("test_images").iterdir()) chart_labels = [ 'arc diagram', 'area chart', 'bar chart', 'block diagram', 'boxplot', 'bubble chart', 'cartogram', 'control chart', 'dendrogram', 'flowchart', 'funnel chart', 'gantt chart', 'heatmap', 'histogram', 'line graph', 'matrix diagram', 'mind map', 'network graph', 'neural network diagram', 'organogram', 'phase diagram', 'pie chart', 'radar chart', 'scatter plot', 'snakey chart', 'surface plot', 'timeline chart', 'venn diagram' ] class PILImageRGB(PILImage): _show_args, _open_args = {'cmap': 'Viridis'}, {'mode': 'RGB'} model = load_learner(model_path) def predict_image(image): pred, idx, probs = model.predict(image) return dict(zip(chart_labels, map(float, probs))) image = gr.inputs.Image(shape=(256, 256)) label = gr.outputs.Label(num_top_classes=5) title = "Charts Classifier" description = "
This is a demo to classify charts or diagrams out of 28 categories. Upload or drop any images to see the results.
" iface = gr.Interface(fn=predict_image, inputs=image, outputs=label, examples=test_img_paths, title=title, description=description, ) iface.launch(inline=False)