Spaces:
Build error
Build error
| from html import escape | |
| from io import BytesIO | |
| import base64 | |
| from multiprocessing.dummy import Pool | |
| from PIL import Image, ImageDraw | |
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| # from transformers import CLIPProcessor, CLIPModel | |
| # from transformers import OwlViTProcessor, OwlViTForObjectDetection | |
| # from transformers.image_utils import ImageFeatureExtractionMixin | |
| import pickle as pkl | |
| # sketches | |
| from streamlit_drawable_canvas import st_canvas | |
| from PIL import Image, ImageOps | |
| from torchvision import transforms | |
| # model | |
| import os | |
| # No reconoce la carpeta que esta dos niveles abajo src | |
| from src.model_LN_prompt import Model | |
| from src.options import opts | |
| from datasets import load_dataset | |
| DEBUG = False | |
| if DEBUG: | |
| MODEL = "vit-base-patch32" | |
| else: | |
| MODEL = "vit-large-patch14-336" | |
| CLIP_MODEL = f"openai/clip-{MODEL}" | |
| OWL_MODEL = f"google/owlvit-base-patch32" | |
| if not DEBUG and torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| else: | |
| device = torch.device("cpu") | |
| HEIGHT = 350 | |
| N_RESULTS = 5 | |
| from huggingface_hub import hf_hub_download,login | |
| token = os.getenv("HUGGINGFACE_TOKEN") | |
| # Autentica usando el token | |
| login(token=token) | |
| color = st.get_option("theme.primaryColor") | |
| if color is None: | |
| color = (0, 255, 0) | |
| else: | |
| color = tuple(int(color.lstrip("#")[i: i + 2], 16) for i in (0, 2, 4)) | |
| def load(): | |
| path_images = "data/doc_explore/DocExplore_images/" | |
| path_model = hf_hub_download(repo_id="CHSTR/DocExplore", filename="epoch=16-mAP=0.66_triplet.ckpt")#"models/epoch=16-mAP=0.66_triplet.ckpt" | |
| model = Model() | |
| model_checkpoint = torch.load(path_model, map_location=device) # 'model_60k_images_073.ckpt' -> modelo entrenado con 60k imagenes sin pidinet | |
| model.load_state_dict(model_checkpoint['state_dict']) # 'modified_model_083.ckpt' -> modelo entrenado con 60k imagenes con pidinet | |
| model.eval() # 'original_model_083.ckpt' -> modelo original entrenado con 60k imagenes con pidinet | |
| print("Modelo cargado exitosamente") | |
| embeddings_file_1 = hf_hub_download(repo_id="CHSTR/DocExplore", filename="dino_flicker_docexplore_groundingDINO.pkl") | |
| embeddings_file_0 = hf_hub_download(repo_id="CHSTR/DocExplore", filename="docexp_embeddings.pkl") | |
| embeddings = { | |
| 0: pkl.load(open(embeddings_file_0, "rb")), | |
| 1: pkl.load(open(embeddings_file_1, "rb")) | |
| } | |
| # embeddings = { | |
| # 0: pkl.load(open("docexp_embeddings.pkl", "rb")), | |
| # 1: pkl.load(open("dino_flicker_docexplore_groundingDINO.pkl", "rb")) | |
| # } | |
| # Actualizar los paths de las imágenes en los embeddings | |
| #for i in range(len(embeddings[0])): | |
| # print(embeddings[0][i]) | |
| #embeddings[0][i] = (embeddings[0][i][0], path_images + "/".join(embeddings[0][i][1].split("/")[:-3])) | |
| #for i in range(len(embeddings[1])): | |
| # print(embeddings[1][i]) | |
| #embeddings[1][i] = (embeddings[1][i][0], path_images + "/".join(embeddings[1][i][1].split("/")[:-3])) | |
| return model, path_images, embeddings | |
| print("Cargando modelos...") | |
| model, path_images, embeddings = load() | |
| source = {0: "\nDocExplore SAM", 1: "\nDocExplore GroundingDINO"} | |
| stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 5) | |
| dataset_transforms = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| def compute_text_embeddings(sketch): | |
| with torch.no_grad(): | |
| sketch_feat = model(sketch.to(device), dtype='sketch') | |
| return sketch_feat | |
| # inputs = clip_processor(text=list_of_strings, return_tensors="pt", padding=True).to( | |
| # device | |
| # ) | |
| # with torch.no_grad(): | |
| # result = clip_model.get_text_features(**inputs).detach().cpu().numpy() | |
| # return result / np.linalg.norm(result, axis=1, keepdims=True) | |
| #return torch.randn(1, 768) | |
| def image_search(query, corpus, n_results=N_RESULTS): | |
| query_embedding = compute_text_embeddings(query) | |
| corpus_id = 0 if corpus == "DocExplore SAM" else 1 | |
| image_features = torch.tensor([item[0] for item in embeddings[corpus_id]]).to(device) | |
| bbox_of_images = torch.tensor([item[1] for item in embeddings[corpus_id]]).to(device) | |
| label_of_images = torch.tensor([item[2] for item in embeddings[corpus_id]]).to(device) | |
| dot_product = (image_features @ query_embedding.T)[:, 0] | |
| _, max_indices = torch.topk(dot_product, n_results, dim=0, largest=True, sorted=True) | |
| return [ | |
| ( | |
| path_images + "page" + str(i) + ".jpg", | |
| ) | |
| for i in label_of_images[max_indices].cpu().numpy().tolist() | |
| ], bbox_of_images[max_indices], dot_product[max_indices] | |
| def make_square(img, fill_color=(255, 255, 255)): | |
| x, y = img.size | |
| size = max(x, y) | |
| new_img = Image.new("RGB", (x, y), fill_color) | |
| new_img.paste(img) | |
| return new_img, x, y | |
| def get_images(paths): | |
| def process_image(path): | |
| return make_square(Image.open(path)) | |
| processed = Pool(N_RESULTS).map(process_image, paths) | |
| imgs, xs, ys = [], [], [] | |
| for img, x, y in processed: | |
| imgs.append(img) | |
| xs.append(x) | |
| ys.append(y) | |
| return imgs, xs, ys | |
| def keep_best_boxes(boxes, scores, score_threshold=0.1, max_iou=0.8): | |
| candidates = [] | |
| for box, score in zip(boxes, scores): | |
| box = [round(i, 0) for i in box.tolist()] | |
| if score >= score_threshold: | |
| candidates.append((box, float(score))) | |
| to_ignore = set() | |
| for i in range(len(candidates) - 1): | |
| if i in to_ignore: | |
| continue | |
| for j in range(i + 1, len(candidates)): | |
| if j in to_ignore: | |
| continue | |
| xmin1, ymin1, xmax1, ymax1 = candidates[i][0] | |
| xmin2, ymin2, xmax2, ymax2 = candidates[j][0] | |
| if xmax1 < xmin2 or xmax2 < xmin1 or ymax1 < ymin2 or ymax2 < ymin1: | |
| continue | |
| else: | |
| xmin_inter, xmax_inter = sorted( | |
| [xmin1, xmax1, xmin2, xmax2])[1:3] | |
| ymin_inter, ymax_inter = sorted( | |
| [ymin1, ymax1, ymin2, ymax2])[1:3] | |
| area_inter = (xmax_inter - xmin_inter) * \ | |
| (ymax_inter - ymin_inter) | |
| area1 = (xmax1 - xmin1) * (ymax1 - ymin1) | |
| area2 = (xmax2 - xmin2) * (ymax2 - ymin2) | |
| iou = area_inter / (area1 + area2 - area_inter) | |
| if iou > max_iou: | |
| if candidates[i][1] > candidates[j][1]: | |
| to_ignore.add(j) | |
| else: | |
| to_ignore.add(i) | |
| break | |
| else: | |
| if area_inter / area1 > 0.9: | |
| if candidates[i][1] < 1.1 * candidates[j][1]: | |
| to_ignore.add(i) | |
| if area_inter / area2 > 0.9: | |
| if 1.1 * candidates[i][1] > candidates[j][1]: | |
| to_ignore.add(j) | |
| return [candidates[i][0] for i in range(len(candidates)) if i not in to_ignore] | |
| def convert_pil_to_base64(image): | |
| img_buffer = BytesIO() | |
| image.save(img_buffer, format="JPEG") | |
| byte_data = img_buffer.getvalue() | |
| base64_str = base64.b64encode(byte_data) | |
| return base64_str | |
| def draw_reshape_encode(img, boxes, x, y): | |
| boxes = [boxes.tolist()] | |
| image = img.copy() | |
| draw = ImageDraw.Draw(image) | |
| new_x, new_y = int(x * HEIGHT / y), HEIGHT | |
| for box in boxes: | |
| print("box:", box) | |
| draw.rectangle( | |
| [(box[0], box[1]), (box[2], box[3])], # (x_min, y_min, x_max, y_max) | |
| outline=color, # Box color | |
| width=10 # Box width | |
| ) | |
| #if x > y: | |
| # image = image.crop((0, (x - y) / 2, x, x - (x - y) / 2)) | |
| #else: | |
| # image = image.crop(((y - x) / 2, 0, y - (y - x) / 2, y)) | |
| return convert_pil_to_base64(image.resize((new_x, new_y))) | |
| def get_html(url_list, encoded_images): | |
| html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>" | |
| for i in range(len(url_list)): | |
| title, encoded = url_list[i][0], encoded_images[i] | |
| html = ( | |
| html | |
| + f"<img title='{escape(title)}' style='height: {HEIGHT}px; margin: 1px' src='data:image/jpeg;base64,{encoded.decode()}'>" | |
| ) | |
| html += "</div>" | |
| return html | |
| description = """ | |
| # Sketch-based Detection | |
| This app retrieves images from the [DocExplore](https://www.docexplore.eu/?lang=en) dataset based on a sketch query. | |
| **Tip 1**: you can draw a sketch in the canvas. | |
| **Tip 2**: you can change the size of the stroke with the slider. | |
| The model utilized in this application is a DINOv2, which was trained in a self-supervised manner on the Flickr25k dataset. | |
| """ | |
| div_style = { | |
| "display": "flex", | |
| "justify-content": "center", | |
| "flex-wrap": "wrap", | |
| } | |
| def main(): | |
| st.markdown( | |
| """ | |
| <style> | |
| .block-container{ | |
| max-width: 1600px; | |
| } | |
| div.row-widget > div{ | |
| flex-direction: row; | |
| display: flex; | |
| justify-content: center; | |
| } | |
| div.row-widget.stRadio > div > label{ | |
| margin-left: 5px; | |
| margin-right: 5px; | |
| } | |
| .row-widget { | |
| margin-top: -25px; | |
| } | |
| section > div:first-child { | |
| padding-top: 30px; | |
| } | |
| div.appview-container > section:first-child{ | |
| max-width: 320px; | |
| } | |
| #MainMenu { | |
| visibility: hidden; | |
| } | |
| .stMarkdown { | |
| display: grid; | |
| place-items: center; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| st.sidebar.markdown(description) | |
| st.title("One-Shot Detection") | |
| # Create two main columns | |
| left_col, right_col = st.columns([0.2, 0.8]) # Adjust the weights as needed | |
| with left_col: | |
| # Canvas for drawing | |
| canvas_result = st_canvas( | |
| background_color="#eee", | |
| stroke_width=stroke_width, | |
| update_streamlit=True, | |
| height=300, | |
| width=300, | |
| key="color_annotation_app", | |
| ) | |
| # Input controls | |
| query = [0] | |
| corpus = st.radio("", ["DocExplore SAM", "DocExplore GroundingDINO"], index=0) | |
| # score_threshold = st.slider( | |
| # "Score threshold", min_value=0.01, max_value=1.0, value=0.5, step=0.01 | |
| # ) | |
| with right_col: | |
| if canvas_result.image_data is not None: | |
| draw = Image.fromarray(canvas_result.image_data.astype("uint8")) | |
| draw = ImageOps.pad(draw.convert("RGB"), size=(224, 224)) | |
| draw.save("draw.jpg") | |
| draw_tensor = transforms.ToTensor()(draw) | |
| draw_tensor = transforms.Resize((224, 224))(draw_tensor) | |
| draw_tensor = transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | |
| )(draw_tensor) | |
| draw_tensor = draw_tensor.unsqueeze(0) | |
| else: | |
| return | |
| if len(query) > 0: | |
| retrieved, bbox_of_images, dot_product = image_search(draw_tensor, corpus) | |
| imgs, xs, ys = get_images([x[0] for x in retrieved]) | |
| encoded_images = [] | |
| for image_idx in range(len(imgs)): | |
| img0, x, y = imgs[image_idx], xs[image_idx], ys[image_idx] | |
| encoded_images.append(draw_reshape_encode(img0, bbox_of_images[image_idx], x, y)) | |
| st.markdown(get_html(retrieved, encoded_images), unsafe_allow_html=True) | |
| if __name__ == "__main__": | |
| main() | |