Spaces:
Running
Running
| import os, io, zipfile, replicate, time, logging, requests, streamlit as st, boto3, threading | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from typing import Dict, Any, List, Tuple, Optional, Union | |
| from uuid import uuid4 | |
| from urllib.parse import urlparse | |
| from functools import lru_cache | |
| import os, base64, logging | |
| from openai import OpenAI | |
| from helpers_function.helper_meta_data import meta_data_helper_function | |
| from database.operations import start_job, finish_job | |
| from database.connections import get_results_collection | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| def _encode_image_to_base64(image_path): | |
| try: | |
| with open(image_path, "rb") as f: | |
| return base64.b64encode(f.read()).decode("utf-8") | |
| except Exception: | |
| logger.exception(f"Failed to base64 encode image: {image_path}") | |
| return "" | |
| logger = logging.getLogger("imagegen_service") | |
| logging.basicConfig(level=logging.INFO) | |
| REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN") | |
| MAX_WORKERS = min(32, (os.cpu_count() or 1) + 4) | |
| REQUEST_TIMEOUT = 30 | |
| RETRY_ATTEMPTS = 3 | |
| MODEL_REGISTRY: Dict[str, Dict[str, Any]] = { | |
| "imagegen-4-ultra": {"id": "google/imagen-4-ultra","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"],"param_name":"aspect_ratio"}, | |
| "imagen-4": {"id": "google/imagen-4","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"],"param_name":"aspect_ratio"}, | |
| "nano-banana": {"id": "google/nano-banana","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"],"param_name":"aspect_ratio"}, | |
| "qwen": {"id": "qwen/qwen-image","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3","3:2","2:3"],"param_name":"aspect_ratio"}, | |
| "seedream-3": {"id": "bytedance/seedream-3","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3","3:2","2:3","21:9"],"param_name":"aspect_ratio"}, | |
| "recraft-v3": {"id": "recraft-ai/recraft-v3","aspect_ratios": ["1:1","4:3","3:4","3:2","2:3","16:9","9:16","1:2","2:1","7:5","5:7","4:5","5:4","3:5","5:3"],"param_name":"aspect_ratio"}, | |
| "photon": {"id": "luma/photon","aspect_ratios": ["1:1","3:4","4:3","9:16","16:9","9:21","21:9"],"param_name":"aspect_ratio"}, | |
| "ideogram-v3-quality": {"id": "ideogram-ai/ideogram-v3-quality","aspect_ratios": ["1:3","3:1","1:2","2:1","9:16","16:9","10:16","16:10","2:3","3:2","3:4","4:3","4:5","5:4","1:1"],"param_name":"aspect_ratio"}, | |
| } | |
| _thread_local = threading.local() | |
| def get_model_config(model_key: str) -> Optional[Dict[str, Any]]: | |
| return MODEL_REGISTRY.get(model_key) | |
| def _get_model_config_cached(model_key: str) -> Optional[Dict[str, Any]]: | |
| return MODEL_REGISTRY.get(model_key) | |
| def _s3(): | |
| if not hasattr(_thread_local, "s3"): | |
| needed = ["R2_ENDPOINT","R2_ACCESS_KEY","R2_SECRET_KEY","R2_BUCKET_NAME","NEW_BASE"] | |
| if any(not os.getenv(k) for k in needed): | |
| _thread_local.s3 = None | |
| return None | |
| try: | |
| _thread_local.s3 = boto3.client( | |
| "s3", | |
| endpoint_url=os.getenv("R2_ENDPOINT"), | |
| aws_access_key_id=os.getenv("R2_ACCESS_KEY"), | |
| aws_secret_access_key=os.getenv("R2_SECRET_KEY"), | |
| region_name="auto", | |
| ) | |
| except Exception as e: | |
| logger.error(f"S3 init failed: {e}") | |
| _thread_local.s3 = None | |
| return _thread_local.s3 | |
| def _upload_to_r2(image_bytes: bytes) -> Optional[str]: | |
| s3 = _s3() | |
| if not s3: | |
| return None | |
| for attempt in range(RETRY_ATTEMPTS): | |
| try: | |
| filename = f"{uuid4().hex}.png" | |
| key = f"adgenesis_image_text/creative_adgenesis/images/{filename}" | |
| s3.put_object( | |
| Bucket=os.getenv("R2_BUCKET_NAME"), | |
| Key=key, | |
| Body=image_bytes, | |
| ContentType="image/png", | |
| ) | |
| return f"{os.getenv('NEW_BASE').rstrip('/')}/{key}" | |
| except Exception as e: | |
| if attempt == RETRY_ATTEMPTS - 1: | |
| logger.error(f"R2 upload failed: {e}") | |
| return None | |
| time.sleep(2 ** attempt) | |
| return None | |
| def _generate_one(model_key: str, prompt: str, aspect_ratio: str) -> List[str]: | |
| if not REPLICATE_API_TOKEN: | |
| return [] | |
| cfg = _get_model_config_cached(model_key) | |
| if not cfg: | |
| return [] | |
| for attempt in range(RETRY_ATTEMPTS): | |
| try: | |
| output = replicate.run(cfg["id"], input={"prompt": prompt, cfg["param_name"]: aspect_ratio}) | |
| urls: List[str] = [] | |
| if isinstance(output, list) and output: | |
| first = output[0] | |
| url = getattr(first, "url", str(first)) | |
| urls = [url] | |
| elif isinstance(output, str): | |
| urls = [output] | |
| elif hasattr(output, "url"): | |
| urls = [getattr(output, "url")] | |
| if urls: | |
| return urls | |
| except Exception as e: | |
| if attempt == RETRY_ATTEMPTS - 1: | |
| logger.error(f"replicate run failed: {e}") | |
| return [] | |
| time.sleep(1) | |
| return [] | |
| def _fetch(url: Union[str, Any]) -> Optional[bytes]: | |
| url_str = getattr(url, "url", str(url)) | |
| for attempt in range(RETRY_ATTEMPTS): | |
| try: | |
| r = requests.get( | |
| url_str, timeout=REQUEST_TIMEOUT, stream=True, | |
| headers={"Cache-Control":"no-cache","Pragma":"no-cache","User-Agent":"ImageBot/1.0"} | |
| ) | |
| r.raise_for_status() | |
| buf = b"" | |
| for chunk in r.iter_content(8192): | |
| buf += chunk | |
| return buf | |
| except Exception: | |
| if attempt == RETRY_ATTEMPTS - 1: | |
| return None | |
| time.sleep(1) | |
| return None | |
| def _process_one(args: Tuple[str, str, str, int, bool]) -> Dict[str, Any]: | |
| model_key, prompt, aspect_ratio, idx, private_mode = args | |
| out = {"index": idx, "success": False, "source_url": None, "r2_url": None, "error": None} | |
| try: | |
| urls = _generate_one(model_key, prompt, aspect_ratio) | |
| if not urls: | |
| out["error"] = "No URLs returned"; return out | |
| src = urls[0] | |
| out["source_url"] = getattr(src, "url", str(src)) | |
| b = _fetch(src) | |
| if not b: | |
| out["error"] = "Fetch failed"; return out | |
| image_with_metadata = meta_data_helper_function(b) | |
| if private_mode: | |
| data_uri = "data:image/png;base64," + base64.b64encode(image_with_metadata).decode("utf-8") | |
| out["r2_url"] = data_uri | |
| out["success"] = True | |
| else: | |
| r2 = _upload_to_r2(image_with_metadata) | |
| if r2: | |
| out["r2_url"] = r2; out["success"] = True | |
| else: | |
| out["error"] = "Upload to R2 failed" | |
| except Exception as e: | |
| out["error"] = str(e) | |
| return out | |
| def _generate_images_parallel(model_key: str, aspect_ratio: str, prompt: str, num_images: int, *, private_mode: bool = False) -> Tuple[List[str], List[str], List[str]]: | |
| if num_images == 1: | |
| res = _process_one((model_key, prompt, aspect_ratio, 0, private_mode)) | |
| if res["success"]: | |
| return [res["r2_url"]], [res["source_url"]], [] | |
| return [], [], [res["error"] or "Generation failed"] | |
| args = [(model_key, prompt, aspect_ratio, i, private_mode) for i in range(num_images)] | |
| r2, src, errs = [], [], [] | |
| with ThreadPoolExecutor(max_workers=min(MAX_WORKERS, num_images)) as ex: | |
| for fut in as_completed({ex.submit(_process_one, a): a[3] for a in args}): | |
| try: | |
| res = fut.result() | |
| if res["success"]: | |
| if res["r2_url"]: r2.append(res["r2_url"]) | |
| if res["source_url"]: src.append(res["source_url"]) | |
| else: | |
| errs.append(res["error"] or "Generation failed") | |
| except Exception as e: | |
| errs.append(f"Future err: {e}") | |
| # de-dup | |
| r2 = list(dict.fromkeys(r2)); src = list(dict.fromkeys(src)) | |
| return r2, src, errs | |
| def generate_images_parallel(model_key: str, aspect_ratio: str, prompt: str, num_images: int, *, private_mode: bool = False) -> Tuple[List[str], List[str], List[str]]: | |
| """Back-compat public export used by background tasks.""" | |
| return _generate_images_parallel(model_key, aspect_ratio, prompt, num_images, private_mode=private_mode) | |
| def handle_image_generation_optimized( | |
| *, | |
| model_key: str, | |
| aspect_ratio: str, | |
| prompt: str, | |
| num_images: int, | |
| debug_mode: bool = False, | |
| category: Optional[str] = None, | |
| platform: Optional[str] = None, | |
| uid:str, | |
| private_mode: bool = False, | |
| ): | |
| """ | |
| Streamlit-friendly wrapper: kicks off parallel gen, persists a job row, | |
| and renders results in-place (no return value). | |
| """ | |
| if not REPLICATE_API_TOKEN: | |
| st.error("Missing REPLICATE_API_TOKEN. Set it as an environment variable.") | |
| return | |
| if not prompt.strip(): | |
| st.warning("Please enter a prompt.") | |
| return | |
| created_by = uid | |
| results_col = None if private_mode else get_results_collection() | |
| db_job_id = None | |
| if results_col is not None: | |
| try: | |
| db_job_id = start_job( | |
| results_col, | |
| type="generation", | |
| created_by=created_by, | |
| category=(category or "general"), | |
| inputs={"model_key": model_key, "aspect_ratio": aspect_ratio, "num_images": num_images}, | |
| settings={"platform": platform}, | |
| user_prompt=prompt.strip(), | |
| ) | |
| except Exception as e: | |
| logger.error(f"start_job failed: {e}") | |
| progress = st.progress(0, text="Starting generation...") | |
| status = st.empty() | |
| start = time.time() | |
| try: | |
| with status.container(): | |
| st.info(f"Generating {num_images} image(s)") | |
| progress.progress(10, text="Running...") | |
| r2_urls, source_urls, errors = _generate_images_parallel( | |
| model_key, | |
| aspect_ratio, | |
| prompt.strip(), | |
| num_images, | |
| private_mode=private_mode, | |
| ) | |
| urls = r2_urls if private_mode else (r2_urls or source_urls) | |
| if results_col is not None and db_job_id: | |
| try: | |
| finish_job( | |
| results_col, | |
| db_job_id, | |
| status="completed" if urls else "failed", | |
| outputs_urls=urls or [], | |
| provider_update={"errors": errors} if errors else None, | |
| ) | |
| except Exception as e: | |
| logger.error(f"finish_job failed: {e}") | |
| progress.progress(100, text="Complete!") | |
| took = time.time() - start | |
| if urls: | |
| with status.container(): | |
| message = f"Generated {len(urls)} image(s) in {took:.1f}s." | |
| if not private_mode: | |
| message += f" Job ID: {db_job_id or 'N/A'}" | |
| else: | |
| message += " Private mode: results stay local to this session." | |
| st.success(message) | |
| cols = st.columns(min(4, len(urls)) or 1) | |
| image_bytes_list = [] | |
| for i, u in enumerate(urls): | |
| with cols[i % len(cols)]: | |
| try: | |
| if isinstance(u, str) and u.startswith("data:image"): | |
| try: | |
| _, encoded = u.split(",", 1) | |
| b = base64.b64decode(encoded) | |
| except Exception: | |
| b = None | |
| else: | |
| b = _fetch(u) | |
| if b is None: | |
| st.error("Failed to load image") | |
| continue | |
| image_bytes_list.append((f"image_{i + 1}.png", b)) | |
| st.image(b, width='stretch') | |
| st.download_button( | |
| f"Download image ", | |
| b, | |
| file_name=f"image_{i + 1}.png", | |
| mime="image/png", | |
| width='stretch', | |
| ) | |
| except Exception as e: | |
| st.error(f"Display failed: {e}") | |
| if len(image_bytes_list) > 1: | |
| zip_buffer = io.BytesIO() | |
| with zipfile.ZipFile(zip_buffer, "w") as zf: | |
| for fname, b in image_bytes_list: | |
| zf.writestr(fname, b) | |
| zip_buffer.seek(0) | |
| st.download_button( | |
| " Download All Images", | |
| data=zip_buffer, | |
| file_name="all_images.zip", | |
| mime="application/zip", | |
| width='stretch', | |
| ) | |
| else: | |
| with status.container(): | |
| st.error("No images were generated.") | |
| if errors and debug_mode: | |
| with st.expander("Generation Errors", expanded=True): | |
| for e in errors: | |
| st.error(e) | |
| except Exception as e: | |
| if results_col is not None and db_job_id: | |
| try: | |
| finish_job(results_col, db_job_id, status="failed") | |
| except Exception: | |
| pass | |
| with status.container(): | |
| st.error(f"Generation failed: {e}") | |
| def generate_image(file_path, size, quality, category, sentiment, user_prompt, platform, blur, i=None): | |
| try: | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if not api_key: | |
| logger.critical("OPENAI_API_KEY is not set.") | |
| raise RuntimeError("OPENAI_API_KEY is missing") | |
| client = OpenAI(api_key=api_key) | |
| with open(file_path, "rb") as img_file: | |
| background = "blurred background." if blur else " not blurred background." | |
| result = client.images.edit( | |
| model="gpt-image-1", | |
| prompt=( | |
| f"You are a top-tier performance digital marketer and creative strategist with 15+ years of expertise in affiliate marketing.\n" | |
| f"Your objective is to analyze the provided winning ad image, deconstruct its concept, visual composition, and color scheme, and generate a fresh, conversion-focused ad visual tailored for the {category} niche.\n" | |
| f"The new design should convey a {sentiment} sentiment and incorporate the user instruction: \n {user_prompt}.\n If user has given multple choices or options to be include in the image so choose randomly relevant to the reference image." | |
| f"Create a visually compelling ad optimized for {platform} Ads that is scroll-stopping, pattern-interrupting, and designed to drive high CTR and Conversion Rate. Utilize striking color combinations, dynamic contrast levels, and strategic layout compositions to command attention while aligning with the target audience avatar.\n" | |
| f"Make sure the images should be realistic, not be stocky at all and raw which should look like they are shot from an iPhone with {background}." | |
| ), | |
| image=img_file, | |
| size=size, | |
| quality=quality, | |
| ) | |
| image_base64 = result.data[0].b64_json | |
| image_bytes = base64.b64decode(image_base64) | |
| logger.info(f"Successfully generated image for {file_path}") | |
| return image_bytes | |
| except Exception as e: | |
| logger.exception(f"Failed to generate image for {file_path}: {e}") | |
| raise | |