Spaces:
Running
Running
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +154 -400
src/streamlit_app.py
CHANGED
|
@@ -33,20 +33,19 @@ RETRY_ATTEMPTS = 3
|
|
| 33 |
LIBRARY_PAGE_SIZE = 20
|
| 34 |
|
| 35 |
MODEL_REGISTRY: Dict[str, Dict[str, Any]] = {
|
| 36 |
-
"imagegen-4-ultra": {"id": "google/imagen-4-ultra","aspect_ratios": ["1:1",
|
| 37 |
-
"imagen-4": {"id": "google/imagen-4","aspect_ratios": ["1:1",
|
| 38 |
-
"
|
| 39 |
-
"
|
| 40 |
-
"
|
| 41 |
-
"
|
| 42 |
-
"
|
| 43 |
-
"ideogram-v3-quality": {"id": "ideogram-ai/ideogram-v3-quality","aspect_ratios": ["1:3", "3:1", "1:2", "2:1", "9:16", "16:9", "10:16", "16:10", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "1:1"],"param_name": "aspect_ratio"},
|
| 44 |
}
|
| 45 |
|
| 46 |
_thread_local = threading.local()
|
| 47 |
|
| 48 |
def get_mongo_collection():
|
| 49 |
-
if not hasattr(_thread_local,
|
| 50 |
if not MONGO_URI:
|
| 51 |
_thread_local.mongo_collection = None
|
| 52 |
return None
|
|
@@ -54,7 +53,7 @@ def get_mongo_collection():
|
|
| 54 |
client = MongoClient(MONGO_URI, serverSelectionTimeoutMS=3000)
|
| 55 |
db = client[MONGO_DB]
|
| 56 |
collection = db[MONGO_COLLECTION]
|
| 57 |
-
client.admin.command(
|
| 58 |
_thread_local.mongo_collection = collection
|
| 59 |
except Exception as e:
|
| 60 |
logger.error(f"MongoDB connection failed: {e}")
|
|
@@ -62,8 +61,8 @@ def get_mongo_collection():
|
|
| 62 |
return _thread_local.mongo_collection
|
| 63 |
|
| 64 |
def get_s3_client():
|
| 65 |
-
if not hasattr(_thread_local,
|
| 66 |
-
required_vars = ["R2_ENDPOINT",
|
| 67 |
missing = [var for var in required_vars if not os.getenv(var)]
|
| 68 |
if missing:
|
| 69 |
_thread_local.s3_client = None
|
|
@@ -89,23 +88,19 @@ def upload_to_r2_optimized(image_bytes: bytes) -> Optional[str]:
|
|
| 89 |
s3_client = get_s3_client()
|
| 90 |
if not s3_client:
|
| 91 |
return None
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
if attempt == RETRY_ATTEMPTS - 1:
|
| 106 |
-
return None
|
| 107 |
-
time.sleep(2 ** attempt)
|
| 108 |
-
return None
|
| 109 |
|
| 110 |
def generate_one_image_optimized(model_key: str, prompt: str, aspect_ratio: str) -> List[str]:
|
| 111 |
if not REPLICATE_API_TOKEN:
|
|
@@ -113,425 +108,184 @@ def generate_one_image_optimized(model_key: str, prompt: str, aspect_ratio: str)
|
|
| 113 |
config = get_model_config(model_key)
|
| 114 |
if not config:
|
| 115 |
return []
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
elif hasattr(output, "url"):
|
| 130 |
-
urls = [getattr(output, "url")]
|
| 131 |
-
if urls:
|
| 132 |
-
return urls
|
| 133 |
-
except Exception:
|
| 134 |
-
if attempt == RETRY_ATTEMPTS - 1:
|
| 135 |
-
return []
|
| 136 |
-
time.sleep(1)
|
| 137 |
return []
|
| 138 |
|
| 139 |
-
def fetch_image_bytes_optimized(url:
|
| 140 |
-
url_str = getattr(url, "url", str(url))
|
| 141 |
for attempt in range(RETRY_ATTEMPTS):
|
| 142 |
try:
|
| 143 |
-
response = requests.get(
|
| 144 |
-
url_str,
|
| 145 |
-
timeout=REQUEST_TIMEOUT,
|
| 146 |
-
headers={"Cache-Control": "no-cache","Pragma": "no-cache","User-Agent": "Mozilla/5.0 (compatible; ImageBot/1.0)"},
|
| 147 |
-
stream=True
|
| 148 |
-
)
|
| 149 |
response.raise_for_status()
|
| 150 |
-
content
|
| 151 |
-
for chunk in response.iter_content(chunk_size=8192):
|
| 152 |
-
content += chunk
|
| 153 |
-
return content
|
| 154 |
except Exception:
|
| 155 |
if attempt == RETRY_ATTEMPTS - 1:
|
| 156 |
return None
|
| 157 |
time.sleep(1)
|
| 158 |
return None
|
| 159 |
|
| 160 |
-
def process_single_image(args: Tuple[str,
|
| 161 |
model_key, prompt, aspect_ratio, index = args
|
| 162 |
result = {"index": index,"success": False,"source_url": None,"r2_url": None,"error": None}
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
result["error"] = "Failed to upload to R2"
|
| 180 |
-
except Exception as e:
|
| 181 |
-
result["error"] = str(e)
|
| 182 |
return result
|
| 183 |
|
| 184 |
-
def generate_images_parallel(model_key: str, aspect_ratio: str, prompt: str
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
return [], [], [result["error"] or "Generation failed"]
|
| 191 |
-
args_list = [(model_key, prompt, aspect_ratio, i) for i in range(num_images)]
|
| 192 |
-
all_r2_urls: List[str] = []
|
| 193 |
-
all_source_urls: List[str] = []
|
| 194 |
-
generation_errors: List[str] = []
|
| 195 |
-
with ThreadPoolExecutor(max_workers=min(MAX_WORKERS, num_images)) as executor:
|
| 196 |
-
future_to_index = {executor.submit(process_single_image, args): args[3] for args in args_list}
|
| 197 |
-
for future in as_completed(future_to_index):
|
| 198 |
-
try:
|
| 199 |
-
result = future.result()
|
| 200 |
-
if result["success"]:
|
| 201 |
-
all_r2_urls.append(result["r2_url"])
|
| 202 |
-
if result["source_url"]:
|
| 203 |
-
all_source_urls.append(result["source_url"])
|
| 204 |
-
else:
|
| 205 |
-
error_msg = result["error"] or f"Generation {result['index']} failed"
|
| 206 |
-
generation_errors.append(error_msg)
|
| 207 |
-
except Exception as e:
|
| 208 |
-
generation_errors.append(f"Future execution error: {str(e)}")
|
| 209 |
-
return all_r2_urls, all_source_urls, generation_errors
|
| 210 |
|
| 211 |
-
def save_creative_record_optimized(model_key:
|
| 212 |
collection = get_mongo_collection()
|
| 213 |
if collection is None:
|
| 214 |
return None
|
| 215 |
try:
|
| 216 |
-
doc = {
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
"urls": urls,
|
| 221 |
-
"num_images": len(urls),
|
| 222 |
-
"lob": "balraaj",
|
| 223 |
-
"created_at": datetime.utcnow()
|
| 224 |
-
}
|
| 225 |
-
result = collection.insert_one(doc)
|
| 226 |
-
return str(result.inserted_id)
|
| 227 |
-
except Exception:
|
| 228 |
return None
|
| 229 |
|
| 230 |
@st.cache_data(ttl=300)
|
| 231 |
-
def query_creatives_optimized(start_dt:
|
| 232 |
collection = get_mongo_collection()
|
| 233 |
if collection is None:
|
| 234 |
-
return [],
|
| 235 |
try:
|
| 236 |
-
total_count = collection.count_documents({"created_at":
|
| 237 |
-
cursor = (
|
| 238 |
-
|
| 239 |
-
.sort("created_at", -1)
|
| 240 |
-
.skip(page * LIBRARY_PAGE_SIZE)
|
| 241 |
-
.limit(LIBRARY_PAGE_SIZE)
|
| 242 |
-
)
|
| 243 |
-
records = list(cursor)
|
| 244 |
-
return records, total_count
|
| 245 |
except Exception:
|
| 246 |
-
return [],
|
| 247 |
-
|
| 248 |
-
@st.cache_data(ttl=3600)
|
| 249 |
-
def get_image_bytes_cached(url: str) -> Optional[bytes]:
|
| 250 |
-
return fetch_image_bytes_optimized(url)
|
| 251 |
|
| 252 |
-
def display_image_with_download_optimized(url:
|
| 253 |
-
url_str = getattr(url, "url", str(url))
|
| 254 |
try:
|
| 255 |
-
img_bytes =
|
| 256 |
-
if img_bytes
|
| 257 |
st.error("Failed to load image")
|
| 258 |
return
|
| 259 |
st.image(img_bytes, use_container_width=True)
|
| 260 |
-
|
| 261 |
-
base = os.path.basename(path) or "image.png"
|
| 262 |
if not os.path.splitext(base)[1]:
|
| 263 |
-
base
|
| 264 |
-
st.download_button(
|
| 265 |
except Exception as e:
|
| 266 |
st.error(f"Failed to display image: {e}")
|
| 267 |
|
| 268 |
-
def
|
| 269 |
-
if not urls:
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
try:
|
| 275 |
-
img_bytes = fetch_image_bytes_optimized(url)
|
| 276 |
-
if img_bytes:
|
| 277 |
-
path = urlparse(str(url)).path
|
| 278 |
-
base = os.path.basename(path) or f"image_{idx}.png"
|
| 279 |
-
if not os.path.splitext(base)[1]:
|
| 280 |
-
base = f"{base}.png"
|
| 281 |
-
zip_file.writestr(base, img_bytes)
|
| 282 |
-
except Exception:
|
| 283 |
-
pass
|
| 284 |
-
zip_buffer.seek(0)
|
| 285 |
-
st.download_button("Download All Images", data=zip_buffer, file_name=filename, mime="application/zip", use_container_width=True)
|
| 286 |
|
| 287 |
-
def
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
with
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
render_library_page()
|
| 296 |
-
|
| 297 |
-
def render_library_page():
|
| 298 |
-
st.subheader("Creative Library")
|
| 299 |
-
if "library_page" not in st.session_state:
|
| 300 |
-
st.session_state.library_page = 0
|
| 301 |
-
today_utc = datetime.utcnow().date()
|
| 302 |
-
default_start = today_utc - timedelta(days=30)
|
| 303 |
-
c1, c2, c3 = st.columns([1, 1, 1])
|
| 304 |
-
with c1:
|
| 305 |
-
start_date: date = st.date_input("Start date", value=default_start)
|
| 306 |
-
with c2:
|
| 307 |
-
end_date: date = st.date_input("End date", value=today_utc)
|
| 308 |
-
with c3:
|
| 309 |
-
if st.button("Apply Filters", use_container_width=True):
|
| 310 |
-
st.session_state.library_page = 0
|
| 311 |
-
st.cache_data.clear()
|
| 312 |
-
start_dt = datetime.combine(start_date, datetime.min.time())
|
| 313 |
-
end_dt = datetime.combine(end_date + timedelta(days=1), datetime.min.time())
|
| 314 |
-
records, total_count = query_creatives_optimized(start_dt, end_dt, st.session_state.library_page)
|
| 315 |
-
if not records and st.session_state.library_page == 0:
|
| 316 |
-
st.info("No creatives found for the selected dates.")
|
| 317 |
-
return
|
| 318 |
-
start_idx = st.session_state.library_page * LIBRARY_PAGE_SIZE + 1
|
| 319 |
-
end_idx = min(start_idx + len(records) - 1, total_count)
|
| 320 |
-
st.caption(f"Showing {start_idx}-{end_idx} of {total_count} items")
|
| 321 |
-
if total_count > LIBRARY_PAGE_SIZE:
|
| 322 |
-
col1, col2, col3, col4 = st.columns([1, 1, 1, 1])
|
| 323 |
-
with col1:
|
| 324 |
-
if st.button("First", disabled=st.session_state.library_page == 0):
|
| 325 |
-
st.session_state.library_page = 0
|
| 326 |
-
st.rerun()
|
| 327 |
-
with col2:
|
| 328 |
-
if st.button("Previous", disabled=st.session_state.library_page == 0):
|
| 329 |
-
st.session_state.library_page = max(0, st.session_state.library_page - 1)
|
| 330 |
-
st.rerun()
|
| 331 |
-
with col3:
|
| 332 |
-
max_page = (total_count - 1) // LIBRARY_PAGE_SIZE
|
| 333 |
-
if st.button("Next", disabled=st.session_state.library_page >= max_page):
|
| 334 |
-
st.session_state.library_page += 1
|
| 335 |
-
st.rerun()
|
| 336 |
-
with col4:
|
| 337 |
-
max_page = (total_count - 1) // LIBRARY_PAGE_SIZE
|
| 338 |
-
if st.button("Last", disabled=st.session_state.library_page >= max_page):
|
| 339 |
-
st.session_state.library_page = max_page
|
| 340 |
-
st.rerun()
|
| 341 |
-
display_creative_grid_optimized(records)
|
| 342 |
-
|
| 343 |
-
def display_creative_grid_optimized(records: List[Dict[str, Any]]):
|
| 344 |
-
if not records:
|
| 345 |
-
return
|
| 346 |
-
items: List[Tuple[str, str, str, datetime]] = []
|
| 347 |
-
for doc in records:
|
| 348 |
-
created_at = doc.get("created_at", datetime.utcnow())
|
| 349 |
-
model = doc.get("model", "?")
|
| 350 |
-
aspect = doc.get("aspect_ratio", "?")
|
| 351 |
-
urls = doc.get("urls", []) or []
|
| 352 |
-
for url in urls:
|
| 353 |
-
items.append((url, model, aspect, created_at))
|
| 354 |
-
if not items:
|
| 355 |
-
return
|
| 356 |
-
if "selected_urls" not in st.session_state:
|
| 357 |
-
st.session_state.selected_urls = set()
|
| 358 |
-
colA, colB, colC = st.columns([1, 1, 2])
|
| 359 |
-
with colA:
|
| 360 |
-
if st.button("Select all on this page", use_container_width=True, disabled=not items):
|
| 361 |
-
for (url, _, _, _) in items:
|
| 362 |
-
st.session_state.selected_urls.add(getattr(url, "url", str(url)))
|
| 363 |
-
with colB:
|
| 364 |
-
if st.button("Clear selection", use_container_width=True, disabled=not st.session_state.selected_urls):
|
| 365 |
-
st.session_state.selected_urls = set()
|
| 366 |
-
with colC:
|
| 367 |
-
if st.session_state.selected_urls:
|
| 368 |
-
zip_bytes = build_zip_from_selected()
|
| 369 |
-
st.download_button("Download selected (.zip)", data=zip_bytes, file_name="selected_images.zip", mime="application/zip", use_container_width=True)
|
| 370 |
-
cols = st.columns(4)
|
| 371 |
-
for i, (url, model, aspect, created_at) in enumerate(items):
|
| 372 |
-
with cols[i % 4]:
|
| 373 |
-
try:
|
| 374 |
-
display_image_with_download_optimized(url)
|
| 375 |
-
st.caption(f"{model} • {aspect} • {created_at.strftime('%Y-%m-%d %H:%M UTC')}")
|
| 376 |
-
url_str = getattr(url, "url", str(url))
|
| 377 |
-
key = f"lib_sel_{hash(url_str)}_{i}"
|
| 378 |
-
checked = st.checkbox("Select", key=key, value=(url_str in st.session_state.selected_urls))
|
| 379 |
-
if checked:
|
| 380 |
-
st.session_state.selected_urls.add(url_str)
|
| 381 |
-
else:
|
| 382 |
-
st.session_state.selected_urls.discard(url_str)
|
| 383 |
-
except Exception as e:
|
| 384 |
-
st.error(f"Failed to display image: {e}")
|
| 385 |
-
|
| 386 |
-
def build_zip_from_selected() -> bytes:
|
| 387 |
-
buf = io.BytesIO()
|
| 388 |
-
with zipfile.ZipFile(buf, "w", compression=zipfile.ZIP_DEFLATED) as zf:
|
| 389 |
-
for i, u in enumerate(sorted(list(st.session_state.selected_urls)), 1):
|
| 390 |
-
data = get_image_bytes_cached(u)
|
| 391 |
-
if data:
|
| 392 |
-
path = urlparse(u).path
|
| 393 |
-
base = os.path.basename(path) or f"image_{i}.png"
|
| 394 |
if not os.path.splitext(base)[1]:
|
| 395 |
-
base
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
@lru_cache(maxsize=1)
|
| 401 |
-
def check_token_cached(user_token: str) -> Tuple[bool, str]:
|
| 402 |
-
ACCESS_TOKEN = os.getenv("ACCESS_TOKEN")
|
| 403 |
-
if not ACCESS_TOKEN:
|
| 404 |
-
return False, "Server error: Access token not configured."
|
| 405 |
-
if user_token == ACCESS_TOKEN:
|
| 406 |
-
return True, ""
|
| 407 |
-
return False, "Invalid token."
|
| 408 |
-
|
| 409 |
-
def load_json_prompts(file) -> List[Dict[str, Any]]:
|
| 410 |
-
raw = file.getvalue().decode("utf-8", errors="replace")
|
| 411 |
-
data = json.loads(raw)
|
| 412 |
-
|
| 413 |
-
if not isinstance(data, dict) or "prompts" not in data or not isinstance(data["prompts"], list):
|
| 414 |
-
raise ValueError("Invalid JSON. Expected an object with a 'prompts' array of strings.")
|
| 415 |
-
|
| 416 |
-
prompts_out: List[Dict[str, Any]] = []
|
| 417 |
-
for i, item in enumerate(data["prompts"], 1):
|
| 418 |
-
if not isinstance(item, str) or not item.strip():
|
| 419 |
-
raise ValueError(f"'prompts[{i-1}]' must be a non-empty string.")
|
| 420 |
-
prompts_out.append({"id": f"p{i}", "content": item.strip()})
|
| 421 |
|
| 422 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
|
| 424 |
def render_json_page():
|
| 425 |
st.subheader("Generate from JSON Prompts")
|
| 426 |
-
up
|
| 427 |
-
|
| 428 |
-
col1
|
| 429 |
-
with
|
| 430 |
-
|
| 431 |
-
with col2:
|
| 432 |
-
aspect_options = MODEL_REGISTRY[default_model]["aspect_ratios"]
|
| 433 |
-
default_aspect = st.selectbox("Default Aspect Ratio", aspect_options, index=0, key="json_default_ar")
|
| 434 |
-
|
| 435 |
-
debug_mode = st.checkbox("Debug Mode", value=False, key="json_debug")
|
| 436 |
-
|
| 437 |
if up:
|
| 438 |
try:
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
with st.expander("Preview normalized prompts", expanded=False):
|
| 445 |
-
st.json(prompts_list, expanded=False)
|
| 446 |
-
|
| 447 |
-
if st.button("Generate for All Prompts", type="primary", use_container_width=True):
|
| 448 |
-
handle_bulk_json_generation(prompts_list, default_model, default_aspect, debug_mode)
|
| 449 |
-
except json.JSONDecodeError as e:
|
| 450 |
-
st.error(f"Invalid JSON: {e}")
|
| 451 |
except Exception as e:
|
| 452 |
-
st.error(
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
st.
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
with block:
|
| 475 |
-
st.markdown(f"**Prompt {idx}/{total}** — Model: `{model_key}` • Aspect: `{aspect_ratio}` • Num: `1`")
|
| 476 |
-
st.code(prompt_text or "(empty)", language="markdown")
|
| 477 |
-
|
| 478 |
-
if not prompt_text:
|
| 479 |
-
st.error("Prompt text is empty. Skipping.")
|
| 480 |
-
overall_progress.progress(min(idx / total, 1.0), text=f"Processed {idx}/{total}")
|
| 481 |
-
continue
|
| 482 |
-
|
| 483 |
-
r2_urls, src_urls, gen_errors = generate_images_parallel(model_key, aspect_ratio, prompt_text, num_images)
|
| 484 |
-
rec_id = None
|
| 485 |
-
if r2_urls:
|
| 486 |
-
rec_id = save_creative_record_optimized(model_key, aspect_ratio, prompt_text, r2_urls)
|
| 487 |
-
|
| 488 |
-
if r2_urls:
|
| 489 |
-
st.success(f"Generated 1 image. DB: {rec_id or 'N/A'}")
|
| 490 |
-
display_image_gallery_optimized(r2_urls)
|
| 491 |
-
bulk_download_button(r2_urls, filename=f"prompt_{idx}_image.zip")
|
| 492 |
-
all_generated_urls.extend(r2_urls)
|
| 493 |
-
elif src_urls:
|
| 494 |
-
st.warning("Image generated but R2 upload failed. Showing original:")
|
| 495 |
-
display_image_gallery_optimized(src_urls)
|
| 496 |
-
bulk_download_button(src_urls, filename=f"prompt_{idx}_image.zip")
|
| 497 |
-
all_generated_urls.extend(src_urls)
|
| 498 |
-
else:
|
| 499 |
-
st.error("No image was generated for this prompt.")
|
| 500 |
-
|
| 501 |
-
if gen_errors and debug_mode:
|
| 502 |
-
with st.expander("Errors", expanded=False):
|
| 503 |
-
for e in gen_errors:
|
| 504 |
-
st.error(e)
|
| 505 |
-
errors_total.extend(gen_errors)
|
| 506 |
|
| 507 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 508 |
|
| 509 |
-
|
| 510 |
-
st.
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
if errors_total and debug_mode:
|
| 515 |
-
with st.expander("All Errors Summary", expanded=False):
|
| 516 |
-
for e in errors_total:
|
| 517 |
-
st.error(e)
|
| 518 |
|
| 519 |
def main():
|
| 520 |
-
|
| 521 |
-
if "authenticated" not in st.session_state:
|
| 522 |
-
st.session_state["authenticated"] = False
|
| 523 |
-
if not st.session_state["authenticated"]:
|
| 524 |
-
st.markdown("## Access Required")
|
| 525 |
-
token_input = st.text_input("Enter Access Token", type="password")
|
| 526 |
-
if st.button("Unlock App"):
|
| 527 |
-
ok, error_msg = check_token_cached(token_input)
|
| 528 |
-
if ok:
|
| 529 |
-
st.session_state["authenticated"] = True
|
| 530 |
-
st.rerun()
|
| 531 |
-
else:
|
| 532 |
-
st.error(error_msg)
|
| 533 |
-
else:
|
| 534 |
-
main_app()
|
| 535 |
|
| 536 |
-
if __name__
|
| 537 |
main()
|
|
|
|
| 33 |
LIBRARY_PAGE_SIZE = 20
|
| 34 |
|
| 35 |
MODEL_REGISTRY: Dict[str, Dict[str, Any]] = {
|
| 36 |
+
"imagegen-4-ultra": {"id": "google/imagen-4-ultra","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"],"param_name": "aspect_ratio"},
|
| 37 |
+
"imagen-4": {"id": "google/imagen-4","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"],"param_name": "aspect_ratio"},
|
| 38 |
+
"qwen": {"id": "qwen/qwen-image","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3","3:2","2:3"],"param_name": "aspect_ratio"},
|
| 39 |
+
"seedream-3": {"id": "bytedance/seedream-3","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3","3:2","2:3","21:9"],"param_name": "aspect_ratio"},
|
| 40 |
+
"recraft-v3": {"id": "recraft-ai/recraft-v3","aspect_ratios": ["1:1","4:3","3:4","3:2","2:3","16:9","9:16"],"param_name": "aspect_ratio"},
|
| 41 |
+
"photon": {"id": "luma/photon","aspect_ratios": ["1:1","3:4","4:3","9:16","16:9","21:9"],"param_name": "aspect_ratio"},
|
| 42 |
+
"ideogram-v3-quality": {"id": "ideogram-ai/ideogram-v3-quality","aspect_ratios": ["1:1","16:9","9:16","2:3","3:2","4:5","5:4"],"param_name": "aspect_ratio"},
|
|
|
|
| 43 |
}
|
| 44 |
|
| 45 |
_thread_local = threading.local()
|
| 46 |
|
| 47 |
def get_mongo_collection():
|
| 48 |
+
if not hasattr(_thread_local, "mongo_collection"):
|
| 49 |
if not MONGO_URI:
|
| 50 |
_thread_local.mongo_collection = None
|
| 51 |
return None
|
|
|
|
| 53 |
client = MongoClient(MONGO_URI, serverSelectionTimeoutMS=3000)
|
| 54 |
db = client[MONGO_DB]
|
| 55 |
collection = db[MONGO_COLLECTION]
|
| 56 |
+
client.admin.command("ping")
|
| 57 |
_thread_local.mongo_collection = collection
|
| 58 |
except Exception as e:
|
| 59 |
logger.error(f"MongoDB connection failed: {e}")
|
|
|
|
| 61 |
return _thread_local.mongo_collection
|
| 62 |
|
| 63 |
def get_s3_client():
|
| 64 |
+
if not hasattr(_thread_local, "s3_client"):
|
| 65 |
+
required_vars = ["R2_ENDPOINT","R2_ACCESS_KEY","R2_SECRET_KEY","R2_BUCKET_NAME"]
|
| 66 |
missing = [var for var in required_vars if not os.getenv(var)]
|
| 67 |
if missing:
|
| 68 |
_thread_local.s3_client = None
|
|
|
|
| 88 |
s3_client = get_s3_client()
|
| 89 |
if not s3_client:
|
| 90 |
return None
|
| 91 |
+
try:
|
| 92 |
+
filename = f"{uuid4().hex}.png"
|
| 93 |
+
file_key = f"adgenesis_image_file/json/images/{filename}"
|
| 94 |
+
s3_client.put_object(
|
| 95 |
+
Bucket=os.getenv("R2_BUCKET_NAME"),
|
| 96 |
+
Key=file_key,
|
| 97 |
+
Body=image_bytes,
|
| 98 |
+
ContentType="image/png",
|
| 99 |
+
)
|
| 100 |
+
return f"{os.getenv('NEW_BASE').rstrip('/')}/{file_key}"
|
| 101 |
+
except Exception as e:
|
| 102 |
+
logger.error(f"S3 upload failed: {e}")
|
| 103 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
def generate_one_image_optimized(model_key: str, prompt: str, aspect_ratio: str) -> List[str]:
|
| 106 |
if not REPLICATE_API_TOKEN:
|
|
|
|
| 108 |
config = get_model_config(model_key)
|
| 109 |
if not config:
|
| 110 |
return []
|
| 111 |
+
try:
|
| 112 |
+
model_id = config["id"]
|
| 113 |
+
ar_param = config["param_name"]
|
| 114 |
+
inputs = {"prompt": prompt, ar_param: aspect_ratio}
|
| 115 |
+
output = replicate.run(model_id, input=inputs)
|
| 116 |
+
if isinstance(output, list) and output:
|
| 117 |
+
return [str(output[0])]
|
| 118 |
+
elif isinstance(output, str):
|
| 119 |
+
return [output]
|
| 120 |
+
elif hasattr(output, "url"):
|
| 121 |
+
return [getattr(output, "url")]
|
| 122 |
+
except Exception as e:
|
| 123 |
+
logger.error(f"Replicate error: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
return []
|
| 125 |
|
| 126 |
+
def fetch_image_bytes_optimized(url: str) -> Optional[bytes]:
|
|
|
|
| 127 |
for attempt in range(RETRY_ATTEMPTS):
|
| 128 |
try:
|
| 129 |
+
response = requests.get(url, timeout=REQUEST_TIMEOUT, stream=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
response.raise_for_status()
|
| 131 |
+
return response.content
|
|
|
|
|
|
|
|
|
|
| 132 |
except Exception:
|
| 133 |
if attempt == RETRY_ATTEMPTS - 1:
|
| 134 |
return None
|
| 135 |
time.sleep(1)
|
| 136 |
return None
|
| 137 |
|
| 138 |
+
def process_single_image(args: Tuple[str,str,str,int]) -> Dict[str,Any]:
|
| 139 |
model_key, prompt, aspect_ratio, index = args
|
| 140 |
result = {"index": index,"success": False,"source_url": None,"r2_url": None,"error": None}
|
| 141 |
+
urls = generate_one_image_optimized(model_key, prompt, aspect_ratio)
|
| 142 |
+
if not urls:
|
| 143 |
+
result["error"] = "No URLs returned"
|
| 144 |
+
return result
|
| 145 |
+
source_url = urls[0]
|
| 146 |
+
result["source_url"] = source_url
|
| 147 |
+
img_bytes = fetch_image_bytes_optimized(source_url)
|
| 148 |
+
if not img_bytes:
|
| 149 |
+
result["error"] = "Failed to fetch image bytes"
|
| 150 |
+
return result
|
| 151 |
+
r2_url = upload_to_r2_optimized(img_bytes)
|
| 152 |
+
if r2_url:
|
| 153 |
+
result["r2_url"] = r2_url
|
| 154 |
+
result["success"] = True
|
| 155 |
+
else:
|
| 156 |
+
result["error"] = "Failed to upload to R2"
|
|
|
|
|
|
|
|
|
|
| 157 |
return result
|
| 158 |
|
| 159 |
+
def generate_images_parallel(model_key: str, aspect_ratio: str, prompt: str) -> Tuple[List[str],List[str],List[str]]:
|
| 160 |
+
result = process_single_image((model_key, prompt, aspect_ratio, 0))
|
| 161 |
+
if result["success"]:
|
| 162 |
+
return [result["r2_url"]], [result["source_url"]], []
|
| 163 |
+
else:
|
| 164 |
+
return [], [], [result["error"] or "Generation failed"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
+
def save_creative_record_optimized(model_key:str,aspect_ratio:str,prompt:str,urls:List[str]) -> Optional[str]:
|
| 167 |
collection = get_mongo_collection()
|
| 168 |
if collection is None:
|
| 169 |
return None
|
| 170 |
try:
|
| 171 |
+
doc = {"model": model_key,"aspect_ratio": aspect_ratio,"prompt": prompt,"urls": urls,"num_images": len(urls),"lob": "json_batch","created_at": datetime.utcnow()}
|
| 172 |
+
return str(collection.insert_one(doc).inserted_id)
|
| 173 |
+
except Exception as e:
|
| 174 |
+
logger.error(f"Mongo insert failed: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
return None
|
| 176 |
|
| 177 |
@st.cache_data(ttl=300)
|
| 178 |
+
def query_creatives_optimized(start_dt:datetime,end_dt:datetime,page:int=0)->Tuple[List[Dict[str,Any]],int]:
|
| 179 |
collection = get_mongo_collection()
|
| 180 |
if collection is None:
|
| 181 |
+
return [],0
|
| 182 |
try:
|
| 183 |
+
total_count = collection.count_documents({"created_at":{"$gte":start_dt,"$lt":end_dt},"lob":"json_batch"})
|
| 184 |
+
cursor = collection.find({"created_at":{"$gte":start_dt,"$lt":end_dt},"lob":"json_batch"}).sort("created_at",-1).skip(page*LIBRARY_PAGE_SIZE).limit(LIBRARY_PAGE_SIZE)
|
| 185 |
+
return list(cursor), total_count
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
except Exception:
|
| 187 |
+
return [],0
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
|
| 189 |
+
def display_image_with_download_optimized(url:str):
|
|
|
|
| 190 |
try:
|
| 191 |
+
img_bytes = fetch_image_bytes_optimized(url)
|
| 192 |
+
if not img_bytes:
|
| 193 |
st.error("Failed to load image")
|
| 194 |
return
|
| 195 |
st.image(img_bytes, use_container_width=True)
|
| 196 |
+
base = os.path.basename(urlparse(url).path) or "image.png"
|
|
|
|
| 197 |
if not os.path.splitext(base)[1]:
|
| 198 |
+
base += ".png"
|
| 199 |
+
st.download_button("Download image", data=img_bytes, file_name=base, mime="image/png", use_container_width=True)
|
| 200 |
except Exception as e:
|
| 201 |
st.error(f"Failed to display image: {e}")
|
| 202 |
|
| 203 |
+
def display_image_gallery_optimized(urls: List[str]):
|
| 204 |
+
if not urls: return
|
| 205 |
+
cols = st.columns(min(4,len(urls)))
|
| 206 |
+
for i,url in enumerate(urls):
|
| 207 |
+
with cols[i % len(cols)]:
|
| 208 |
+
display_image_with_download_optimized(url)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
+
def bulk_download_button(urls: List[str], filename="images_bundle.zip"):
|
| 211 |
+
if not urls: return
|
| 212 |
+
zip_buffer = io.BytesIO()
|
| 213 |
+
with zipfile.ZipFile(zip_buffer,"w",compression=zipfile.ZIP_DEFLATED) as zipf:
|
| 214 |
+
for i,url in enumerate(urls,1):
|
| 215 |
+
img = fetch_image_bytes_optimized(url)
|
| 216 |
+
if img:
|
| 217 |
+
base = os.path.basename(urlparse(url).path) or f"img_{i}.png"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
if not os.path.splitext(base)[1]:
|
| 219 |
+
base += ".png"
|
| 220 |
+
zipf.writestr(base,img)
|
| 221 |
+
zip_buffer.seek(0)
|
| 222 |
+
st.download_button("Download All Images",data=zip_buffer,file_name=filename,mime="application/zip",use_container_width=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
|
| 224 |
+
def load_json_prompts(file)->List[Dict[str,str]]:
|
| 225 |
+
raw=file.getvalue().decode("utf-8")
|
| 226 |
+
data=json.loads(raw)
|
| 227 |
+
if not isinstance(data,dict) or "prompts" not in data or not isinstance(data["prompts"],list):
|
| 228 |
+
raise ValueError("JSON must be { 'prompts': [ 'string', ... ] }")
|
| 229 |
+
return [{"id":f"p{i}","content":p.strip()} for i,p in enumerate(data["prompts"],1) if isinstance(p,str) and p.strip()]
|
| 230 |
|
| 231 |
def render_json_page():
|
| 232 |
st.subheader("Generate from JSON Prompts")
|
| 233 |
+
up=st.file_uploader("Upload prompts JSON",type=["json"])
|
| 234 |
+
col1,col2=st.columns([1,1])
|
| 235 |
+
with col1: default_model=st.selectbox("Default Model",list(MODEL_REGISTRY.keys()),0)
|
| 236 |
+
with col2: default_aspect=st.selectbox("Default Aspect Ratio",MODEL_REGISTRY[default_model]["aspect_ratios"],0)
|
| 237 |
+
debug=st.checkbox("Debug Mode",False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
if up:
|
| 239 |
try:
|
| 240 |
+
prompts=load_json_prompts(up)
|
| 241 |
+
st.json(prompts)
|
| 242 |
+
if st.button("Generate for All Prompts",type="primary",use_container_width=True):
|
| 243 |
+
handle_bulk_json_generation(prompts,default_model,default_aspect,debug)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
except Exception as e:
|
| 245 |
+
st.error(str(e))
|
| 246 |
+
|
| 247 |
+
def handle_bulk_json_generation(prompts:List[Dict[str,str]],default_model:str,default_aspect:str,debug:bool):
|
| 248 |
+
total=len(prompts)
|
| 249 |
+
all_urls=[]
|
| 250 |
+
for i,p in enumerate(prompts,1):
|
| 251 |
+
st.markdown(f"**Prompt {i}/{total}** — {p['content']}")
|
| 252 |
+
r2,src,errs=generate_images_parallel(default_model,default_aspect,p["content"])
|
| 253 |
+
if r2:
|
| 254 |
+
save_creative_record_optimized(default_model,default_aspect,p["content"],r2)
|
| 255 |
+
display_image_gallery_optimized(r2)
|
| 256 |
+
all_urls.extend(r2)
|
| 257 |
+
elif src:
|
| 258 |
+
display_image_gallery_optimized(src)
|
| 259 |
+
all_urls.extend(src)
|
| 260 |
+
else:
|
| 261 |
+
st.error("No image generated")
|
| 262 |
+
if errs and debug:
|
| 263 |
+
st.error(errs)
|
| 264 |
+
if all_urls:
|
| 265 |
+
st.subheader("Download All")
|
| 266 |
+
bulk_download_button(all_urls,"all_prompts_images.zip")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
|
| 268 |
+
def render_library_page():
|
| 269 |
+
st.subheader("Creative Library")
|
| 270 |
+
today=datetime.utcnow().date()
|
| 271 |
+
start_date=st.date_input("Start date",today-timedelta(days=30))
|
| 272 |
+
end_date=st.date_input("End date",today)
|
| 273 |
+
start_dt=datetime.combine(start_date,datetime.min.time())
|
| 274 |
+
end_dt=datetime.combine(end_date+timedelta(days=1),datetime.min.time())
|
| 275 |
+
records,total=query_creatives_optimized(start_dt,end_dt,0)
|
| 276 |
+
st.caption(f"{total} items")
|
| 277 |
+
for rec in records:
|
| 278 |
+
urls=rec.get("urls",[])
|
| 279 |
+
if urls: display_image_gallery_optimized(urls)
|
| 280 |
|
| 281 |
+
def main_app():
|
| 282 |
+
st.title("File-to-Image Generator")
|
| 283 |
+
page=st.sidebar.radio("Navigation",["Generate from JSON","Creative Library"])
|
| 284 |
+
if page=="Generate from JSON": render_json_page()
|
| 285 |
+
else: render_library_page()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
|
| 287 |
def main():
|
| 288 |
+
main_app()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
|
| 290 |
+
if __name__=="__main__":
|
| 291 |
main()
|