Leadgen-AdGenesis / load_files.py
userIdc2024's picture
Update load_files.py
fd642f6 verified
import gradio as gr
import pymongo
from datetime import datetime, date, timezone
import os
from typing import List, Tuple, Optional
import requests
from PIL import Image
import io
from dotenv import load_dotenv
import concurrent.futures
import threading
from functools import lru_cache
from gradio_calendar import Calendar
load_dotenv()
class ImageGalleryApp:
def __init__(self, mongo_uri: str, db_name: str, collection_name: str):
"""Initialize the MongoDB connection"""
self.client = pymongo.MongoClient(mongo_uri)
self.db = self.client[db_name]
self.collection = self.db[collection_name]
self._categories_cache = None
self._filenames_cache = None
self._cache_lock = threading.Lock()
@lru_cache(maxsize=128)
def get_unique_categories(self) -> List[str]:
"""Get all unique categories from the database with caching"""
try:
# Only get categories for completed records to improve performance
pipeline = [
{"$match": {"status": "completed", "lob": "test"}},
{"$group": {"_id": "$category"}},
{"$sort": {"_id": 1}}
]
categories = [doc["_id"] for doc in self.collection.aggregate(pipeline) if doc["_id"]]
return ["All"] + categories
except Exception as e:
print(f"Error fetching categories: {e}")
return ["All"]
@lru_cache(maxsize=128)
def get_unique_filenames(self) -> List[str]:
"""Get all unique file names from the database with caching"""
try:
# Only get filenames for completed records
pipeline = [
{"$match": {"status": "completed", "lob": "test"}},
{"$group": {"_id": "$file_name"}},
{"$sort": {"_id": 1}}
]
filenames = [doc["_id"] for doc in self.collection.aggregate(pipeline) if doc["_id"]]
return ["All"] + filenames
except Exception as e:
print(f"Error fetching filenames: {e}")
return ["All"]
def load_image_from_url(self, url: str) -> Optional[Image.Image]:
"""Load image from URL with better error handling"""
try:
full_url = url
response = requests.get(full_url, timeout=5, stream=True) # Reduced timeout, added streaming
response.raise_for_status()
# Limit image size to prevent memory issues
image = Image.open(io.BytesIO(response.content))
# Resize large images to improve performance
max_size = (800, 800)
if image.size[0] > max_size[0] or image.size[1] > max_size[1]:
image.thumbnail(max_size, Image.Resampling.LANCZOS)
return image
except Exception as e:
print(f"Error loading image from {url}: {e}")
return None
def load_images_parallel(self, urls: List[str], max_workers: int = 5) -> List[Image.Image]:
"""Load multiple images in parallel"""
images = []
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_url = {executor.submit(self.load_image_from_url, url): url for url in
urls[:]} # Limit to first 10 URLs
for future in concurrent.futures.as_completed(future_to_url):
try:
image = future.result()
if image:
images.append(image)
except Exception as e:
print(f"Error in parallel loading: {e}")
return images
def parse_date_input(self, date_input) -> Optional[date]:
"""Convert string date input to date object"""
if not date_input or date_input == "":
return None
if isinstance(date_input, date):
return date_input
if isinstance(date_input, str):
try:
# Handle YYYY-MM-DD format (most common from Calendar component)
if date_input.count('-') == 2:
return datetime.strptime(date_input, '%Y-%m-%d').date()
# Handle MM/DD/YYYY format
elif date_input.count('/') == 2:
return datetime.strptime(date_input, '%m/%d/%Y').date()
except ValueError as e:
print(f"Error parsing date string '{date_input}': {e}")
return None
return None
def search_images(self,
category: str = "All",
file_name: str = "All",
start_date: Optional[date] = datetime.now(timezone.utc).date(),
end_date: Optional[date] = datetime.now(timezone.utc).date(),
lob: str = "test") -> Tuple[List[Image.Image], str]:
"""Search images based on filters with optimizations"""
# Convert string inputs to date objects
start_date_obj = self.parse_date_input(start_date)
end_date_obj = self.parse_date_input(end_date)
# Build query with default status filter
query = {
"lob": lob,
"status": "completed", # Default filter for completed status
"urls": {"$exists": True, "$ne": []} # Only get records with URLs
}
if category != "All":
query["category"] = category
if file_name != "All":
query["file_name"] = file_name
# Handle date filtering with proper date objects
if start_date or end_date:
date_query = {}
if start_date_obj:
try:
# Convert date to datetime at start of day
start_dt = datetime.combine(start_date_obj, datetime.min.time())
date_query["$gte"] = start_dt
except Exception as e:
print(f"Error parsing start date: {e}")
if end_date_obj:
try:
# Convert date to datetime at end of day
end_dt = datetime.combine(end_date_obj, datetime.max.time())
date_query["$lte"] = end_dt
except Exception as e:
print(f"Error parsing end date: {e}")
if date_query:
query["created_at"] = date_query
try:
# Use projection to only fetch needed fields
projection = {
"category": 1,
"file_name": 1,
"created_at": 1,
"urls": 1,
"status": 1,
"prompt": 1
}
# Execute query with optimizations
cursor = self.collection.find(query, projection).sort("created_at", -1)
documents = list(cursor)
if not documents:
return [], "No records found matching the criteria."
# Collect all URLs for parallel loading
all_urls = []
url_to_doc = {}
for doc in documents:
urls = doc.get("urls", [])
if urls:
# Take only first URL per document for faster loading
for url in urls:
first_url = url
all_urls.append(first_url)
url_to_doc[first_url] = doc
# Load images in parallel
print(f"Loading {len(all_urls)} images...")
images = self.load_images_parallel(all_urls, max_workers=8)
# Build info text
info_text = f"Found {len(documents)} records (showing {len(images)} images)\n"
info_text += f"Filter: Status = completed, LOB = {lob}\n"
if start_date_obj:
info_text += f"Start Date: {start_date_obj}\n"
if end_date_obj:
info_text += f"End Date: {end_date_obj}\n"
info_text += "\n"
for i, doc in enumerate(documents[:len(images)]):
info_text += f"#{i + 1}\n"
info_text += f"Category: {doc.get('category', 'N/A')}\n"
info_text += f"File: {doc.get('file_name', 'N/A')}\n"
info_text += f"Prompt: {doc.get('prompt', 'N/A')}\n"
info_text += f"Created: {doc.get('created_at', 'N/A')}\n"
info_text += f"URLs: {len(doc.get('urls', []))} image(s)\n"
info_text += f"Status: {doc.get('status', 'N/A')}\n"
info_text += "-" * 30 + "\n"
return images, info_text
except Exception as e:
error_msg = f"Error searching database: {str(e)}"
print(error_msg)
return [], error_msg
def create_gradio_app(mongo_uri: str, db_name: str, collection_name: str):
"""Create and launch the Gradio application"""
app = ImageGalleryApp(mongo_uri, db_name, collection_name)
def update_gallery(category, file_name, start_date, end_date, lob):
try:
images, info = app.search_images(category, file_name, start_date, end_date, lob)
return images, info
except Exception as e:
return [], f"Error: {str(e)}"
def get_filter_choices():
"""Get filter choices with loading indicator"""
try:
categories = app.get_unique_categories()
filenames = app.get_unique_filenames()
return categories, filenames
except Exception as e:
print(f"Error loading filter choices: {e}")
return ["All"], ["All"]
# Create Gradio interface
with gr.Blocks(title="Image Gallery", theme=gr.themes.Soft()) as demo:
gr.Markdown("# Generated Images")
# Loading state
loading_state = gr.State(False)
with gr.Row():
with gr.Column(scale=1):
category_dropdown = gr.Dropdown(
choices=["All"],
value="All",
label="Category",
interactive=True
)
filename_dropdown = gr.Dropdown(
choices=["All"],
value="All",
label="File Name",
interactive=True
)
lob_input = gr.Textbox(
value="test",
label="LOB (Line of Business)",
interactive=False
)
with gr.Row():
start_date_calendar = Calendar(
label="Start Date",
info="Select start date for filtering",
value=str(datetime.now(timezone.utc).date())
)
end_date_calendar = Calendar(
label="End Date",
info="Select end date for filtering",
value=str(datetime.now(timezone.utc).date())
)
with gr.Row():
search_btn = gr.Button("πŸ” Search Images", variant="primary")
refresh_btn = gr.Button("πŸ”„ Refresh Filters")
with gr.Row():
with gr.Column(scale=2):
gallery = gr.Gallery(
label="Images (Status: completed)",
show_label=True,
elem_id="gallery",
columns=4, # Increased columns for better layout
rows=3,
height="700px",
object_fit="contain",
preview=True # Enable preview mode
)
with gr.Column(scale=1):
info_text = gr.Textbox(
label="Document Information",
lines=25,
max_lines=25,
interactive=False
)
# Event handlers
def search_with_loading(*args):
return update_gallery(*args)
search_btn.click(
fn=search_with_loading,
inputs=[category_dropdown, filename_dropdown, start_date_calendar, end_date_calendar, lob_input],
outputs=[gallery, info_text]
)
def refresh_filters():
try:
categories, filenames = get_filter_choices()
return (
gr.Dropdown(choices=categories, value="All"),
gr.Dropdown(choices=filenames, value="All")
)
except Exception as e:
print(f"Error refreshing filters: {e}")
return category_dropdown, filename_dropdown
refresh_btn.click(
fn=refresh_filters,
outputs=[category_dropdown, filename_dropdown]
)
# Initialize filters and load initial data
def initialize_app():
try:
categories, filenames = get_filter_choices()
images, info = app.search_images() # Load fewer images initially
return (
gr.Dropdown(choices=categories, value="All"),
gr.Dropdown(choices=filenames, value="All"),
images,
info
)
except Exception as e:
error_msg = f"Error initializing app: {str(e)}"
return category_dropdown, filename_dropdown, [], error_msg
demo.load(
fn=initialize_app,
outputs=[category_dropdown, filename_dropdown, gallery, info_text]
)
return demo