AdGenesis-App / generator_function /image_function.py
userIdc2024's picture
Update generator_function/image_function.py
b7548f4 verified
import os, io, zipfile, replicate, time, logging, requests, streamlit as st, boto3, threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, Any, List, Tuple, Optional, Union
from uuid import uuid4
from urllib.parse import urlparse
from functools import lru_cache
import os, base64, logging
from openai import OpenAI
from helpers_function.helper_meta_data import meta_data_helper_function
from database.operations import start_job, finish_job
from database.connections import get_results_collection
from dotenv import load_dotenv
load_dotenv()
def _encode_image_to_base64(image_path):
try:
with open(image_path, "rb") as f:
return base64.b64encode(f.read()).decode("utf-8")
except Exception:
logger.exception(f"Failed to base64 encode image: {image_path}")
return ""
logger = logging.getLogger("imagegen_service")
logging.basicConfig(level=logging.INFO)
REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN")
MAX_WORKERS = min(32, (os.cpu_count() or 1) + 4)
REQUEST_TIMEOUT = 30
RETRY_ATTEMPTS = 3
MODEL_REGISTRY: Dict[str, Dict[str, Any]] = {
"imagegen-4-ultra": {"id": "google/imagen-4-ultra","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"],"param_name":"aspect_ratio"},
"imagen-4": {"id": "google/imagen-4","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"],"param_name":"aspect_ratio"},
"nano-banana": {"id": "google/nano-banana","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"],"param_name":"aspect_ratio"},
"qwen": {"id": "qwen/qwen-image","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3","3:2","2:3"],"param_name":"aspect_ratio"},
"seedream-3": {"id": "bytedance/seedream-3","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3","3:2","2:3","21:9"],"param_name":"aspect_ratio"},
"recraft-v3": {"id": "recraft-ai/recraft-v3","aspect_ratios": ["1:1","4:3","3:4","3:2","2:3","16:9","9:16","1:2","2:1","7:5","5:7","4:5","5:4","3:5","5:3"],"param_name":"aspect_ratio"},
"photon": {"id": "luma/photon","aspect_ratios": ["1:1","3:4","4:3","9:16","16:9","9:21","21:9"],"param_name":"aspect_ratio"},
"ideogram-v3-quality": {"id": "ideogram-ai/ideogram-v3-quality","aspect_ratios": ["1:3","3:1","1:2","2:1","9:16","16:9","10:16","16:10","2:3","3:2","3:4","4:3","4:5","5:4","1:1"],"param_name":"aspect_ratio"},
}
_thread_local = threading.local()
def get_model_config(model_key: str) -> Optional[Dict[str, Any]]:
return MODEL_REGISTRY.get(model_key)
@lru_cache(maxsize=128)
def _get_model_config_cached(model_key: str) -> Optional[Dict[str, Any]]:
return MODEL_REGISTRY.get(model_key)
def _s3():
if not hasattr(_thread_local, "s3"):
needed = ["R2_ENDPOINT","R2_ACCESS_KEY","R2_SECRET_KEY","R2_BUCKET_NAME","NEW_BASE"]
if any(not os.getenv(k) for k in needed):
_thread_local.s3 = None
return None
try:
_thread_local.s3 = boto3.client(
"s3",
endpoint_url=os.getenv("R2_ENDPOINT"),
aws_access_key_id=os.getenv("R2_ACCESS_KEY"),
aws_secret_access_key=os.getenv("R2_SECRET_KEY"),
region_name="auto",
)
except Exception as e:
logger.error(f"S3 init failed: {e}")
_thread_local.s3 = None
return _thread_local.s3
def _upload_to_r2(image_bytes: bytes) -> Optional[str]:
s3 = _s3()
if not s3:
return None
for attempt in range(RETRY_ATTEMPTS):
try:
filename = f"{uuid4().hex}.png"
key = f"adgenesis_image_text/creative_adgenesis/images/{filename}"
s3.put_object(
Bucket=os.getenv("R2_BUCKET_NAME"),
Key=key,
Body=image_bytes,
ContentType="image/png",
)
return f"{os.getenv('NEW_BASE').rstrip('/')}/{key}"
except Exception as e:
if attempt == RETRY_ATTEMPTS - 1:
logger.error(f"R2 upload failed: {e}")
return None
time.sleep(2 ** attempt)
return None
def _generate_one(model_key: str, prompt: str, aspect_ratio: str) -> List[str]:
if not REPLICATE_API_TOKEN:
return []
cfg = _get_model_config_cached(model_key)
if not cfg:
return []
for attempt in range(RETRY_ATTEMPTS):
try:
output = replicate.run(cfg["id"], input={"prompt": prompt, cfg["param_name"]: aspect_ratio})
urls: List[str] = []
if isinstance(output, list) and output:
first = output[0]
url = getattr(first, "url", str(first))
urls = [url]
elif isinstance(output, str):
urls = [output]
elif hasattr(output, "url"):
urls = [getattr(output, "url")]
if urls:
return urls
except Exception as e:
if attempt == RETRY_ATTEMPTS - 1:
logger.error(f"replicate run failed: {e}")
return []
time.sleep(1)
return []
def _fetch(url: Union[str, Any]) -> Optional[bytes]:
url_str = getattr(url, "url", str(url))
for attempt in range(RETRY_ATTEMPTS):
try:
r = requests.get(
url_str, timeout=REQUEST_TIMEOUT, stream=True,
headers={"Cache-Control":"no-cache","Pragma":"no-cache","User-Agent":"ImageBot/1.0"}
)
r.raise_for_status()
buf = b""
for chunk in r.iter_content(8192):
buf += chunk
return buf
except Exception:
if attempt == RETRY_ATTEMPTS - 1:
return None
time.sleep(1)
return None
def _process_one(args: Tuple[str, str, str, int, bool]) -> Dict[str, Any]:
model_key, prompt, aspect_ratio, idx, private_mode = args
out = {"index": idx, "success": False, "source_url": None, "r2_url": None, "error": None}
try:
urls = _generate_one(model_key, prompt, aspect_ratio)
if not urls:
out["error"] = "No URLs returned"; return out
src = urls[0]
out["source_url"] = getattr(src, "url", str(src))
b = _fetch(src)
if not b:
out["error"] = "Fetch failed"; return out
image_with_metadata = meta_data_helper_function(b)
if private_mode:
data_uri = "data:image/png;base64," + base64.b64encode(image_with_metadata).decode("utf-8")
out["r2_url"] = data_uri
out["success"] = True
else:
r2 = _upload_to_r2(image_with_metadata)
if r2:
out["r2_url"] = r2; out["success"] = True
else:
out["error"] = "Upload to R2 failed"
except Exception as e:
out["error"] = str(e)
return out
def _generate_images_parallel(model_key: str, aspect_ratio: str, prompt: str, num_images: int, *, private_mode: bool = False) -> Tuple[List[str], List[str], List[str]]:
if num_images == 1:
res = _process_one((model_key, prompt, aspect_ratio, 0, private_mode))
if res["success"]:
return [res["r2_url"]], [res["source_url"]], []
return [], [], [res["error"] or "Generation failed"]
args = [(model_key, prompt, aspect_ratio, i, private_mode) for i in range(num_images)]
r2, src, errs = [], [], []
with ThreadPoolExecutor(max_workers=min(MAX_WORKERS, num_images)) as ex:
for fut in as_completed({ex.submit(_process_one, a): a[3] for a in args}):
try:
res = fut.result()
if res["success"]:
if res["r2_url"]: r2.append(res["r2_url"])
if res["source_url"]: src.append(res["source_url"])
else:
errs.append(res["error"] or "Generation failed")
except Exception as e:
errs.append(f"Future err: {e}")
# de-dup
r2 = list(dict.fromkeys(r2)); src = list(dict.fromkeys(src))
return r2, src, errs
def generate_images_parallel(model_key: str, aspect_ratio: str, prompt: str, num_images: int, *, private_mode: bool = False) -> Tuple[List[str], List[str], List[str]]:
"""Back-compat public export used by background tasks."""
return _generate_images_parallel(model_key, aspect_ratio, prompt, num_images, private_mode=private_mode)
def handle_image_generation_optimized(
*,
model_key: str,
aspect_ratio: str,
prompt: str,
num_images: int,
debug_mode: bool = False,
category: Optional[str] = None,
platform: Optional[str] = None,
uid:str,
private_mode: bool = False,
):
"""
Streamlit-friendly wrapper: kicks off parallel gen, persists a job row,
and renders results in-place (no return value).
"""
if not REPLICATE_API_TOKEN:
st.error("Missing REPLICATE_API_TOKEN. Set it as an environment variable.")
return
if not prompt.strip():
st.warning("Please enter a prompt.")
return
created_by = uid
results_col = None if private_mode else get_results_collection()
db_job_id = None
if results_col is not None:
try:
db_job_id = start_job(
results_col,
type="generation",
created_by=created_by,
category=(category or "general"),
inputs={"model_key": model_key, "aspect_ratio": aspect_ratio, "num_images": num_images},
settings={"platform": platform},
user_prompt=prompt.strip(),
)
except Exception as e:
logger.error(f"start_job failed: {e}")
progress = st.progress(0, text="Starting generation...")
status = st.empty()
start = time.time()
try:
with status.container():
st.info(f"Generating {num_images} image(s)")
progress.progress(10, text="Running...")
r2_urls, source_urls, errors = _generate_images_parallel(
model_key,
aspect_ratio,
prompt.strip(),
num_images,
private_mode=private_mode,
)
urls = r2_urls if private_mode else (r2_urls or source_urls)
if results_col is not None and db_job_id:
try:
finish_job(
results_col,
db_job_id,
status="completed" if urls else "failed",
outputs_urls=urls or [],
provider_update={"errors": errors} if errors else None,
)
except Exception as e:
logger.error(f"finish_job failed: {e}")
progress.progress(100, text="Complete!")
took = time.time() - start
if urls:
with status.container():
message = f"Generated {len(urls)} image(s) in {took:.1f}s."
if not private_mode:
message += f" Job ID: {db_job_id or 'N/A'}"
else:
message += " Private mode: results stay local to this session."
st.success(message)
cols = st.columns(min(4, len(urls)) or 1)
image_bytes_list = []
for i, u in enumerate(urls):
with cols[i % len(cols)]:
try:
if isinstance(u, str) and u.startswith("data:image"):
try:
_, encoded = u.split(",", 1)
b = base64.b64decode(encoded)
except Exception:
b = None
else:
b = _fetch(u)
if b is None:
st.error("Failed to load image")
continue
image_bytes_list.append((f"image_{i + 1}.png", b))
st.image(b, width='stretch')
st.download_button(
f"Download image ",
b,
file_name=f"image_{i + 1}.png",
mime="image/png",
width='stretch',
)
except Exception as e:
st.error(f"Display failed: {e}")
if len(image_bytes_list) > 1:
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w") as zf:
for fname, b in image_bytes_list:
zf.writestr(fname, b)
zip_buffer.seek(0)
st.download_button(
" Download All Images",
data=zip_buffer,
file_name="all_images.zip",
mime="application/zip",
width='stretch',
)
else:
with status.container():
st.error("No images were generated.")
if errors and debug_mode:
with st.expander("Generation Errors", expanded=True):
for e in errors:
st.error(e)
except Exception as e:
if results_col is not None and db_job_id:
try:
finish_job(results_col, db_job_id, status="failed")
except Exception:
pass
with status.container():
st.error(f"Generation failed: {e}")
def generate_image(file_path, size, quality, category, sentiment, user_prompt, platform, blur, i=None):
try:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
logger.critical("OPENAI_API_KEY is not set.")
raise RuntimeError("OPENAI_API_KEY is missing")
client = OpenAI(api_key=api_key)
with open(file_path, "rb") as img_file:
background = "blurred background." if blur else " not blurred background."
result = client.images.edit(
model="gpt-image-1",
prompt=(
f"You are a top-tier performance digital marketer and creative strategist with 15+ years of expertise in affiliate marketing.\n"
f"Your objective is to analyze the provided winning ad image, deconstruct its concept, visual composition, and color scheme, and generate a fresh, conversion-focused ad visual tailored for the {category} niche.\n"
f"The new design should convey a {sentiment} sentiment and incorporate the user instruction: \n {user_prompt}.\n If user has given multple choices or options to be include in the image so choose randomly relevant to the reference image."
f"Create a visually compelling ad optimized for {platform} Ads that is scroll-stopping, pattern-interrupting, and designed to drive high CTR and Conversion Rate. Utilize striking color combinations, dynamic contrast levels, and strategic layout compositions to command attention while aligning with the target audience avatar.\n"
f"Make sure the images should be realistic, not be stocky at all and raw which should look like they are shot from an iPhone with {background}."
),
image=img_file,
size=size,
quality=quality,
)
image_base64 = result.data[0].b64_json
image_bytes = base64.b64decode(image_base64)
logger.info(f"Successfully generated image for {file_path}")
return image_bytes
except Exception as e:
logger.exception(f"Failed to generate image for {file_path}: {e}")
raise