AdGenesis-App / generator_function /image_processor.py
pmachal101723's picture
Enhance image analysis and generation with improved error handling, format normalization, and user guidance
8ba42b4
raw
history blame
6.54 kB
import os, zipfile, tempfile, logging
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Tuple, Optional
from generator_function.image_function import generate_image
from helpers_function.helper_meta_data import meta_data_helper_function
from helpers_function.helpers import upload_image_to_r2
from helpers_function.helpers import is_valid_image
from database.connections import get_results_collection as get_collection
from database.operations import start_job, finish_job
from util.session_state import current_uid
logger = logging.getLogger(__name__)
COL = get_collection()
def _resolve_user_id() -> str:
return current_uid() or os.getenv("DEFAULT_USER_ID", "anonymous")
def process_zip_and_generate_images(
zip_path: str,
category: str,
size: str,
quality: str,
user_prompt: str,
sentiment: str,
platform: str,
num_images: int,
demo_mode: bool,
existing_images: Optional[List[str]],
blur: bool,
uid: str,
) -> List[str]:
num_images = 1 if demo_mode else num_images
try:
if zip_path.endswith(".zip"):
temp_dir = extract_zip_file(zip_path)
image_files = get_valid_image_files(temp_dir)
else:
image_files = [(os.path.basename(zip_path), zip_path)]
results = process_image_files(
image_files, category, size, quality, user_prompt, sentiment, platform, num_images, blur, uid
)
all_urls = [url for entry in results for url in entry["urls"]]
seen, deduped = set(), []
for u in all_urls:
if u not in seen:
seen.add(u); deduped.append(u)
# Return only new images, not appended to existing ones
return deduped
except Exception:
logger.exception(f"Global error during processing file: {zip_path}")
return existing_images or []
def extract_zip_file(zip_path: str) -> tempfile.TemporaryDirectory:
temp_dir = tempfile.TemporaryDirectory()
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(temp_dir.name)
logger.info(f"Extracted ZIP file: {zip_path}")
return temp_dir
def get_valid_image_files(temp_dir: tempfile.TemporaryDirectory) -> List[Tuple[str, str]]:
valid_files: List[Tuple[str, str]] = []
for file in os.listdir(temp_dir.name):
if "__MACOSX" in file: continue
file_path = os.path.join(temp_dir.name, file)
if is_valid_image(file):
valid_files.append((file, file_path))
else:
logger.warning(f"Ignored non-image file: {file}")
logger.info(f"Found {len(valid_files)} valid images.")
return valid_files
def process_image_files(image_files: List[Tuple[str, str]],category: str,size: str,
quality: str,user_prompt: str,sentiment: str,platform: str,num_images: int,blur: bool,uid: str,
) -> List[dict]:
final_results: List[dict] = []
with ThreadPoolExecutor(max_workers=5) as executor:
futures = []
for file_name, file_path in image_files:
job_id: Optional[str] = None
if COL is not None:
try:
settings = {"size": size,"quality": quality,"sentiment": sentiment,"platform": platform,"num_images": num_images,"blur": bool(blur)}
inputs = {"file_name": file_name, "mode": "img_or_zip"}
job_id = start_job(
COL,
type="variation",
created_by=uid,
category=category or "general",
inputs=inputs,
settings=settings,
user_prompt=user_prompt
)
except Exception:
logger.exception("Failed to start DB job; continuing without DB logging.")
futures.append(
executor.submit(
process_single_image,
file_name, file_path, category, size, quality, user_prompt, sentiment, platform, num_images, blur, job_id,
)
)
for future in as_completed(futures):
try:
result = future.result()
if result: final_results.append(result)
except Exception:
logger.exception("Unhandled exception during image processing thread.")
return final_results
def process_single_image(file_name: str,file_path: str,category: str,size: str,quality: str,user_prompt: str,sentiment: str,
platform: str,num_images: int,blur: bool,job_id: Optional[str],) -> Optional[dict]:
try:
image_urls = generate_images_from_prompts(
file_path, size, quality, category, sentiment, user_prompt, platform, num_images, blur
)
if COL is not None and job_id:
try:
finish_job(COL, job_id, status=("completed" if image_urls else "failed"), outputs_urls=image_urls)
except Exception:
logger.exception("Failed to finish DB job.")
if image_urls:
return {"file_name": file_name, "urls": image_urls}
return None
except Exception as e:
logger.error(f"Processing failed for {file_name}: {e}")
if COL is not None and job_id:
try:
finish_job(COL, job_id, status="failed", outputs_urls=[])
except Exception:
logger.exception("Also failed to mark DB job as failed.")
return None
def generate_images_from_prompts(
file_path: str, size: str, quality: str, category: str, sentiment: str, user_prompt: str,
platform: str, num_images: int, blur: bool,
) -> List[str]:
image_urls: List[str] = []
def worker(i: int) -> Optional[str]:
try:
image_bytes = generate_image(file_path, size, quality, category, sentiment, user_prompt, platform, blur, i)
if not image_bytes: return None
image_with_metadata = meta_data_helper_function(image_bytes)
s3_url = upload_image_to_r2(image_with_metadata)
return s3_url
except Exception as e:
logger.error(f"Image generation failed: {e}")
return None
with ThreadPoolExecutor(max_workers=min(10, num_images)) as executor:
futures = [executor.submit(worker, i) for i in range(num_images)]
for future in as_completed(futures):
result = future.result()
if result: image_urls.append(result)
return image_urls