Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import pandas as pd | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| from pipelines.detection.yolo_v8 import Yolov8Pipeline | |
| from pipelines.detection.yolo_stamp import YoloStampPipeline | |
| from pipelines.segmentation.deeplabv3 import DeepLabv3Pipeline | |
| from pipelines.feature_extraction.vae import VaePipeline | |
| from pipelines.feature_extraction.vits8 import Vits8Pipeline | |
| from utils import * | |
| yolov8 = Yolov8Pipeline.from_pretrained('stamps-labs/yolov8-finetuned', 'weights.pt') | |
| yolo_stamp = YoloStampPipeline.from_pretrained('stamps-labs/yolo-stamp', 'weights.pt') | |
| vae = VaePipeline.from_pretrained('stamps-labs/vae-encoder', 'weights.pt') | |
| vits8 = Vits8Pipeline.from_pretrained('stamps-labs/vits8-stamp', 'weights.pt') | |
| dlv3 = DeepLabv3Pipeline.from_pretrained('stamps-labs/deeplabv3-finetuned', 'weights.pt') | |
| def doc_predict(image, det_choice, seg_choice, emb_choice): | |
| image = image.convert('RGB') | |
| if det_choice == 'yolov8': | |
| boxes = yolov8(image) | |
| elif det_choice == 'yolo-stamp': | |
| boxes = yolo_stamp(image) | |
| else: | |
| return | |
| image_with_boxes = visualize_bbox(image, boxes) | |
| segmented_stamps = [] | |
| for box in boxes: | |
| cropped_stamp = image.crop(box.tolist()) | |
| segmented_stamps.append(dlv3(cropped_stamp) if seg_choice else cropped_stamp) | |
| if len(segmented_stamps) != 0: | |
| widths, heights = zip(*(i.size for i in segmented_stamps)) | |
| total_width = sum(widths) | |
| max_height = max(heights) | |
| concatenated_stamps = Image.new('RGB', (total_width, max_height)) | |
| x_offset = 0 | |
| for im in segmented_stamps: | |
| concatenated_stamps.paste(im, (x_offset,0)) | |
| x_offset += im.size[0] | |
| else: | |
| concatenated_stamps = Image.new('RGB', (0, 0)) | |
| embeddings = [] | |
| if emb_choice == 'vits8': | |
| for stamp in segmented_stamps: | |
| embeddings.append(vits8(stamp)) | |
| elif emb_choice == 'vae-encoder': | |
| for stamp in segmented_stamps: | |
| embeddings.append(vae(stamp)) | |
| embeddings = np.stack(embeddings) | |
| similarities = cosine_similarity(embeddings) | |
| df_boxes = pd.DataFrame(boxes, columns=['x1', 'y1', 'x2', 'y2']) | |
| fig, ax = plt.subplots() | |
| im, cbar = heatmap(similarities, range(1, len(embeddings) + 1), range(1, len(embeddings) + 1), ax=ax, | |
| cmap="YlGn", cbarlabel="Embeddings similarities") | |
| texts = annotate_heatmap(im, valfmt="{x:.3f}") | |
| return image_with_boxes, df_boxes, concatenated_stamps, embeddings, fig | |
| doc_examples = [['examples/1.jpg', 'yolov8', True, 'vits8'], ['examples/2.jpg', 'yolo-stamp', False, 'vae-encoder'], ['examples/3.jpg', 'yolov8', True, 'vits8']] | |
| doc_inputs = [ | |
| gr.Image(label="Document image", type="pil"), | |
| gr.Dropdown(choices=['yolov8', 'yolo-stamp'], value='yolov8', label='Detection model'), | |
| gr.Checkbox(label="Use segmentation model"), | |
| gr.Dropdown(choices=['vits8', 'vae-encoder'], value='vits8', label='Embedding model'), | |
| ] | |
| doc_outputs = [ | |
| gr.Image(label="Document with bounding boxes", type="pil"), | |
| gr.DataFrame(type='pandas', label="Bounding boxes"), | |
| gr.Image(label="Segmented stamps", type="pil"), | |
| gr.DataFrame(type='numpy', label="Embeddings"), | |
| gr.Plot(label="Cosine Similarities") | |
| ] | |
| with gr.Blocks() as demo: | |
| with gr.Tab("Signle document"): | |
| gr.Interface(doc_predict, doc_inputs, doc_outputs, examples=doc_examples) | |
| demo.launch(inline=False) |