import os, io, zipfile, replicate, time, logging, requests, streamlit as st, boto3, threading from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Dict, Any, List, Tuple, Optional, Union from uuid import uuid4 from urllib.parse import urlparse from functools import lru_cache import os, base64, logging from openai import OpenAI from helpers_function.helper_meta_data import meta_data_helper_function from database.operations import start_job, finish_job from database.connections import get_results_collection from dotenv import load_dotenv load_dotenv() def _encode_image_to_base64(image_path): try: with open(image_path, "rb") as f: return base64.b64encode(f.read()).decode("utf-8") except Exception: logger.exception(f"Failed to base64 encode image: {image_path}") return "" logger = logging.getLogger("imagegen_service") logging.basicConfig(level=logging.INFO) REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN") MAX_WORKERS = min(32, (os.cpu_count() or 1) + 4) REQUEST_TIMEOUT = 30 RETRY_ATTEMPTS = 3 MODEL_REGISTRY: Dict[str, Dict[str, Any]] = { "imagegen-4-ultra": {"id": "google/imagen-4-ultra","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"],"param_name":"aspect_ratio"}, "imagen-4": {"id": "google/imagen-4","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"],"param_name":"aspect_ratio"}, "nano-banana": {"id": "google/nano-banana","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"],"param_name":"aspect_ratio"}, "qwen": {"id": "qwen/qwen-image","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3","3:2","2:3"],"param_name":"aspect_ratio"}, "seedream-3": {"id": "bytedance/seedream-3","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3","3:2","2:3","21:9"],"param_name":"aspect_ratio"}, "recraft-v3": {"id": "recraft-ai/recraft-v3","aspect_ratios": ["1:1","4:3","3:4","3:2","2:3","16:9","9:16","1:2","2:1","7:5","5:7","4:5","5:4","3:5","5:3"],"param_name":"aspect_ratio"}, "photon": {"id": "luma/photon","aspect_ratios": ["1:1","3:4","4:3","9:16","16:9","9:21","21:9"],"param_name":"aspect_ratio"}, "ideogram-v3-quality": {"id": "ideogram-ai/ideogram-v3-quality","aspect_ratios": ["1:3","3:1","1:2","2:1","9:16","16:9","10:16","16:10","2:3","3:2","3:4","4:3","4:5","5:4","1:1"],"param_name":"aspect_ratio"}, } _thread_local = threading.local() def get_model_config(model_key: str) -> Optional[Dict[str, Any]]: return MODEL_REGISTRY.get(model_key) @lru_cache(maxsize=128) def _get_model_config_cached(model_key: str) -> Optional[Dict[str, Any]]: return MODEL_REGISTRY.get(model_key) def _s3(): if not hasattr(_thread_local, "s3"): needed = ["R2_ENDPOINT","R2_ACCESS_KEY","R2_SECRET_KEY","R2_BUCKET_NAME","NEW_BASE"] if any(not os.getenv(k) for k in needed): _thread_local.s3 = None return None try: _thread_local.s3 = boto3.client( "s3", endpoint_url=os.getenv("R2_ENDPOINT"), aws_access_key_id=os.getenv("R2_ACCESS_KEY"), aws_secret_access_key=os.getenv("R2_SECRET_KEY"), region_name="auto", ) except Exception as e: logger.error(f"S3 init failed: {e}") _thread_local.s3 = None return _thread_local.s3 def _upload_to_r2(image_bytes: bytes) -> Optional[str]: s3 = _s3() if not s3: return None for attempt in range(RETRY_ATTEMPTS): try: filename = f"{uuid4().hex}.png" key = f"adgenesis_image_text/creative_adgenesis/images/{filename}" s3.put_object( Bucket=os.getenv("R2_BUCKET_NAME"), Key=key, Body=image_bytes, ContentType="image/png", ) return f"{os.getenv('NEW_BASE').rstrip('/')}/{key}" except Exception as e: if attempt == RETRY_ATTEMPTS - 1: logger.error(f"R2 upload failed: {e}") return None time.sleep(2 ** attempt) return None def _generate_one(model_key: str, prompt: str, aspect_ratio: str) -> List[str]: if not REPLICATE_API_TOKEN: return [] cfg = _get_model_config_cached(model_key) if not cfg: return [] for attempt in range(RETRY_ATTEMPTS): try: output = replicate.run(cfg["id"], input={"prompt": prompt, cfg["param_name"]: aspect_ratio}) urls: List[str] = [] if isinstance(output, list) and output: first = output[0] url = getattr(first, "url", str(first)) urls = [url] elif isinstance(output, str): urls = [output] elif hasattr(output, "url"): urls = [getattr(output, "url")] if urls: return urls except Exception as e: if attempt == RETRY_ATTEMPTS - 1: logger.error(f"replicate run failed: {e}") return [] time.sleep(1) return [] def _fetch(url: Union[str, Any]) -> Optional[bytes]: url_str = getattr(url, "url", str(url)) for attempt in range(RETRY_ATTEMPTS): try: r = requests.get( url_str, timeout=REQUEST_TIMEOUT, stream=True, headers={"Cache-Control":"no-cache","Pragma":"no-cache","User-Agent":"ImageBot/1.0"} ) r.raise_for_status() buf = b"" for chunk in r.iter_content(8192): buf += chunk return buf except Exception: if attempt == RETRY_ATTEMPTS - 1: return None time.sleep(1) return None def _process_one(args: Tuple[str, str, str, int, bool]) -> Dict[str, Any]: model_key, prompt, aspect_ratio, idx, private_mode = args out = {"index": idx, "success": False, "source_url": None, "r2_url": None, "error": None} try: urls = _generate_one(model_key, prompt, aspect_ratio) if not urls: out["error"] = "No URLs returned"; return out src = urls[0] out["source_url"] = getattr(src, "url", str(src)) b = _fetch(src) if not b: out["error"] = "Fetch failed"; return out image_with_metadata = meta_data_helper_function(b) if private_mode: data_uri = "data:image/png;base64," + base64.b64encode(image_with_metadata).decode("utf-8") out["r2_url"] = data_uri out["success"] = True else: r2 = _upload_to_r2(image_with_metadata) if r2: out["r2_url"] = r2; out["success"] = True else: out["error"] = "Upload to R2 failed" except Exception as e: out["error"] = str(e) return out def _generate_images_parallel(model_key: str, aspect_ratio: str, prompt: str, num_images: int, *, private_mode: bool = False) -> Tuple[List[str], List[str], List[str]]: if num_images == 1: res = _process_one((model_key, prompt, aspect_ratio, 0, private_mode)) if res["success"]: return [res["r2_url"]], [res["source_url"]], [] return [], [], [res["error"] or "Generation failed"] args = [(model_key, prompt, aspect_ratio, i, private_mode) for i in range(num_images)] r2, src, errs = [], [], [] with ThreadPoolExecutor(max_workers=min(MAX_WORKERS, num_images)) as ex: for fut in as_completed({ex.submit(_process_one, a): a[3] for a in args}): try: res = fut.result() if res["success"]: if res["r2_url"]: r2.append(res["r2_url"]) if res["source_url"]: src.append(res["source_url"]) else: errs.append(res["error"] or "Generation failed") except Exception as e: errs.append(f"Future err: {e}") # de-dup r2 = list(dict.fromkeys(r2)); src = list(dict.fromkeys(src)) return r2, src, errs def generate_images_parallel(model_key: str, aspect_ratio: str, prompt: str, num_images: int, *, private_mode: bool = False) -> Tuple[List[str], List[str], List[str]]: """Back-compat public export used by background tasks.""" return _generate_images_parallel(model_key, aspect_ratio, prompt, num_images, private_mode=private_mode) def handle_image_generation_optimized( *, model_key: str, aspect_ratio: str, prompt: str, num_images: int, debug_mode: bool = False, category: Optional[str] = None, platform: Optional[str] = None, uid:str, private_mode: bool = False, ): """ Streamlit-friendly wrapper: kicks off parallel gen, persists a job row, and renders results in-place (no return value). """ if not REPLICATE_API_TOKEN: st.error("Missing REPLICATE_API_TOKEN. Set it as an environment variable.") return if not prompt.strip(): st.warning("Please enter a prompt.") return created_by = uid results_col = None if private_mode else get_results_collection() db_job_id = None if results_col is not None: try: db_job_id = start_job( results_col, type="generation", created_by=created_by, category=(category or "general"), inputs={"model_key": model_key, "aspect_ratio": aspect_ratio, "num_images": num_images}, settings={"platform": platform}, user_prompt=prompt.strip(), ) except Exception as e: logger.error(f"start_job failed: {e}") progress = st.progress(0, text="Starting generation...") status = st.empty() start = time.time() try: with status.container(): st.info(f"Generating {num_images} image(s)") progress.progress(10, text="Running...") r2_urls, source_urls, errors = _generate_images_parallel( model_key, aspect_ratio, prompt.strip(), num_images, private_mode=private_mode, ) urls = r2_urls if private_mode else (r2_urls or source_urls) if results_col is not None and db_job_id: try: finish_job( results_col, db_job_id, status="completed" if urls else "failed", outputs_urls=urls or [], provider_update={"errors": errors} if errors else None, ) except Exception as e: logger.error(f"finish_job failed: {e}") progress.progress(100, text="Complete!") took = time.time() - start if urls: with status.container(): message = f"Generated {len(urls)} image(s) in {took:.1f}s." if not private_mode: message += f" Job ID: {db_job_id or 'N/A'}" else: message += " Private mode: results stay local to this session." st.success(message) cols = st.columns(min(4, len(urls)) or 1) image_bytes_list = [] for i, u in enumerate(urls): with cols[i % len(cols)]: try: if isinstance(u, str) and u.startswith("data:image"): try: _, encoded = u.split(",", 1) b = base64.b64decode(encoded) except Exception: b = None else: b = _fetch(u) if b is None: st.error("Failed to load image") continue image_bytes_list.append((f"image_{i + 1}.png", b)) st.image(b, width='stretch') st.download_button( f"Download image ", b, file_name=f"image_{i + 1}.png", mime="image/png", width='stretch', ) except Exception as e: st.error(f"Display failed: {e}") if len(image_bytes_list) > 1: zip_buffer = io.BytesIO() with zipfile.ZipFile(zip_buffer, "w") as zf: for fname, b in image_bytes_list: zf.writestr(fname, b) zip_buffer.seek(0) st.download_button( " Download All Images", data=zip_buffer, file_name="all_images.zip", mime="application/zip", width='stretch', ) else: with status.container(): st.error("No images were generated.") if errors and debug_mode: with st.expander("Generation Errors", expanded=True): for e in errors: st.error(e) except Exception as e: if results_col is not None and db_job_id: try: finish_job(results_col, db_job_id, status="failed") except Exception: pass with status.container(): st.error(f"Generation failed: {e}") def generate_image(file_path, size, quality, category, sentiment, user_prompt, platform, blur, i=None): try: api_key = os.getenv("OPENAI_API_KEY") if not api_key: logger.critical("OPENAI_API_KEY is not set.") raise RuntimeError("OPENAI_API_KEY is missing") client = OpenAI(api_key=api_key) with open(file_path, "rb") as img_file: background = "blurred background." if blur else " not blurred background." result = client.images.edit( model="gpt-image-1", prompt=( f"You are a top-tier performance digital marketer and creative strategist with 15+ years of expertise in affiliate marketing.\n" f"Your objective is to analyze the provided winning ad image, deconstruct its concept, visual composition, and color scheme, and generate a fresh, conversion-focused ad visual tailored for the {category} niche.\n" f"The new design should convey a {sentiment} sentiment and incorporate the user instruction: \n {user_prompt}.\n If user has given multple choices or options to be include in the image so choose randomly relevant to the reference image." f"Create a visually compelling ad optimized for {platform} Ads that is scroll-stopping, pattern-interrupting, and designed to drive high CTR and Conversion Rate. Utilize striking color combinations, dynamic contrast levels, and strategic layout compositions to command attention while aligning with the target audience avatar.\n" f"Make sure the images should be realistic, not be stocky at all and raw which should look like they are shot from an iPhone with {background}." ), image=img_file, size=size, quality=quality, ) image_base64 = result.data[0].b64_json image_bytes = base64.b64decode(image_base64) logger.info(f"Successfully generated image for {file_path}") return image_bytes except Exception as e: logger.exception(f"Failed to generate image for {file_path}: {e}") raise