AdGenesis-App / ui /load_file.py
userIdc2024's picture
Update ui/load_file.py
aa0c905 verified
import streamlit as st
import pymongo
from datetime import datetime, date, timezone
from typing import List, Tuple, Optional, Dict
import requests
from PIL import Image, ImageFile
import io
from dotenv import load_dotenv
import concurrent.futures
import threading
from functools import lru_cache
import contextlib
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
load_dotenv()
ImageFile.LOAD_TRUNCATED_IMAGES = True
@st.cache_data(show_spinner=False, ttl=60 * 30)
def _download_bytes_cached(url: str, timeout_s: float = 12.0) -> Optional[bytes]:
try:
r = requests.get(url, timeout=(3.05, timeout_s), stream=True)
r.raise_for_status()
return r.content
except Exception:
return None
class ImageGalleryApp:
def __init__(self, mongo_uri: str, db_name: str, collection_name: str):
self.client = pymongo.MongoClient(mongo_uri)
self.db = self.client[db_name]
self.collection = self.db[collection_name]
self._cache_lock = threading.Lock()
self.session = requests.Session()
retries = Retry(total=3, connect=3, read=3, backoff_factor=0.4, status_forcelist=[429,500,502,503,504], allowed_methods=["GET","HEAD"])
adapter = HTTPAdapter(pool_connections=64, pool_maxsize=64, max_retries=retries)
self.session.mount("http://", adapter); self.session.mount("https://", adapter)
self.thumb_max_size = (768, 768)
@lru_cache(maxsize=128)
def get_unique_categories(self) -> List[str]:
try:
categories = self.collection.distinct("category", {"status": "completed", "category": {"$ne": None}})
return ["All"] + sorted(categories)
except Exception:
return ["All"]
@lru_cache(maxsize=128)
def get_unique_filenames(self) -> List[str]:
try:
filenames = self.collection.distinct("file_name", {"status": "completed", "file_name": {"$ne": None}})
return ["All"] + sorted(filenames)
except Exception:
return ["All"]
def parse_date_input(self, date_input) -> Optional[date]:
if not date_input or date_input == "": return None
if isinstance(date_input, date): return date_input
if isinstance(date_input, str):
from datetime import datetime as _dt
for fmt in ("%Y-%m-%d", "%m/%d/%Y"):
try: return _dt.strptime(date_input, fmt).date()
except Exception: pass
return None
def load_image_from_url(self, url: str) -> Optional[Image.Image]:
try:
data = _download_bytes_cached(url) or self.session.get(url, timeout=(3.05, 12), stream=True).content
img = Image.open(io.BytesIO(data))
with contextlib.suppress(Exception): img = img.convert("RGB")
if img.size[0] > 768 or img.size[1] > 768: img.thumbnail((768,768), Image.Resampling.LANCZOS)
return img
except Exception:
return None
def load_images_parallel(self, urls: List[str], max_workers: int = 8) -> List[Tuple[str, Optional[Image.Image]]]:
results: List[Tuple[str, Optional[Image.Image]]] = []
if not urls: return results
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
future_map = {ex.submit(self.load_image_from_url, u): u for u in urls}
for fut in concurrent.futures.as_completed(future_map):
u = future_map[fut]
try: img = fut.result()
except Exception: img = None
results.append((u, img))
return results
def search_images_page(self, category="All", file_name="All", start_date=None, end_date=None, lob="search_arb", page=0, page_size=24) -> Tuple[List[Dict], int]:
match = {"status": "completed", "urls": {"$exists": True, "$ne": []}}
if category != "All": match["category"] = category
if file_name != "All": match["file_name"] = file_name
from datetime import datetime as _dt
if start_date or end_date:
date_query = {}
if start_date: date_query["$gte"] = _dt.combine(self.parse_date_input(start_date), _dt.min.time())
if end_date: date_query["$lte"] = _dt.combine(self.parse_date_input(end_date), _dt.max.time())
if date_query: match["created_at"] = date_query
try:
count_pipeline = [{"$match": match},{"$unwind": "$urls"},{"$count": "n"}]
count_doc = list(self.collection.aggregate(count_pipeline))
total = count_doc[0]["n"] if count_doc else 0
except Exception:
total = 0
if total == 0: return [], 0
pipeline = [
{"$match": match},{"$unwind": "$urls"},{"$sort": {"created_at": -1}},
{"$skip": max(0, page) * max(1, page_size)},{"$limit": max(1, page_size)},
{"$project": {"_id": 0,"url": "$urls","category": 1,"file_name": 1,"created_at": 1,"prompt": 1,"status": 1}}
]
try:
docs = list(self.collection.aggregate(pipeline, allowDiskUse=True))
except Exception:
docs = []
return docs, total
def create_streamlit_app(mongo_uri: str, db_name: str, collection_name: str):
app = ImageGalleryApp(mongo_uri, db_name, collection_name)
def get_filter_choices():
try:
categories = app.get_unique_categories()
filenames = app.get_unique_filenames()
return categories, filenames
except Exception:
return ["All"], ["All"]
if "categories_list" not in st.session_state:
st.session_state["categories_list"], st.session_state["filenames_list"] = get_filter_choices()
st.session_state.setdefault("selected_category", "All")
st.session_state.setdefault("selected_filename", "All")
st.session_state.setdefault("selected_lob", "search_arb")
today = datetime.now(timezone.utc).date()
st.session_state.setdefault("use_date_filter", True)
st.session_state.setdefault("selected_start_date", today)
st.session_state.setdefault("selected_end_date", today)
st.session_state.setdefault("page", 0)
st.session_state.setdefault("page_size", 24)
st.session_state.setdefault("last_query_total", 0)
st.session_state.setdefault("did_search", False)
col1, col2= st.columns([1,1])
with col1:
category = st.selectbox("Category", options=st.session_state["categories_list"])
with col2:
filename = st.selectbox("File Name", options=st.session_state["filenames_list"])
coldf = st.columns([1,1,1])
with coldf[0]:
use_date_filter = st.checkbox("Filter by date", value=st.session_state["use_date_filter"])
with coldf[1]:
start_date = st.date_input("Start Date", value=st.session_state["selected_start_date"], disabled=not use_date_filter)
with coldf[2]:
end_date = st.date_input("End Date", value=st.session_state["selected_end_date"], disabled=not use_date_filter)
col_misc = st.columns([1,1,1,2])
with col_misc[0]:
page_size = st.selectbox("Images per page", [8,12,16,24,32,48], index=[8,12,16,24,32,48].index(st.session_state["page_size"]))
col_btn1, col_btn2, col_btn3 = st.columns([2,2,2])
with col_btn1:
search_clicked = st.button("🔍 Search", width='content')
with col_btn2:
refresh_clicked = st.button("🔄 Refresh Filters", width='content')
with col_btn3:
reset_clicked = st.button("♻️ Reset Page", width='content')
if refresh_clicked:
st.session_state["categories_list"], st.session_state["filenames_list"] = get_filter_choices()
st.session_state["selected_category"] = "All"; st.session_state["selected_filename"] = "All"; st.session_state["page"] = 0; st.rerun()
if reset_clicked:
st.session_state["page"] = 0; st.rerun()
if search_clicked:
st.session_state["selected_category"] = category; st.session_state["selected_filename"] = filename
st.session_state["use_date_filter"] = use_date_filter; st.session_state["selected_start_date"] = start_date; st.session_state["selected_end_date"] = end_date
st.session_state["page_size"] = page_size; st.session_state["page"] = 0; st.session_state["did_search"] = True; st.rerun()
if st.session_state["did_search"]:
_start = st.session_state["selected_start_date"] if st.session_state["use_date_filter"] else None
_end = st.session_state["selected_end_date"] if st.session_state["use_date_filter"] else None
docs, total = app.search_images_page(category=st.session_state["selected_category"], file_name=st.session_state["selected_filename"], start_date=_start, end_date=_end, lob=st.session_state["selected_lob"], page=st.session_state["page"], page_size=st.session_state["page_size"])
st.session_state["last_query_total"] = total
total_pages = max(1, (total + st.session_state["page_size"] - 1) // st.session_state["page_size"])
nav1, nav2, nav3 = st.columns([1,2,1])
with nav1:
if st.button("⬅️ Prev", disabled=(st.session_state["page"] <= 0)): st.session_state["page"] -= 1; st.rerun()
with nav2:
st.markdown(f"<div style='text-align:center'>Page <b>{st.session_state['page']+1}</b> of <b>{total_pages}</b> &nbsp;&middot;&nbsp; <b>{total}</b> images total</div>", unsafe_allow_html=True)
with nav3:
if st.button("Next ➡️", disabled=(st.session_state["page"] >= total_pages - 1)): st.session_state["page"] += 1; st.rerun()
st.divider()
if total == 0 or not docs:
st.info("No images found for the current filters."); return
st.markdown("#### Images"); cols = st.columns(4)
placeholders = []
for i, _ in enumerate(docs):
ph = cols[i % 4].empty(); ph.markdown("<div style='width:100%;aspect-ratio:1/1;border-radius:10px;background:#eee'></div>", unsafe_allow_html=True); placeholders.append(ph)
urls = [d["url"] for d in docs]
loaded = app.load_images_parallel(urls, max_workers=8)
url_to_img = {u: img for (u, img) in loaded}
for i, d in enumerate(docs):
img = url_to_img.get(d["url"]); meta = f"{d.get('category','N/A')} | {d.get('file_name','N/A')} | {d.get('created_at','')}"
if img: placeholders[i].image(img,width='stretch', caption=meta)
else: placeholders[i].warning("Failed to load image")
else:
st.info("Set your filters and click **Search** to load images.")