Spaces:
Sleeping
Sleeping
| import os | |
| import fitz | |
| import tempfile | |
| import streamlit as st | |
| from PIL import Image | |
| from chromadb import PersistentClient | |
| from chromadb.utils.data_loaders import ImageLoader | |
| from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction | |
| from skimage import data as skdata | |
| from skimage.io import imsave | |
| import uuid | |
| # Use safe temp directories for Streamlit or restricted environments | |
| TEMP_DIR = tempfile.gettempdir() | |
| IMAGES_DIR = os.path.join(TEMP_DIR, "extracted_images") | |
| DB_PATH = os.path.join(TEMP_DIR, "image_vdb") | |
| os.makedirs(IMAGES_DIR, exist_ok=True) | |
| def get_chroma_collection(): | |
| chroma_client = PersistentClient(path=DB_PATH) | |
| image_loader = ImageLoader() | |
| embedding_fn = OpenCLIPEmbeddingFunction() | |
| collection = chroma_client.get_or_create_collection( | |
| name="image", embedding_function=embedding_fn, data_loader=image_loader | |
| ) | |
| return collection | |
| image_collection = get_chroma_collection() | |
| # === Image Extraction === | |
| def extract_images_from_pdf(pdf_bytes): | |
| pdf = fitz.open(stream=pdf_bytes, filetype="pdf") | |
| saved_images = [] | |
| for page_num in range(len(pdf)): | |
| page = pdf.load_page(page_num) | |
| images = page.get_images(full=True) | |
| for img_idx, img in enumerate(images): | |
| xref = img[0] | |
| base_image = pdf.extract_image(xref) | |
| img_bytes = base_image["image"] | |
| ext = base_image["ext"] | |
| filename = f"page_{page_num+1}_img_{img_idx+1}.{ext}" | |
| path = os.path.join(IMAGES_DIR, filename) | |
| with open(path, "wb") as f: | |
| f.write(img_bytes) | |
| saved_images.append(path) | |
| return saved_images | |
| # === Indexing === | |
| def index_images(image_paths): | |
| ids = [] | |
| uris = [] | |
| for path in sorted(image_paths): | |
| if path.lower().endswith((".png", ".jpeg", ".jpg")): | |
| ids.append(str(uuid.uuid4())) | |
| uris.append(path) | |
| if ids: | |
| image_collection.add(ids=ids, uris=uris) | |
| # === Querying === | |
| def query_similar_images(image_file, top_k=5): | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp: | |
| tmp.write(image_file.read()) | |
| tmp_path = tmp.name | |
| try: | |
| results = image_collection.query(query_uris=[tmp_path], n_results=top_k) | |
| return results['uris'][0] | |
| finally: | |
| os.remove(tmp_path) | |
| # === Demo images === | |
| def load_skimage_demo_images(): | |
| demo_images = { | |
| "astronaut": skdata.astronaut(), | |
| "coffee": skdata.coffee(), | |
| "camera": skdata.camera(), | |
| "chelsea": skdata.chelsea(), | |
| "rocket": skdata.rocket() | |
| } | |
| saved_paths = [] | |
| for name, img in demo_images.items(): | |
| path = os.path.join(IMAGES_DIR, f"{name}.png") | |
| imsave(path, img) | |
| saved_paths.append(path) | |
| return saved_paths | |
| # === Streamlit UI === | |
| st.title("π Image Similarity Search from PDF or Custom Dataset") | |
| source = st.radio( | |
| "Select Image Source", | |
| ["Upload PDF", "Upload Images", "Load Demo Dataset"], | |
| horizontal=True | |
| ) | |
| if source == "Upload PDF": | |
| uploaded_pdf = st.file_uploader("π€ Upload PDF", type=["pdf"]) | |
| if uploaded_pdf: | |
| with st.spinner("Extracting images..."): | |
| images = extract_images_from_pdf(uploaded_pdf.read()) | |
| index_images(images) | |
| st.success(f"{len(images)} images extracted and indexed.") | |
| st.image(images, width=150) | |
| elif source == "Upload Images": | |
| uploaded_imgs = st.file_uploader( | |
| "π€ Upload one or more images", type=["jpg", "jpeg", "png"], accept_multiple_files=True | |
| ) | |
| if uploaded_imgs: | |
| saved_paths = [] | |
| for img in uploaded_imgs: | |
| img_path = os.path.join(IMAGES_DIR, img.name) | |
| with open(img_path, "wb") as f: | |
| f.write(img.read()) | |
| saved_paths.append(img_path) | |
| index_images(saved_paths) | |
| st.success(f"{len(saved_paths)} images indexed.") | |
| st.image(saved_paths, width=150) | |
| elif source == "Load Demo Dataset": | |
| if st.button("π Load Demo Images (skimage)"): | |
| demo_paths = load_skimage_demo_images() | |
| index_images(demo_paths) | |
| st.success("Demo images loaded and indexed.") | |
| st.image(demo_paths, width=150) | |
| st.divider() | |
| st.subheader("π Search for Similar Images") | |
| query_img = st.file_uploader("Upload a query image", type=["jpg", "jpeg", "png"]) | |
| if query_img: | |
| st.image(query_img, caption="Query Image", width=200) | |
| with st.spinner("Searching..."): | |
| matches = query_similar_images(query_img, top_k=5) | |
| st.subheader("π Top Matches:") | |
| for match in matches: | |
| st.image(match, width=200, caption=os.path.basename(match)) | |