Spaces:
Build error
Build error
| """ | |
| streamlit run app.py --server.address 0.0.0.0 | |
| """ | |
| from __future__ import annotations | |
| import os | |
| from time import time | |
| from typing import Literal | |
| import streamlit as st | |
| import torch | |
| from open_clip import create_model_and_transforms, get_tokenizer | |
| from openai import OpenAI | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.http import models | |
| if os.getenv("SPACE_ID"): | |
| USE_HF_SPACE = True | |
| os.environ["HF_HOME"] = "/data/.huggingface" | |
| os.environ["HF_DATASETS_CACHE"] = "/data/.huggingface" | |
| else: | |
| USE_HF_SPACE = False | |
| # for tokenizer | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") | |
| QDRANT_API_ENDPOINT = os.environ.get("QDRANT_API_ENDPOINT") | |
| QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY") | |
| BASE_IMAGE_URL = "https://storage.googleapis.com/secons-site-images/photo/" | |
| TargetImageType = Literal["xsmall", "small", "medium", "large"] | |
| if not QDRANT_API_ENDPOINT or not QDRANT_API_KEY: | |
| raise ValueError("env: QDRANT_API_ENDPOINT or QDRANT_API_KEY is not set.") | |
| def get_image_url(image_name: str, image_type: TargetImageType = "xsmall") -> str: | |
| return f"{BASE_IMAGE_URL}{image_type}/{image_name}.webp" | |
| def get_model_preprocess_tokenizer( | |
| target_model: str = "xlm-roberta-base-ViT-B-32", | |
| pretrained: str = "laion5B-s13B-b90k", | |
| ): | |
| model, _, preprocess = create_model_and_transforms( | |
| target_model, pretrained=pretrained | |
| ) | |
| tokenizer = get_tokenizer(target_model) | |
| return model, preprocess, tokenizer | |
| def get_qdrant_client(): | |
| qdrant_client = QdrantClient( | |
| url=QDRANT_API_ENDPOINT, | |
| api_key=QDRANT_API_KEY, | |
| ) | |
| return qdrant_client | |
| def get_text_features(text: str): | |
| model, preprocess, tokenizer = get_model_preprocess_tokenizer() | |
| text_tokenized = tokenizer([text]) | |
| with torch.no_grad(): | |
| text_features = model.encode_text(text_tokenized) # type: ignore | |
| text_features /= text_features.norm(dim=-1, keepdim=True) | |
| # tensor to list | |
| return text_features[0].tolist() | |
| def app(): | |
| _, _, _ = get_model_preprocess_tokenizer() # for cache | |
| st.title("secon.dev site search") | |
| search_text = st.text_input("Search", key="search_text") | |
| if search_text: | |
| st.write("searching...") | |
| start = time() | |
| qdrant_client = get_qdrant_client() | |
| text_features = get_text_features(search_text) | |
| search_results = qdrant_client.search( | |
| collection_name="images-clip", | |
| query_vector=text_features, | |
| limit=50, | |
| ) | |
| elapsed = time() - start | |
| st.write(f"elapsed: {elapsed:.2f} sec") | |
| st.write(f"total: {len(search_results)}") | |
| images = [] | |
| captions = [] | |
| for r in search_results: | |
| score = r.score | |
| if payload := r.payload: | |
| name = payload["name"] | |
| else: | |
| name = "unknown" | |
| image_url = get_image_url(name, image_type="xsmall") | |
| images.append(image_url) | |
| captions.append(f"{name} ({score:.4f})") | |
| image_group_n = 6 | |
| for i in range(0, len(images), image_group_n): | |
| target_images = images[i : i + image_group_n] | |
| target_captions = captions[i : i + image_group_n] | |
| st.image( | |
| target_images, | |
| caption=target_captions, | |
| width=160, | |
| ) | |
| if __name__ == "__main__": | |
| st.set_page_config( | |
| layout="wide", page_icon="https://secon.dev/images/profile_usa.png" | |
| ) | |
| app() | |