Spaces:
Running
Running
| import streamlit as st | |
| import pymongo | |
| from datetime import datetime, date, timezone | |
| from typing import List, Tuple, Optional, Dict | |
| import requests | |
| from PIL import Image, ImageFile | |
| import io | |
| from dotenv import load_dotenv | |
| import concurrent.futures | |
| import threading | |
| from functools import lru_cache | |
| import contextlib | |
| from requests.adapters import HTTPAdapter | |
| from urllib3.util.retry import Retry | |
| load_dotenv() | |
| ImageFile.LOAD_TRUNCATED_IMAGES = True | |
| def _download_bytes_cached(url: str, timeout_s: float = 12.0) -> Optional[bytes]: | |
| try: | |
| r = requests.get(url, timeout=(3.05, timeout_s), stream=True) | |
| r.raise_for_status() | |
| return r.content | |
| except Exception: | |
| return None | |
| class ImageGalleryApp: | |
| def __init__(self, mongo_uri: str, db_name: str, collection_name: str): | |
| self.client = pymongo.MongoClient(mongo_uri) | |
| self.db = self.client[db_name] | |
| self.collection = self.db[collection_name] | |
| self._cache_lock = threading.Lock() | |
| self.session = requests.Session() | |
| retries = Retry(total=3, connect=3, read=3, backoff_factor=0.4, status_forcelist=[429,500,502,503,504], allowed_methods=["GET","HEAD"]) | |
| adapter = HTTPAdapter(pool_connections=64, pool_maxsize=64, max_retries=retries) | |
| self.session.mount("http://", adapter); self.session.mount("https://", adapter) | |
| self.thumb_max_size = (768, 768) | |
| def get_unique_categories(self) -> List[str]: | |
| try: | |
| categories = self.collection.distinct("category", {"status": "completed", "category": {"$ne": None}}) | |
| return ["All"] + sorted(categories) | |
| except Exception: | |
| return ["All"] | |
| def get_unique_filenames(self) -> List[str]: | |
| try: | |
| filenames = self.collection.distinct("file_name", {"status": "completed", "file_name": {"$ne": None}}) | |
| return ["All"] + sorted(filenames) | |
| except Exception: | |
| return ["All"] | |
| def parse_date_input(self, date_input) -> Optional[date]: | |
| if not date_input or date_input == "": return None | |
| if isinstance(date_input, date): return date_input | |
| if isinstance(date_input, str): | |
| from datetime import datetime as _dt | |
| for fmt in ("%Y-%m-%d", "%m/%d/%Y"): | |
| try: return _dt.strptime(date_input, fmt).date() | |
| except Exception: pass | |
| return None | |
| def load_image_from_url(self, url: str) -> Optional[Image.Image]: | |
| try: | |
| data = _download_bytes_cached(url) or self.session.get(url, timeout=(3.05, 12), stream=True).content | |
| img = Image.open(io.BytesIO(data)) | |
| with contextlib.suppress(Exception): img = img.convert("RGB") | |
| if img.size[0] > 768 or img.size[1] > 768: img.thumbnail((768,768), Image.Resampling.LANCZOS) | |
| return img | |
| except Exception: | |
| return None | |
| def load_images_parallel(self, urls: List[str], max_workers: int = 8) -> List[Tuple[str, Optional[Image.Image]]]: | |
| results: List[Tuple[str, Optional[Image.Image]]] = [] | |
| if not urls: return results | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex: | |
| future_map = {ex.submit(self.load_image_from_url, u): u for u in urls} | |
| for fut in concurrent.futures.as_completed(future_map): | |
| u = future_map[fut] | |
| try: img = fut.result() | |
| except Exception: img = None | |
| results.append((u, img)) | |
| return results | |
| def search_images_page(self, category="All", file_name="All", start_date=None, end_date=None, lob="search_arb", page=0, page_size=24) -> Tuple[List[Dict], int]: | |
| match = {"status": "completed", "urls": {"$exists": True, "$ne": []}} | |
| if category != "All": match["category"] = category | |
| if file_name != "All": match["file_name"] = file_name | |
| from datetime import datetime as _dt | |
| if start_date or end_date: | |
| date_query = {} | |
| if start_date: date_query["$gte"] = _dt.combine(self.parse_date_input(start_date), _dt.min.time()) | |
| if end_date: date_query["$lte"] = _dt.combine(self.parse_date_input(end_date), _dt.max.time()) | |
| if date_query: match["created_at"] = date_query | |
| try: | |
| count_pipeline = [{"$match": match},{"$unwind": "$urls"},{"$count": "n"}] | |
| count_doc = list(self.collection.aggregate(count_pipeline)) | |
| total = count_doc[0]["n"] if count_doc else 0 | |
| except Exception: | |
| total = 0 | |
| if total == 0: return [], 0 | |
| pipeline = [ | |
| {"$match": match},{"$unwind": "$urls"},{"$sort": {"created_at": -1}}, | |
| {"$skip": max(0, page) * max(1, page_size)},{"$limit": max(1, page_size)}, | |
| {"$project": {"_id": 0,"url": "$urls","category": 1,"file_name": 1,"created_at": 1,"prompt": 1,"status": 1}} | |
| ] | |
| try: | |
| docs = list(self.collection.aggregate(pipeline, allowDiskUse=True)) | |
| except Exception: | |
| docs = [] | |
| return docs, total | |
| def create_streamlit_app(mongo_uri: str, db_name: str, collection_name: str): | |
| app = ImageGalleryApp(mongo_uri, db_name, collection_name) | |
| def get_filter_choices(): | |
| try: | |
| categories = app.get_unique_categories() | |
| filenames = app.get_unique_filenames() | |
| return categories, filenames | |
| except Exception: | |
| return ["All"], ["All"] | |
| if "categories_list" not in st.session_state: | |
| st.session_state["categories_list"], st.session_state["filenames_list"] = get_filter_choices() | |
| st.session_state.setdefault("selected_category", "All") | |
| st.session_state.setdefault("selected_filename", "All") | |
| st.session_state.setdefault("selected_lob", "search_arb") | |
| today = datetime.now(timezone.utc).date() | |
| st.session_state.setdefault("use_date_filter", True) | |
| st.session_state.setdefault("selected_start_date", today) | |
| st.session_state.setdefault("selected_end_date", today) | |
| st.session_state.setdefault("page", 0) | |
| st.session_state.setdefault("page_size", 24) | |
| st.session_state.setdefault("last_query_total", 0) | |
| st.session_state.setdefault("did_search", False) | |
| col1, col2= st.columns([1,1]) | |
| with col1: | |
| category = st.selectbox("Category", options=st.session_state["categories_list"]) | |
| with col2: | |
| filename = st.selectbox("File Name", options=st.session_state["filenames_list"]) | |
| coldf = st.columns([1,1,1]) | |
| with coldf[0]: | |
| use_date_filter = st.checkbox("Filter by date", value=st.session_state["use_date_filter"]) | |
| with coldf[1]: | |
| start_date = st.date_input("Start Date", value=st.session_state["selected_start_date"], disabled=not use_date_filter) | |
| with coldf[2]: | |
| end_date = st.date_input("End Date", value=st.session_state["selected_end_date"], disabled=not use_date_filter) | |
| col_misc = st.columns([1,1,1,2]) | |
| with col_misc[0]: | |
| page_size = st.selectbox("Images per page", [8,12,16,24,32,48], index=[8,12,16,24,32,48].index(st.session_state["page_size"])) | |
| col_btn1, col_btn2, col_btn3 = st.columns([2,2,2]) | |
| with col_btn1: | |
| search_clicked = st.button("🔍 Search", width='content') | |
| with col_btn2: | |
| refresh_clicked = st.button("🔄 Refresh Filters", width='content') | |
| with col_btn3: | |
| reset_clicked = st.button("♻️ Reset Page", width='content') | |
| if refresh_clicked: | |
| st.session_state["categories_list"], st.session_state["filenames_list"] = get_filter_choices() | |
| st.session_state["selected_category"] = "All"; st.session_state["selected_filename"] = "All"; st.session_state["page"] = 0; st.rerun() | |
| if reset_clicked: | |
| st.session_state["page"] = 0; st.rerun() | |
| if search_clicked: | |
| st.session_state["selected_category"] = category; st.session_state["selected_filename"] = filename | |
| st.session_state["use_date_filter"] = use_date_filter; st.session_state["selected_start_date"] = start_date; st.session_state["selected_end_date"] = end_date | |
| st.session_state["page_size"] = page_size; st.session_state["page"] = 0; st.session_state["did_search"] = True; st.rerun() | |
| if st.session_state["did_search"]: | |
| _start = st.session_state["selected_start_date"] if st.session_state["use_date_filter"] else None | |
| _end = st.session_state["selected_end_date"] if st.session_state["use_date_filter"] else None | |
| docs, total = app.search_images_page(category=st.session_state["selected_category"], file_name=st.session_state["selected_filename"], start_date=_start, end_date=_end, lob=st.session_state["selected_lob"], page=st.session_state["page"], page_size=st.session_state["page_size"]) | |
| st.session_state["last_query_total"] = total | |
| total_pages = max(1, (total + st.session_state["page_size"] - 1) // st.session_state["page_size"]) | |
| nav1, nav2, nav3 = st.columns([1,2,1]) | |
| with nav1: | |
| if st.button("⬅️ Prev", disabled=(st.session_state["page"] <= 0)): st.session_state["page"] -= 1; st.rerun() | |
| with nav2: | |
| st.markdown(f"<div style='text-align:center'>Page <b>{st.session_state['page']+1}</b> of <b>{total_pages}</b> · <b>{total}</b> images total</div>", unsafe_allow_html=True) | |
| with nav3: | |
| if st.button("Next ➡️", disabled=(st.session_state["page"] >= total_pages - 1)): st.session_state["page"] += 1; st.rerun() | |
| st.divider() | |
| if total == 0 or not docs: | |
| st.info("No images found for the current filters."); return | |
| st.markdown("#### Images"); cols = st.columns(4) | |
| placeholders = [] | |
| for i, _ in enumerate(docs): | |
| ph = cols[i % 4].empty(); ph.markdown("<div style='width:100%;aspect-ratio:1/1;border-radius:10px;background:#eee'></div>", unsafe_allow_html=True); placeholders.append(ph) | |
| urls = [d["url"] for d in docs] | |
| loaded = app.load_images_parallel(urls, max_workers=8) | |
| url_to_img = {u: img for (u, img) in loaded} | |
| for i, d in enumerate(docs): | |
| img = url_to_img.get(d["url"]); meta = f"{d.get('category','N/A')} | {d.get('file_name','N/A')} | {d.get('created_at','')}" | |
| if img: placeholders[i].image(img,width='stretch', caption=meta) | |
| else: placeholders[i].warning("Failed to load image") | |
| else: | |
| st.info("Set your filters and click **Search** to load images.") | |