Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pymongo | |
| from datetime import datetime, date, timezone | |
| import os | |
| from typing import List, Tuple, Optional | |
| import requests | |
| from PIL import Image | |
| import io | |
| from dotenv import load_dotenv | |
| import concurrent.futures | |
| import threading | |
| from functools import lru_cache | |
| from gradio_calendar import Calendar | |
| load_dotenv() | |
| class ImageGalleryApp: | |
| def __init__(self, mongo_uri: str, db_name: str, collection_name: str): | |
| """Initialize the MongoDB connection""" | |
| self.client = pymongo.MongoClient(mongo_uri) | |
| self.db = self.client[db_name] | |
| self.collection = self.db[collection_name] | |
| self._categories_cache = None | |
| self._filenames_cache = None | |
| self._cache_lock = threading.Lock() | |
| def get_unique_categories(self) -> List[str]: | |
| """Get all unique categories from the database with caching""" | |
| try: | |
| # Only get categories for completed records to improve performance | |
| pipeline = [ | |
| {"$match": {"status": "completed", "lob": "test"}}, | |
| {"$group": {"_id": "$category"}}, | |
| {"$sort": {"_id": 1}} | |
| ] | |
| categories = [doc["_id"] for doc in self.collection.aggregate(pipeline) if doc["_id"]] | |
| return ["All"] + categories | |
| except Exception as e: | |
| print(f"Error fetching categories: {e}") | |
| return ["All"] | |
| def get_unique_filenames(self) -> List[str]: | |
| """Get all unique file names from the database with caching""" | |
| try: | |
| # Only get filenames for completed records | |
| pipeline = [ | |
| {"$match": {"status": "completed", "lob": "test"}}, | |
| {"$group": {"_id": "$file_name"}}, | |
| {"$sort": {"_id": 1}} | |
| ] | |
| filenames = [doc["_id"] for doc in self.collection.aggregate(pipeline) if doc["_id"]] | |
| return ["All"] + filenames | |
| except Exception as e: | |
| print(f"Error fetching filenames: {e}") | |
| return ["All"] | |
| def load_image_from_url(self, url: str) -> Optional[Image.Image]: | |
| """Load image from URL with better error handling""" | |
| try: | |
| full_url = url | |
| response = requests.get(full_url, timeout=5, stream=True) # Reduced timeout, added streaming | |
| response.raise_for_status() | |
| # Limit image size to prevent memory issues | |
| image = Image.open(io.BytesIO(response.content)) | |
| # Resize large images to improve performance | |
| max_size = (800, 800) | |
| if image.size[0] > max_size[0] or image.size[1] > max_size[1]: | |
| image.thumbnail(max_size, Image.Resampling.LANCZOS) | |
| return image | |
| except Exception as e: | |
| print(f"Error loading image from {url}: {e}") | |
| return None | |
| def load_images_parallel(self, urls: List[str], max_workers: int = 5) -> List[Image.Image]: | |
| """Load multiple images in parallel""" | |
| images = [] | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: | |
| future_to_url = {executor.submit(self.load_image_from_url, url): url for url in | |
| urls[:]} # Limit to first 10 URLs | |
| for future in concurrent.futures.as_completed(future_to_url): | |
| try: | |
| image = future.result() | |
| if image: | |
| images.append(image) | |
| except Exception as e: | |
| print(f"Error in parallel loading: {e}") | |
| return images | |
| def parse_date_input(self, date_input) -> Optional[date]: | |
| """Convert string date input to date object""" | |
| if not date_input or date_input == "": | |
| return None | |
| if isinstance(date_input, date): | |
| return date_input | |
| if isinstance(date_input, str): | |
| try: | |
| # Handle YYYY-MM-DD format (most common from Calendar component) | |
| if date_input.count('-') == 2: | |
| return datetime.strptime(date_input, '%Y-%m-%d').date() | |
| # Handle MM/DD/YYYY format | |
| elif date_input.count('/') == 2: | |
| return datetime.strptime(date_input, '%m/%d/%Y').date() | |
| except ValueError as e: | |
| print(f"Error parsing date string '{date_input}': {e}") | |
| return None | |
| return None | |
| def search_images(self, | |
| category: str = "All", | |
| file_name: str = "All", | |
| start_date: Optional[date] = datetime.now(timezone.utc).date(), | |
| end_date: Optional[date] = datetime.now(timezone.utc).date(), | |
| lob: str = "test") -> Tuple[List[Image.Image], str]: | |
| """Search images based on filters with optimizations""" | |
| # Convert string inputs to date objects | |
| start_date_obj = self.parse_date_input(start_date) | |
| end_date_obj = self.parse_date_input(end_date) | |
| # Build query with default status filter | |
| query = { | |
| "lob": lob, | |
| "status": "completed", # Default filter for completed status | |
| "urls": {"$exists": True, "$ne": []} # Only get records with URLs | |
| } | |
| if category != "All": | |
| query["category"] = category | |
| if file_name != "All": | |
| query["file_name"] = file_name | |
| # Handle date filtering with proper date objects | |
| if start_date or end_date: | |
| date_query = {} | |
| if start_date_obj: | |
| try: | |
| # Convert date to datetime at start of day | |
| start_dt = datetime.combine(start_date_obj, datetime.min.time()) | |
| date_query["$gte"] = start_dt | |
| except Exception as e: | |
| print(f"Error parsing start date: {e}") | |
| if end_date_obj: | |
| try: | |
| # Convert date to datetime at end of day | |
| end_dt = datetime.combine(end_date_obj, datetime.max.time()) | |
| date_query["$lte"] = end_dt | |
| except Exception as e: | |
| print(f"Error parsing end date: {e}") | |
| if date_query: | |
| query["created_at"] = date_query | |
| try: | |
| # Use projection to only fetch needed fields | |
| projection = { | |
| "category": 1, | |
| "file_name": 1, | |
| "created_at": 1, | |
| "urls": 1, | |
| "status": 1, | |
| "prompt": 1 | |
| } | |
| # Execute query with optimizations | |
| cursor = self.collection.find(query, projection).sort("created_at", -1) | |
| documents = list(cursor) | |
| if not documents: | |
| return [], "No records found matching the criteria." | |
| # Collect all URLs for parallel loading | |
| all_urls = [] | |
| url_to_doc = {} | |
| for doc in documents: | |
| urls = doc.get("urls", []) | |
| if urls: | |
| # Take only first URL per document for faster loading | |
| for url in urls: | |
| first_url = url | |
| all_urls.append(first_url) | |
| url_to_doc[first_url] = doc | |
| # Load images in parallel | |
| print(f"Loading {len(all_urls)} images...") | |
| images = self.load_images_parallel(all_urls, max_workers=8) | |
| # Build info text | |
| info_text = f"Found {len(documents)} records (showing {len(images)} images)\n" | |
| info_text += f"Filter: Status = completed, LOB = {lob}\n" | |
| if start_date_obj: | |
| info_text += f"Start Date: {start_date_obj}\n" | |
| if end_date_obj: | |
| info_text += f"End Date: {end_date_obj}\n" | |
| info_text += "\n" | |
| for i, doc in enumerate(documents[:len(images)]): | |
| info_text += f"#{i + 1}\n" | |
| info_text += f"Category: {doc.get('category', 'N/A')}\n" | |
| info_text += f"File: {doc.get('file_name', 'N/A')}\n" | |
| info_text += f"Prompt: {doc.get('prompt', 'N/A')}\n" | |
| info_text += f"Created: {doc.get('created_at', 'N/A')}\n" | |
| info_text += f"URLs: {len(doc.get('urls', []))} image(s)\n" | |
| info_text += f"Status: {doc.get('status', 'N/A')}\n" | |
| info_text += "-" * 30 + "\n" | |
| return images, info_text | |
| except Exception as e: | |
| error_msg = f"Error searching database: {str(e)}" | |
| print(error_msg) | |
| return [], error_msg | |
| def create_gradio_app(mongo_uri: str, db_name: str, collection_name: str): | |
| """Create and launch the Gradio application""" | |
| app = ImageGalleryApp(mongo_uri, db_name, collection_name) | |
| def update_gallery(category, file_name, start_date, end_date, lob): | |
| try: | |
| images, info = app.search_images(category, file_name, start_date, end_date, lob) | |
| return images, info | |
| except Exception as e: | |
| return [], f"Error: {str(e)}" | |
| def get_filter_choices(): | |
| """Get filter choices with loading indicator""" | |
| try: | |
| categories = app.get_unique_categories() | |
| filenames = app.get_unique_filenames() | |
| return categories, filenames | |
| except Exception as e: | |
| print(f"Error loading filter choices: {e}") | |
| return ["All"], ["All"] | |
| # Create Gradio interface | |
| with gr.Blocks(title="Image Gallery", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# Generated Images") | |
| # Loading state | |
| loading_state = gr.State(False) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| category_dropdown = gr.Dropdown( | |
| choices=["All"], | |
| value="All", | |
| label="Category", | |
| interactive=True | |
| ) | |
| filename_dropdown = gr.Dropdown( | |
| choices=["All"], | |
| value="All", | |
| label="File Name", | |
| interactive=True | |
| ) | |
| lob_input = gr.Textbox( | |
| value="test", | |
| label="LOB (Line of Business)", | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| start_date_calendar = Calendar( | |
| label="Start Date", | |
| info="Select start date for filtering", | |
| value=str(datetime.now(timezone.utc).date()) | |
| ) | |
| end_date_calendar = Calendar( | |
| label="End Date", | |
| info="Select end date for filtering", | |
| value=str(datetime.now(timezone.utc).date()) | |
| ) | |
| with gr.Row(): | |
| search_btn = gr.Button("π Search Images", variant="primary") | |
| refresh_btn = gr.Button("π Refresh Filters") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| gallery = gr.Gallery( | |
| label="Images (Status: completed)", | |
| show_label=True, | |
| elem_id="gallery", | |
| columns=4, # Increased columns for better layout | |
| rows=3, | |
| height="700px", | |
| object_fit="contain", | |
| preview=True # Enable preview mode | |
| ) | |
| with gr.Column(scale=1): | |
| info_text = gr.Textbox( | |
| label="Document Information", | |
| lines=25, | |
| max_lines=25, | |
| interactive=False | |
| ) | |
| # Event handlers | |
| def search_with_loading(*args): | |
| return update_gallery(*args) | |
| search_btn.click( | |
| fn=search_with_loading, | |
| inputs=[category_dropdown, filename_dropdown, start_date_calendar, end_date_calendar, lob_input], | |
| outputs=[gallery, info_text] | |
| ) | |
| def refresh_filters(): | |
| try: | |
| categories, filenames = get_filter_choices() | |
| return ( | |
| gr.Dropdown(choices=categories, value="All"), | |
| gr.Dropdown(choices=filenames, value="All") | |
| ) | |
| except Exception as e: | |
| print(f"Error refreshing filters: {e}") | |
| return category_dropdown, filename_dropdown | |
| refresh_btn.click( | |
| fn=refresh_filters, | |
| outputs=[category_dropdown, filename_dropdown] | |
| ) | |
| # Initialize filters and load initial data | |
| def initialize_app(): | |
| try: | |
| categories, filenames = get_filter_choices() | |
| images, info = app.search_images() # Load fewer images initially | |
| return ( | |
| gr.Dropdown(choices=categories, value="All"), | |
| gr.Dropdown(choices=filenames, value="All"), | |
| images, | |
| info | |
| ) | |
| except Exception as e: | |
| error_msg = f"Error initializing app: {str(e)}" | |
| return category_dropdown, filename_dropdown, [], error_msg | |
| demo.load( | |
| fn=initialize_app, | |
| outputs=[category_dropdown, filename_dropdown, gallery, info_text] | |
| ) | |
| return demo |