Spaces:
Running
Running
| import streamlit as st | |
| import os | |
| import torch | |
| from pathlib import Path | |
| from transformers import AutoTokenizer | |
| from jax import numpy as jnp | |
| import json | |
| import requests | |
| import zipfile | |
| import io | |
| import natsort | |
| from PIL import Image as PilImage | |
| from torchvision import datasets, transforms | |
| from torchvision.transforms import CenterCrop, Normalize, Resize, ToTensor | |
| from torchvision.transforms.functional import InterpolationMode | |
| from tqdm import tqdm | |
| from modeling_hybrid_clip import FlaxHybridCLIP | |
| import utils | |
| def get_model(): | |
| return FlaxHybridCLIP.from_pretrained("clip-italian/clip-italian") | |
| def download_images(): | |
| # from sentence_transformers import SentenceTransformer, util | |
| img_folder = "photos/" | |
| if not os.path.exists(img_folder) or len(os.listdir(img_folder)) == 0: | |
| os.makedirs(img_folder, exist_ok=True) | |
| photo_filename = "unsplash-25k-photos.zip" | |
| if not os.path.exists(photo_filename): # Download dataset if does not exist | |
| print(f"Downloading {photo_filename}...") | |
| r = requests.get("http://sbert.net/datasets/" + photo_filename, stream=True) | |
| z = zipfile.ZipFile(io.BytesIO(r.content)) | |
| print("Extracting the dataset...") | |
| z.extractall(path=img_folder) | |
| print("Done.") | |
| def get_image_features(): | |
| return jnp.load("static/features/features.npy") | |
| """ | |
| # π Ciao! | |
| # CLIP Italian Demo (Flax Community Week) | |
| """ | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| query = st.text_input("Insert a query text") | |
| if query: | |
| with st.spinner("Computing in progress..."): | |
| model = get_model() | |
| download_images() | |
| image_features = get_image_features() | |
| model = get_model() | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| "dbmdz/bert-base-italian-xxl-uncased", cache_dir=None, use_fast=True | |
| ) | |
| image_size = model.config.vision_config.image_size | |
| val_preprocess = transforms.Compose( | |
| [ | |
| Resize([image_size], interpolation=InterpolationMode.BICUBIC), | |
| CenterCrop(image_size), | |
| ToTensor(), | |
| Normalize( | |
| (0.48145466, 0.4578275, 0.40821073), | |
| (0.26862954, 0.26130258, 0.27577711), | |
| ), | |
| ] | |
| ) | |
| dataset = utils.CustomDataSet("photos/", transform=val_preprocess) | |
| image_paths = utils.find_image( | |
| query, model, dataset, tokenizer, image_features, n=2 | |
| ) | |
| st.image(image_paths) | |
| def read_markdown_file(markdown_file): | |
| return Path(markdown_file).read_text() | |
| intro_markdown = read_markdown_file("readme.md") | |
| st.markdown(intro_markdown, unsafe_allow_html=True) |