Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pymongo | |
| from datetime import datetime, date, timezone | |
| from typing import List, Tuple, Optional, Dict | |
| import requests | |
| from PIL import Image, ImageFile | |
| import io | |
| from dotenv import load_dotenv | |
| import concurrent.futures | |
| import threading | |
| from functools import lru_cache | |
| import hashlib | |
| import contextlib | |
| from requests.adapters import HTTPAdapter | |
| from urllib3.util.retry import Retry | |
| load_dotenv() | |
| ImageFile.LOAD_TRUNCATED_IMAGES = True | |
| def _download_bytes_cached(url: str, timeout_s: float = 12.0) -> Optional[bytes]: | |
| """Download image bytes from URL with Streamlit cache (30 min TTL).""" | |
| try: | |
| r = requests.get(url, timeout=(3.05, timeout_s), stream=True) | |
| r.raise_for_status() | |
| return r.content | |
| except Exception: | |
| return None | |
| class ImageGalleryApp: | |
| def __init__(self, mongo_uri: str, db_name: str, collection_name: str): | |
| """Initialize MongoDB and HTTP session with retry pooling.""" | |
| self.client = pymongo.MongoClient(mongo_uri) | |
| self.db = self.client[db_name] | |
| self.collection = self.db[collection_name] | |
| self._cache_lock = threading.Lock() | |
| self.session = requests.Session() | |
| retries = Retry( | |
| total=3, | |
| connect=3, | |
| read=3, | |
| backoff_factor=0.4, | |
| status_forcelist=[429, 500, 502, 503, 504], | |
| allowed_methods=["GET", "HEAD"] | |
| ) | |
| adapter = HTTPAdapter(pool_connections=64, pool_maxsize=64, max_retries=retries) | |
| self.session.mount("http://", adapter) | |
| self.session.mount("https://", adapter) | |
| self.thumb_max_size = (768, 768) | |
| def get_unique_categories(self) -> List[str]: | |
| """Return all unique categories from DB.""" | |
| try: | |
| categories = self.collection.distinct( | |
| "category", | |
| {"status": "completed", "lob": "leadgen_vivek", "category": {"$ne": None}} | |
| ) | |
| return ["All"] + sorted(categories) | |
| except Exception as e: | |
| print(f"Error fetching categories: {e}") | |
| return ["All"] | |
| def get_unique_filenames(self) -> List[str]: | |
| """Return all unique file names from DB.""" | |
| try: | |
| filenames = self.collection.distinct( | |
| "file_name", | |
| {"status": "completed", "lob": "leadgen_vivek", "file_name": {"$ne": None}} | |
| ) | |
| return ["All"] + sorted(filenames) | |
| except Exception as e: | |
| print(f"Error fetching filenames: {e}") | |
| return ["All"] | |
| def parse_date_input(self, date_input) -> Optional[date]: | |
| """Convert string or date input to date object.""" | |
| if not date_input or date_input == "": | |
| return None | |
| if isinstance(date_input, date): | |
| return date_input | |
| if isinstance(date_input, str): | |
| try: | |
| if date_input.count('-') == 2: | |
| return datetime.strptime(date_input, '%Y-%m-%d').date() | |
| elif date_input.count('/') == 2: | |
| return datetime.strptime(date_input, '%m/%d/%Y').date() | |
| except ValueError as e: | |
| print(f"Error parsing date string '{date_input}': {e}") | |
| return None | |
| return None | |
| def load_image_from_url(self, url: str) -> Optional[Image.Image]: | |
| """Load and thumbnail image from URL with caching and pooling.""" | |
| try: | |
| data = _download_bytes_cached(url) | |
| if data is None: | |
| r = self.session.get(url, timeout=(3.05, 12), stream=True) | |
| r.raise_for_status() | |
| data = r.content | |
| img = Image.open(io.BytesIO(data)) | |
| with contextlib.suppress(Exception): | |
| img = img.convert("RGB") | |
| if img.size[0] > self.thumb_max_size[0] or img.size[1] > self.thumb_max_size[1]: | |
| img.thumbnail(self.thumb_max_size, Image.Resampling.LANCZOS) | |
| return img | |
| except Exception as e: | |
| print(f"Error loading image from {url}: {e}") | |
| return None | |
| def load_images_parallel(self, urls: List[str], max_workers: int = 8) -> List[Tuple[str, Optional[Image.Image]]]: | |
| """Load multiple images in parallel; returns list of (url, image).""" | |
| results: List[Tuple[str, Optional[Image.Image]]] = [] | |
| if not urls: | |
| return results | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex: | |
| future_map = {ex.submit(self.load_image_from_url, u): u for u in urls} | |
| for fut in concurrent.futures.as_completed(future_map): | |
| u = future_map[fut] | |
| try: | |
| img = fut.result() | |
| except Exception as e: | |
| print(f"Error in parallel loading: {e}") | |
| img = None | |
| results.append((u, img)) | |
| return results | |
| def search_images_page( | |
| self, | |
| category: str = "All", | |
| file_name: str = "All", | |
| start_date: Optional[date] = None, | |
| end_date: Optional[date] = None, | |
| lob: str = "leadgen_vivek", | |
| page: int = 0, | |
| page_size: int = 24, | |
| ) -> Tuple[List[Dict], int]: | |
| """ | |
| Paginated search for images by filters. | |
| Returns: | |
| - docs: list of {url, category, file_name, created_at, prompt, status} | |
| - total: total number of image URLs matching the filters | |
| """ | |
| match = { | |
| "lob": lob, | |
| "status": "completed", | |
| "urls": {"$exists": True, "$ne": []} | |
| } | |
| if category != "All": | |
| match["category"] = category | |
| if file_name != "All": | |
| match["file_name"] = file_name | |
| start_date_obj = self.parse_date_input(start_date) | |
| end_date_obj = self.parse_date_input(end_date) | |
| if start_date_obj or end_date_obj: | |
| date_query = {} | |
| if start_date_obj: | |
| date_query["$gte"] = datetime.combine(start_date_obj, datetime.min.time()) | |
| if end_date_obj: | |
| date_query["$lte"] = datetime.combine(end_date_obj, datetime.max.time()) | |
| if date_query: | |
| match["created_at"] = date_query | |
| try: | |
| count_pipeline = [ | |
| {"$match": match}, | |
| {"$unwind": "$urls"}, | |
| {"$count": "n"} | |
| ] | |
| count_doc = list(self.collection.aggregate(count_pipeline)) | |
| total = count_doc[0]["n"] if count_doc else 0 | |
| except Exception as e: | |
| print(f"Count error: {e}") | |
| total = 0 | |
| if total == 0: | |
| return [], 0 | |
| pipeline = [ | |
| {"$match": match}, | |
| {"$unwind": "$urls"}, | |
| {"$sort": {"created_at": -1}}, | |
| {"$skip": max(0, page) * max(1, page_size)}, | |
| {"$limit": max(1, page_size)}, | |
| {"$project": { | |
| "_id": 0, | |
| "url": "$urls", | |
| "category": 1, | |
| "file_name": 1, | |
| "created_at": 1, | |
| "prompt": 1, | |
| "status": 1 | |
| }} | |
| ] | |
| try: | |
| docs = list(self.collection.aggregate(pipeline, allowDiskUse=True)) | |
| except Exception as e: | |
| print(f"Aggregation error: {e}") | |
| docs = [] | |
| return docs, total | |
| def create_streamlit_app(mongo_uri: str, db_name: str, collection_name: str): | |
| """Main Streamlit UI for image gallery.""" | |
| app = ImageGalleryApp(mongo_uri, db_name, collection_name) | |
| def get_filter_choices(): | |
| """Fetch filter choices for category and filename.""" | |
| try: | |
| categories = app.get_unique_categories() | |
| filenames = app.get_unique_filenames() | |
| return categories, filenames | |
| except Exception: | |
| return ["All"], ["All"] | |
| # Session state defaults | |
| if "categories_list" not in st.session_state: | |
| st.session_state["categories_list"], st.session_state["filenames_list"] = get_filter_choices() | |
| st.session_state.setdefault("selected_category", "All") | |
| st.session_state.setdefault("selected_filename", "All") | |
| st.session_state.setdefault("selected_lob", "leadgen_vivek") | |
| today = datetime.now(timezone.utc).date() | |
| st.session_state.setdefault("use_date_filter", True) | |
| st.session_state.setdefault("selected_start_date", today) | |
| st.session_state.setdefault("selected_end_date", today) | |
| st.session_state.setdefault("page", 0) | |
| st.session_state.setdefault("page_size", 24) | |
| st.session_state.setdefault("last_query_total", 0) | |
| st.session_state.setdefault("did_search", False) | |
| # Custom styles for UI | |
| st.markdown(""" | |
| <style> | |
| .orange-btn > button { background:#FF7300 !important; color:#fff !important; font-weight:600 !important; } | |
| .skel { | |
| width: 100%; | |
| aspect-ratio: 1 / 1; | |
| border-radius: 10px; | |
| background: linear-gradient(90deg, #eee 25%, #f5f5f5 37%, #eee 63%); | |
| background-size: 400% 100%; | |
| animation: shimmer 1.4s ease infinite; | |
| } | |
| @keyframes shimmer { | |
| 0% { background-position: 0% 0; } | |
| 100% { background-position: -135% 0; } | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Filter controls | |
| col1, col2, col3 = st.columns([1,1,1]) | |
| with col1: | |
| category = st.selectbox( | |
| "Category", | |
| options=st.session_state["categories_list"], | |
| index=st.session_state["categories_list"].index(st.session_state["selected_category"]) | |
| if st.session_state["selected_category"] in st.session_state["categories_list"] else 0 | |
| ) | |
| with col2: | |
| filename = st.selectbox( | |
| "File Name", | |
| options=st.session_state["filenames_list"], | |
| index=st.session_state["filenames_list"].index(st.session_state["selected_filename"]) | |
| if st.session_state["selected_filename"] in st.session_state["filenames_list"] else 0 | |
| ) | |
| with col3: | |
| lob = st.text_input("LOB (Line of Business)", value=st.session_state["selected_lob"]) | |
| coldf = st.columns([1,1,1]) | |
| with coldf[0]: | |
| use_date_filter = st.checkbox("Filter by date", value=st.session_state["use_date_filter"]) | |
| with coldf[1]: | |
| start_date = st.date_input( | |
| "Start Date", | |
| value=st.session_state["selected_start_date"], | |
| disabled=not use_date_filter | |
| ) | |
| with coldf[2]: | |
| end_date = st.date_input( | |
| "End Date", | |
| value=st.session_state["selected_end_date"], | |
| disabled=not use_date_filter | |
| ) | |
| col_misc = st.columns([1,1,1,2]) | |
| with col_misc[0]: | |
| page_size = st.selectbox("Images per page", [8, 12, 16, 24, 32, 48], | |
| index=[8, 12, 16, 24, 32, 48].index(st.session_state["page_size"])) | |
| col_btn1, col_btn2, col_btn3 = st.columns([2,2,2]) | |
| with col_btn1: | |
| search_clicked = st.button("🔍 Search", use_container_width=True) | |
| with col_btn2: | |
| refresh_clicked = st.button("🔄 Refresh Filters", use_container_width=True) | |
| with col_btn3: | |
| reset_clicked = st.button("♻️ Reset Page", use_container_width=True) | |
| # Button events | |
| if refresh_clicked: | |
| st.session_state["categories_list"], st.session_state["filenames_list"] = get_filter_choices() | |
| st.session_state["selected_category"] = "All" | |
| st.session_state["selected_filename"] = "All" | |
| st.session_state["page"] = 0 | |
| st.rerun() | |
| if reset_clicked: | |
| st.session_state["page"] = 0 | |
| st.rerun() | |
| if search_clicked: | |
| st.session_state["selected_category"] = category | |
| st.session_state["selected_filename"] = filename | |
| st.session_state["selected_lob"] = lob | |
| st.session_state["use_date_filter"] = use_date_filter | |
| st.session_state["selected_start_date"] = start_date | |
| st.session_state["selected_end_date"] = end_date | |
| st.session_state["page_size"] = page_size | |
| st.session_state["page"] = 0 | |
| st.session_state["did_search"] = True | |
| st.rerun() | |
| # Results display | |
| if st.session_state["did_search"]: | |
| _start = st.session_state["selected_start_date"] if st.session_state["use_date_filter"] else None | |
| _end = st.session_state["selected_end_date"] if st.session_state["use_date_filter"] else None | |
| docs, total = app.search_images_page( | |
| category=st.session_state["selected_category"], | |
| file_name=st.session_state["selected_filename"], | |
| start_date=_start, | |
| end_date=_end, | |
| lob=st.session_state["selected_lob"], | |
| page=st.session_state["page"], | |
| page_size=st.session_state["page_size"] | |
| ) | |
| st.session_state["last_query_total"] = total | |
| total_pages = max(1, (total + st.session_state["page_size"] - 1) // st.session_state["page_size"]) | |
| nav1, nav2, nav3 = st.columns([1,2,1]) | |
| with nav1: | |
| if st.button("⬅️ Prev", disabled=(st.session_state["page"] <= 0)): | |
| st.session_state["page"] -= 1 | |
| st.rerun() | |
| with nav2: | |
| st.markdown( | |
| f"<div style='text-align:center'>Page <b>{st.session_state['page']+1}</b> of <b>{total_pages}</b>" | |
| f" · <b>{total}</b> images total</div>", | |
| unsafe_allow_html=True | |
| ) | |
| with nav3: | |
| if st.button("Next ➡️", disabled=(st.session_state["page"] >= total_pages - 1)): | |
| st.session_state["page"] += 1 | |
| st.rerun() | |
| st.divider() | |
| if total == 0 or not docs: | |
| st.info("No images found for the current filters.") | |
| return | |
| st.markdown("#### Images") | |
| cols = st.columns(4) | |
| # Show skeletons while loading images | |
| placeholders = [] | |
| for i, _ in enumerate(docs): | |
| ph = cols[i % 4].empty() | |
| ph.markdown("<div class='skel'></div>", unsafe_allow_html=True) | |
| placeholders.append(ph) | |
| urls = [d["url"] for d in docs] | |
| loaded = app.load_images_parallel(urls, max_workers=8) | |
| url_to_img = {u: img for (u, img) in loaded} | |
| # Fill placeholders with loaded images | |
| for i, d in enumerate(docs): | |
| img = url_to_img.get(d["url"]) | |
| meta = f"{d.get('category','N/A')} | {d.get('file_name','N/A')} | {d.get('created_at','')}" | |
| if img: | |
| placeholders[i].image(img, use_container_width=True, caption=meta) | |
| else: | |
| placeholders[i].warning("Failed to load image") | |
| else: | |
| st.info("Set your filters and click **Search** to load images.") |