Spaces:
Runtime error
Runtime error
| import gzip | |
| import io | |
| import json | |
| import random | |
| import re | |
| import tempfile | |
| from typing import Dict, List, Optional | |
| from PIL import Image | |
| import requests | |
| import streamlit as st | |
| http_session = requests.Session() | |
| API_URL = "https://world.openfoodfacts.org/api/v0" | |
| PRODUCT_URL = API_URL + "/product" | |
| OFF_IMAGE_BASE_URL = "https://static.openfoodfacts.org/images/products" | |
| BARCODE_PATH_REGEX = re.compile(r"^(...)(...)(...)(.*)$") | |
| def load_nn_data(url: str): | |
| r = http_session.get(url) | |
| with gzip.open(io.BytesIO(r.content), "rt") as f: | |
| return {int(key): value for key, value in json.loads(f.read()).items()} | |
| def load_logo_data(url: str): | |
| r = http_session.get(url) | |
| with gzip.open(io.BytesIO(r.content), "rt") as f: | |
| return { | |
| int(item["id"]): item for item in (json.loads(x) for x in map(str.strip, f)) | |
| } | |
| def get_image_from_url( | |
| image_url: str, | |
| error_raise: bool = False, | |
| session: Optional[requests.Session] = None, | |
| ) -> Optional[Image.Image]: | |
| if session: | |
| r = http_session.get(image_url) | |
| else: | |
| r = requests.get(image_url) | |
| if error_raise: | |
| r.raise_for_status() | |
| if r.status_code != 200: | |
| return None | |
| with tempfile.NamedTemporaryFile() as f: | |
| f.write(r.content) | |
| image = Image.open(f.name) | |
| return image | |
| def split_barcode(barcode: str) -> List[str]: | |
| if not barcode.isdigit(): | |
| raise ValueError("unknown barcode format: {}".format(barcode)) | |
| match = BARCODE_PATH_REGEX.fullmatch(barcode) | |
| if match: | |
| return [x for x in match.groups() if x] | |
| return [barcode] | |
| def get_cropped_image(barcode: str, image_id: str, bounding_box): | |
| image_path = generate_image_path(barcode, image_id) | |
| url = OFF_IMAGE_BASE_URL + image_path | |
| image = get_image_from_url(url, session=http_session) | |
| if image is None: | |
| return | |
| ymin, xmin, ymax, xmax = bounding_box | |
| (left, right, top, bottom) = ( | |
| xmin * image.width, | |
| xmax * image.width, | |
| ymin * image.height, | |
| ymax * image.height, | |
| ) | |
| return image.crop((left, top, right, bottom)) | |
| def generate_image_path(barcode: str, image_id: str) -> str: | |
| splitted_barcode = split_barcode(barcode) | |
| return "/{}/{}.jpg".format("/".join(splitted_barcode), image_id) | |
| def display_predictions( | |
| logo_data: Dict, | |
| nn_data: Dict, | |
| logo_id: Optional[int] = None, | |
| ): | |
| if not logo_id: | |
| logo_id = random.choice(list(nn_data.keys())) | |
| st.write(f"Logo ID: {logo_id}") | |
| logo = logo_data[logo_id] | |
| logo_nn_data = nn_data[logo_id] | |
| nn_ids = logo_nn_data["ids"] | |
| nn_distances = logo_nn_data["distances"] | |
| annotation = logo_nn_data["annotation"] | |
| cropped_image = get_cropped_image( | |
| logo["barcode"], logo["image_id"], logo["bounding_box"] | |
| ) | |
| if cropped_image is None: | |
| return | |
| st.image(cropped_image, annotation, width=200) | |
| cropped_images: List[Image.Image] = [] | |
| captions: List[str] = [] | |
| progress_bar = st.progress(0) | |
| for i, (closest_id, distance) in enumerate(zip(nn_ids, nn_distances)): | |
| progress_bar.progress((i + 1) / len(nn_ids)) | |
| closest_logo = logo_data[closest_id] | |
| cropped_image = get_cropped_image( | |
| closest_logo["barcode"], | |
| closest_logo["image_id"], | |
| closest_logo["bounding_box"], | |
| ) | |
| if cropped_image is None: | |
| continue | |
| if cropped_image.height > cropped_image.width: | |
| cropped_image = cropped_image.rotate(90) | |
| cropped_images.append(cropped_image) | |
| captions.append(f"distance: {distance}") | |
| if cropped_images: | |
| st.image(cropped_images, captions, width=200) | |
| st.sidebar.title("Logo Nearest Neighbors Demo") | |
| st.sidebar.write( | |
| "Get first 100 nearest neighbors for a random annotated logo.\n\n" | |
| "CLIP model is used to generate embeddings, and nearest neighbors " | |
| "are computed either using a brute-force approach or with ANN." | |
| ) | |
| logo_id = st.sidebar.number_input("logo ID", step=1) or None | |
| approximate = ( | |
| st.sidebar.checkbox( | |
| "ANN (HNSW)", | |
| value=False, | |
| help="Display approximate neighbors (instead of real " | |
| "neighbors computed using brute-force approach", | |
| ) | |
| or None | |
| ) | |
| nn_data = load_nn_data( | |
| f"https://static.openfoodfacts.org/data/logos/{'hnsw_50_closest_neighbours' if approximate else 'exact_100_neighbours'}.json.gz" | |
| ) | |
| logo_data = load_logo_data( | |
| "https://static.openfoodfacts.org/data/logos/logo_annotations.jsonl.gz" | |
| ) | |
| if approximate: | |
| st.write("Using approximate nearest neighbors method") | |
| else: | |
| st.write("Using exact (brute-force) nearest neighbors method") | |
| display_predictions(logo_data=logo_data, nn_data=nn_data, logo_id=logo_id) | |