import os import fitz import tempfile import streamlit as st from PIL import Image from chromadb import PersistentClient from chromadb.utils.data_loaders import ImageLoader from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction from skimage import data as skdata from skimage.io import imsave import uuid # Use safe temp directories for Streamlit or restricted environments TEMP_DIR = tempfile.gettempdir() IMAGES_DIR = os.path.join(TEMP_DIR, "extracted_images") DB_PATH = os.path.join(TEMP_DIR, "image_vdb") os.makedirs(IMAGES_DIR, exist_ok=True) @st.cache_resource def get_chroma_collection(): chroma_client = PersistentClient(path=DB_PATH) image_loader = ImageLoader() embedding_fn = OpenCLIPEmbeddingFunction() collection = chroma_client.get_or_create_collection( name="image", embedding_function=embedding_fn, data_loader=image_loader ) return collection image_collection = get_chroma_collection() # === Image Extraction === def extract_images_from_pdf(pdf_bytes): pdf = fitz.open(stream=pdf_bytes, filetype="pdf") saved_images = [] for page_num in range(len(pdf)): page = pdf.load_page(page_num) images = page.get_images(full=True) for img_idx, img in enumerate(images): xref = img[0] base_image = pdf.extract_image(xref) img_bytes = base_image["image"] ext = base_image["ext"] filename = f"page_{page_num+1}_img_{img_idx+1}.{ext}" path = os.path.join(IMAGES_DIR, filename) with open(path, "wb") as f: f.write(img_bytes) saved_images.append(path) return saved_images # === Indexing === def index_images(image_paths): ids = [] uris = [] for path in sorted(image_paths): if path.lower().endswith((".png", ".jpeg", ".jpg")): ids.append(str(uuid.uuid4())) uris.append(path) if ids: image_collection.add(ids=ids, uris=uris) # === Querying === def query_similar_images(image_file, top_k=5): with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp: tmp.write(image_file.read()) tmp_path = tmp.name try: results = image_collection.query(query_uris=[tmp_path], n_results=top_k) return results['uris'][0] finally: os.remove(tmp_path) # === Demo images === def load_skimage_demo_images(): demo_images = { "astronaut": skdata.astronaut(), "coffee": skdata.coffee(), "camera": skdata.camera(), "chelsea": skdata.chelsea(), "rocket": skdata.rocket() } saved_paths = [] for name, img in demo_images.items(): path = os.path.join(IMAGES_DIR, f"{name}.png") imsave(path, img) saved_paths.append(path) return saved_paths # === Streamlit UI === st.title("🔍 Image Similarity Search from PDF or Custom Dataset") source = st.radio( "Select Image Source", ["Upload PDF", "Upload Images", "Load Demo Dataset"], horizontal=True ) if source == "Upload PDF": uploaded_pdf = st.file_uploader("📤 Upload PDF", type=["pdf"]) if uploaded_pdf: with st.spinner("Extracting images..."): images = extract_images_from_pdf(uploaded_pdf.read()) index_images(images) st.success(f"{len(images)} images extracted and indexed.") st.image(images, width=150) elif source == "Upload Images": uploaded_imgs = st.file_uploader( "📤 Upload one or more images", type=["jpg", "jpeg", "png"], accept_multiple_files=True ) if uploaded_imgs: saved_paths = [] for img in uploaded_imgs: img_path = os.path.join(IMAGES_DIR, img.name) with open(img_path, "wb") as f: f.write(img.read()) saved_paths.append(img_path) index_images(saved_paths) st.success(f"{len(saved_paths)} images indexed.") st.image(saved_paths, width=150) elif source == "Load Demo Dataset": if st.button("🔄 Load Demo Images (skimage)"): demo_paths = load_skimage_demo_images() index_images(demo_paths) st.success("Demo images loaded and indexed.") st.image(demo_paths, width=150) st.divider() st.subheader("🔎 Search for Similar Images") query_img = st.file_uploader("Upload a query image", type=["jpg", "jpeg", "png"]) if query_img: st.image(query_img, caption="Query Image", width=200) with st.spinner("Searching..."): matches = query_similar_images(query_img, top_k=5) st.subheader("📊 Top Matches:") for match in matches: st.image(match, width=200, caption=os.path.basename(match))