ImageSearchClip / src /streamlit_app.py
NEXAS's picture
Update src/streamlit_app.py
ee979c8 verified
raw
history blame
4.71 kB
import os
import fitz
import tempfile
import streamlit as st
from PIL import Image
from chromadb import PersistentClient
from chromadb.utils.data_loaders import ImageLoader
from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction
from skimage import data as skdata
from skimage.io import imsave
import uuid
# Use safe temp directories for Streamlit or restricted environments
TEMP_DIR = tempfile.gettempdir()
IMAGES_DIR = os.path.join(TEMP_DIR, "extracted_images")
DB_PATH = os.path.join(TEMP_DIR, "image_vdb")
os.makedirs(IMAGES_DIR, exist_ok=True)
@st.cache_resource
def get_chroma_collection():
chroma_client = PersistentClient(path=DB_PATH)
image_loader = ImageLoader()
embedding_fn = OpenCLIPEmbeddingFunction()
collection = chroma_client.get_or_create_collection(
name="image", embedding_function=embedding_fn, data_loader=image_loader
)
return collection
image_collection = get_chroma_collection()
# === Image Extraction ===
def extract_images_from_pdf(pdf_bytes):
pdf = fitz.open(stream=pdf_bytes, filetype="pdf")
saved_images = []
for page_num in range(len(pdf)):
page = pdf.load_page(page_num)
images = page.get_images(full=True)
for img_idx, img in enumerate(images):
xref = img[0]
base_image = pdf.extract_image(xref)
img_bytes = base_image["image"]
ext = base_image["ext"]
filename = f"page_{page_num+1}_img_{img_idx+1}.{ext}"
path = os.path.join(IMAGES_DIR, filename)
with open(path, "wb") as f:
f.write(img_bytes)
saved_images.append(path)
return saved_images
# === Indexing ===
def index_images(image_paths):
ids = []
uris = []
for path in sorted(image_paths):
if path.lower().endswith((".png", ".jpeg", ".jpg")):
ids.append(str(uuid.uuid4()))
uris.append(path)
if ids:
image_collection.add(ids=ids, uris=uris)
# === Querying ===
def query_similar_images(image_file, top_k=5):
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
tmp.write(image_file.read())
tmp_path = tmp.name
try:
results = image_collection.query(query_uris=[tmp_path], n_results=top_k)
return results['uris'][0]
finally:
os.remove(tmp_path)
# === Demo images ===
def load_skimage_demo_images():
demo_images = {
"astronaut": skdata.astronaut(),
"coffee": skdata.coffee(),
"camera": skdata.camera(),
"chelsea": skdata.chelsea(),
"rocket": skdata.rocket()
}
saved_paths = []
for name, img in demo_images.items():
path = os.path.join(IMAGES_DIR, f"{name}.png")
imsave(path, img)
saved_paths.append(path)
return saved_paths
# === Streamlit UI ===
st.title("πŸ” Image Similarity Search from PDF or Custom Dataset")
source = st.radio(
"Select Image Source",
["Upload PDF", "Upload Images", "Load Demo Dataset"],
horizontal=True
)
if source == "Upload PDF":
uploaded_pdf = st.file_uploader("πŸ“€ Upload PDF", type=["pdf"])
if uploaded_pdf:
with st.spinner("Extracting images..."):
images = extract_images_from_pdf(uploaded_pdf.read())
index_images(images)
st.success(f"{len(images)} images extracted and indexed.")
st.image(images, width=150)
elif source == "Upload Images":
uploaded_imgs = st.file_uploader(
"πŸ“€ Upload one or more images", type=["jpg", "jpeg", "png"], accept_multiple_files=True
)
if uploaded_imgs:
saved_paths = []
for img in uploaded_imgs:
img_path = os.path.join(IMAGES_DIR, img.name)
with open(img_path, "wb") as f:
f.write(img.read())
saved_paths.append(img_path)
index_images(saved_paths)
st.success(f"{len(saved_paths)} images indexed.")
st.image(saved_paths, width=150)
elif source == "Load Demo Dataset":
if st.button("πŸ”„ Load Demo Images (skimage)"):
demo_paths = load_skimage_demo_images()
index_images(demo_paths)
st.success("Demo images loaded and indexed.")
st.image(demo_paths, width=150)
st.divider()
st.subheader("πŸ”Ž Search for Similar Images")
query_img = st.file_uploader("Upload a query image", type=["jpg", "jpeg", "png"])
if query_img:
st.image(query_img, caption="Query Image", width=200)
with st.spinner("Searching..."):
matches = query_similar_images(query_img, top_k=5)
st.subheader("πŸ“Š Top Matches:")
for match in matches:
st.image(match, width=200, caption=os.path.basename(match))