AdGenesis-App / app_pages /image_generation.py
userIdc2024's picture
Update app_pages/image_generation.py
0ca4d13 verified
from __future__ import annotations
import zipfile, requests, logging, tempfile, shutil, os, base64
import streamlit as st
from generator_function.image_processor import process_zip_and_generate_images
from generator_function.multimodel_image_processor import process_zip_and_generate_images_multimodel
from multimodel_services.model_manager import get_image_to_image_models, get_ui_parameters, is_gpt_model
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)
def _zip_gallery_images(gallery_items):
if not gallery_items:
return None
temp_dir = tempfile.mkdtemp()
try:
zip_path = tempfile.NamedTemporaryFile(delete=False, suffix=".zip").name
for i, item in enumerate(gallery_items):
filename = None
url = None
data_bytes = None
if isinstance(item, tuple):
if len(item) == 2 and isinstance(item[1], (bytes, bytearray)):
filename = item[0]
data_bytes = item[1]
else:
url = item[0]
elif isinstance(item, str):
if item.startswith("data:image"):
try:
header, encoded = item.split(",", 1)
mime = header.split(";")[0].split("/")[-1] or "png"
data_bytes = base64.b64decode(encoded)
filename = f"image_{i}.{mime}"
except Exception:
logger.error("Failed to decode private image data for zip bundle.")
continue
else:
url = item
else:
continue
try:
if data_bytes is not None:
file_path = os.path.join(temp_dir, filename or f"image_{i}.png")
with open(file_path, "wb") as f:
f.write(data_bytes)
elif url:
ext = url.split('?')[0].split('.')[-1]
ext = ext if ext and len(ext) <= 5 else "png"
file_path = os.path.join(temp_dir, f"image_{i}.{ext}")
if url.startswith(("http://", "https://")):
resp = requests.get(url, timeout=10)
resp.raise_for_status()
with open(file_path, "wb") as f:
f.write(resp.content)
elif os.path.exists(url):
shutil.copy(url, file_path)
except Exception as e:
logger.error(f"Error processing image {url or filename}: {e}")
with zipfile.ZipFile(zip_path, "w") as zipf:
for file_name in os.listdir(temp_dir):
zipf.write(os.path.join(temp_dir, file_name), arcname=file_name)
return zip_path
except Exception as e:
logger.critical(f"Failed to create zip: {e}")
return None
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
def render_bulk_image_generator(uid: str, prefix: str = "ig_img"):
private_mode = st.session_state.get("private_mode_enabled", False)
zip_file = st.file_uploader(
"Upload Zip or Single File",
type=["zip", "png", "jpg", "jpeg"],
key=f"{prefix}_upload"
)
category = st.text_input("Category", key=f"{prefix}_category")
# Model selection
models = get_image_to_image_models()
model_names = [model['display_name'] for model in models.values()]
selected_model_display = st.selectbox("Select Model", model_names, key=f"{prefix}_model")
# Get selected model key
selected_model_key = None
for key, model in models.items():
if model['display_name'] == selected_model_display:
selected_model_key = key
break
selected_model = models[selected_model_key]
# Dynamic parameters based on selected model
dynamic_params = {}
if not is_gpt_model(selected_model_key):
ui_params = get_ui_parameters(selected_model_key)
if ui_params:
st.write("**Model Parameters:**")
for param_name, param_config in ui_params.items():
param_type = param_config.get('type', 'text')
default_value = param_config.get('default', '')
if param_type == 'select':
options = param_config.get('options', [])
dynamic_params[param_name] = st.selectbox(
param_name.replace('_', ' ').title(),
options,
index=options.index(default_value) if default_value in options else 0,
key=f"{prefix}_{param_name}"
)
elif param_type == 'number':
min_val = param_config.get('min', 0)
max_val = param_config.get('max', 100)
dynamic_params[param_name] = st.slider(
param_name.replace('_', ' ').title(),
min_value=min_val,
max_value=max_val,
value=default_value if default_value else min_val,
key=f"{prefix}_{param_name}"
)
# Show different columns based on selected model
if is_gpt_model(selected_model_key):
# GPT model - show all fields
col1, col2, col3, col4, col5, col6 = st.columns([2, 2.5, 2.5, 2.5, 2.5, 2])
with col1:
blur = st.checkbox("Blur Image", key=f"{prefix}_blur")
with col2:
size = st.selectbox("Image Size", ["auto", "1024x1024", "1536x1024", "1024x1536"], key=f"{prefix}_size")
with col3:
quality = st.selectbox("Quality", ["auto", "low", "medium"], key=f"{prefix}_quality")
with col4:
sentiment = st.selectbox("Sentiment", ["as original image", "positive", "negative"], key=f"{prefix}_sentiment")
with col5:
platform = st.selectbox("Platform", ["Facebook", "Native", "Newsbreak"], key=f"{prefix}_platform")
with col6:
num_images = st.slider("No. of Images to be generated:", 1, 10, value=2, step=1, key=f"{prefix}_num_images")
else:
# Non-GPT models - hide blur, size and quality, keep only common fields
col1, col2, col3 = st.columns([2.5, 2.5, 2.5])
with col1:
sentiment = st.selectbox("Sentiment", ["as original image", "positive", "negative"], key=f"{prefix}_sentiment")
with col2:
platform = st.selectbox("Platform", ["Facebook", "Native", "Newsbreak"], key=f"{prefix}_platform")
with col3:
num_images = st.slider("No. of Images to be generated:", 1, 10, value=2, step=1, key=f"{prefix}_num_images")
# Set default values for blur, size and quality for non-GPT models
blur = False
size = "auto"
quality = "auto"
user_prompt = st.text_area("User Prompt", height=100, key=f"{prefix}_user_prompt")
colA, colB, colC = st.columns([1, 1, 1])
with colA:
demo_btn = st.button("Generate Demo Image", key=f"{prefix}_demo")
with colB:
gen_all_btn = st.button("Generate All Images", key=f"{prefix}_gen_all")
with colC:
download_btn = st.button("Download All", key=f"{prefix}_download_all")
gallery_key = f"{prefix}_gallery"
gallery = st.session_state.setdefault(gallery_key, [])
if private_mode:
st.info("Private mode is active. Generated variations stay only in this session and are not uploaded to storage or the AI Library.")
if demo_btn or gen_all_btn:
if zip_file and category and user_prompt:
# Clear previous images when starting new generation
st.session_state[gallery_key] = []
with st.spinner(" Generating image variations..."):
temp_dir = tempfile.mkdtemp()
try:
# Use original filename for single files, or generate unique name for zip
if zip_file.name.lower().endswith(".zip"):
temp_path = os.path.join(temp_dir, zip_file.name)
else:
# Use original filename to preserve uniqueness
temp_path = os.path.join(temp_dir, zip_file.name)
with open(temp_path, "wb") as f:
f.write(zip_file.read())
# Use multimodel processor if non-GPT model is selected
if is_gpt_model(selected_model_key):
images = process_zip_and_generate_images(
temp_path,
category,
size,
quality,
user_prompt,
sentiment,
platform,
num_images,
demo_btn,
None, # Pass None instead of gallery to avoid appending
blur,
uid,
private_mode=private_mode,
)
else:
images = process_zip_and_generate_images_multimodel(
temp_path,
category,
size,
quality,
user_prompt,
sentiment,
platform,
num_images,
demo_btn,
None, # Pass None instead of gallery to avoid appending
blur,
uid,
selected_model_key,
dynamic_params,
private_mode=private_mode,
)
st.session_state[gallery_key] = images or []
images = st.session_state[gallery_key]
if images:
cols = st.columns(4)
for idx, img_path in enumerate(images):
with cols[idx % 4]:
if isinstance(img_path, str) and img_path.startswith("data:image"):
try:
_, encoded = img_path.split(",", 1)
img_bytes = base64.b64decode(encoded)
except Exception:
img_bytes = None
if img_bytes:
st.image(img_bytes, width='stretch')
st.download_button(
f"Download image {idx + 1}",
img_bytes,
file_name=f"private_image_{idx + 1}.png",
mime="image/png",
key=f"{prefix}_private_dl_{idx}"
)
else:
st.warning("Unable to display private image data.")
else:
st.image(img_path,width='stretch')
st.success("Image generation completed!")
else:
st.info("No images generated.")
except Exception as e:
st.error(f"Error: {e}")
logger.exception("Generation failed (embedded).")
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
else:
st.warning("Please upload a file and fill all required fields.")
# ----------------------------
# Download all images
# ----------------------------
if download_btn:
with st.spinner(" Preparing for download..."):
zip_path = _zip_gallery_images(st.session_state.get(gallery_key, []))
if zip_path:
with open(zip_path, "rb") as f:
st.download_button(
"Download ZIP",
data=f.read(),
file_name="generated_images.zip",
mime="application/zip",
key=f"{prefix}_zip_dl_btn",
width='stretch'
)
try:
os.remove(zip_path)
except Exception:
pass
else:
st.warning("No images to zip and download.")