userIdc2024 commited on
Commit
0428d2d
·
verified ·
1 Parent(s): aded900

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +167 -100
src/streamlit_app.py CHANGED
@@ -1,23 +1,23 @@
1
- # main.py
2
  import os
3
  import io
4
  import zipfile
5
- import replicate
6
  import time
7
  import logging
8
  from concurrent.futures import ThreadPoolExecutor, as_completed
9
  from datetime import datetime, timedelta, date
10
  from typing import Dict, Any, List, Tuple, Optional
 
 
 
 
 
11
  import requests
12
  import streamlit as st
13
  from pymongo import MongoClient
14
  import boto3
 
15
  from uuid import uuid4
16
  from dotenv import load_dotenv
17
- from urllib.parse import urlparse
18
- import threading
19
- from functools import lru_cache
20
- import json
21
 
22
  load_dotenv()
23
 
@@ -32,44 +32,63 @@ MONGO_URI = os.getenv("MONGO_URI")
32
  MONGO_DB = os.getenv("MONGO_DB", "adgenesis_image_text")
33
  MONGO_COLLECTION = os.getenv("MONGO_COLLECTION", "creatives")
34
 
35
- MAX_WORKERS = min(32, (os.cpu_count() or 1) + 4)
36
  REQUEST_TIMEOUT = 30
37
  RETRY_ATTEMPTS = 3
38
- LIBRARY_PAGE_SIZE = 20
 
39
 
40
  MODEL_REGISTRY: Dict[str, Dict[str, Any]] = {
41
- "imagegen-4-ultra": {"id": "google/imagen-4-ultra", "aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"], "param_name": "aspect_ratio"},
42
- "imagen-4": {"id": "google/imagen-4", "aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"], "param_name": "aspect_ratio"},
43
- "qwen": {"id": "qwen/qwen-image", "aspect_ratios": ["1:1","16:9","9:16","3:4","4:3","3:2","2:3"], "param_name": "aspect_ratio"},
44
- "seedream-3": {"id": "bytedance/seedream-3", "aspect_ratios": ["1:1","16:9","9:16","3:4","4:3","3:2","2:3","21:9"], "param_name": "aspect_ratio"},
45
- "recraft-v3": {"id": "recraft-ai/recraft-v3", "aspect_ratios": ["1:1","4:3","3:4","3:2","2:3","16:9","9:16"], "param_name": "aspect_ratio"},
 
46
  }
47
 
48
  _thread_local = threading.local()
49
 
50
  # ----------------------------
51
- # Infra helpers
 
 
 
 
 
 
 
 
 
 
 
 
52
  # ----------------------------
 
 
 
 
 
 
53
  def get_mongo_collection():
54
- if not hasattr(_thread_local, 'mongo_collection'):
55
  if not MONGO_URI:
56
  _thread_local.mongo_collection = None
57
  return None
58
  try:
59
  client = MongoClient(MONGO_URI, serverSelectionTimeoutMS=3000)
60
  db = client[MONGO_DB]
61
- collection = db[MONGO_COLLECTION]
62
- client.admin.command('ping')
63
- _thread_local.mongo_collection = collection
64
  except Exception as e:
65
  logger.error(f"MongoDB connection failed: {e}")
66
  _thread_local.mongo_collection = None
67
  return _thread_local.mongo_collection
68
 
69
  def get_s3_client():
70
- if not hasattr(_thread_local, 's3_client'):
71
- required_vars = ["R2_ENDPOINT", "R2_ACCESS_KEY", "R2_SECRET_KEY", "R2_BUCKET_NAME"]
72
- if any(not os.getenv(v) for v in required_vars):
73
  _thread_local.s3_client = None
74
  return None
75
  try:
@@ -81,20 +100,16 @@ def get_s3_client():
81
  region_name="auto",
82
  )
83
  except Exception as e:
84
- logger.error(f"S3 client initialization failed: {e}")
85
  _thread_local.s3_client = None
86
  return _thread_local.s3_client
87
 
88
- @lru_cache(maxsize=128)
89
- def get_model_config(model_key: str) -> Optional[Dict[str, Any]]:
90
- return MODEL_REGISTRY.get(model_key)
91
-
92
  # ----------------------------
93
- # Upload & generation
94
  # ----------------------------
95
  def upload_to_r2(image_bytes: bytes) -> Optional[str]:
96
  s3 = get_s3_client()
97
- if not s3:
98
  return None
99
  try:
100
  filename = f"{uuid4().hex}.png"
@@ -107,52 +122,67 @@ def upload_to_r2(image_bytes: bytes) -> Optional[str]:
107
  )
108
  return f"{os.getenv('NEW_BASE').rstrip('/')}/{file_key}"
109
  except Exception as e:
110
- logger.error(f"Upload failed: {e}")
111
  return None
112
 
113
  def generate_one(model_key: str, prompt: str, aspect_ratio: str) -> List[str]:
114
- if not REPLICATE_API_TOKEN:
 
 
 
 
115
  return []
116
- config = get_model_config(model_key)
117
- if not config:
118
  return []
119
  try:
120
- output = replicate.run(config["id"], input={"prompt": prompt, config["param_name"]: aspect_ratio})
121
  if isinstance(output, list) and output:
122
  return [str(output[0])]
123
- elif isinstance(output, str):
124
  return [output]
 
 
125
  return []
126
  except Exception as e:
127
  logger.error(f"Replicate error: {e}")
128
  return []
129
 
130
  def fetch_bytes(url: str) -> Optional[bytes]:
131
- for _ in range(RETRY_ATTEMPTS):
132
  try:
133
  r = requests.get(url, timeout=REQUEST_TIMEOUT, stream=True)
134
  r.raise_for_status()
135
  return r.content
136
  except Exception:
 
 
137
  time.sleep(1)
138
  return None
139
 
140
- def process_prompt(i: int, text: str, model: str, aspect: str):
 
 
 
 
 
 
141
  urls = generate_one(model, text, aspect)
142
  if not urls:
143
  return {"idx": i, "urls": [], "error": "No URLs"}
144
- img_bytes = fetch_bytes(urls[0])
145
- if not img_bytes:
 
146
  return {"idx": i, "urls": [], "error": "Fetch failed"}
147
- r2 = upload_to_r2(img_bytes)
148
- return {"idx": i, "urls": [r2 or urls[0]], "error": None}
149
 
150
  # ----------------------------
151
  # Persistence
152
  # ----------------------------
153
  def save_record(model: str, aspect: str, prompt: str, urls: List[str]):
154
  coll = get_mongo_collection()
155
- if not coll:
156
  return None
157
  try:
158
  return str(coll.insert_one({
@@ -161,7 +191,7 @@ def save_record(model: str, aspect: str, prompt: str, urls: List[str]):
161
  "prompt": prompt,
162
  "urls": urls,
163
  "lob": "balraaj",
164
- "created_at": datetime.utcnow()
165
  }).inserted_id)
166
  except Exception as e:
167
  logger.error(f"Mongo insert failed: {e}")
@@ -170,20 +200,23 @@ def save_record(model: str, aspect: str, prompt: str, urls: List[str]):
170
  @st.cache_data(ttl=300)
171
  def query_records(start: datetime, end: datetime) -> List[Dict[str, Any]]:
172
  coll = get_mongo_collection()
173
- if not coll:
 
174
  return []
175
  try:
176
  return list(coll.find(
177
  {"created_at": {"$gte": start, "$lt": end}, "lob": "balraaj"}
178
  ).sort("created_at", -1).limit(LIBRARY_PAGE_SIZE))
179
- except Exception:
 
180
  return []
181
 
182
  # ----------------------------
183
  # Gallery helpers
184
  # ----------------------------
185
  def display_gallery(urls: List[str]):
186
- if not urls: return
 
187
  cols = st.columns(4)
188
  for i, u in enumerate(urls):
189
  with cols[i % 4]:
@@ -191,112 +224,146 @@ def display_gallery(urls: List[str]):
191
  img = fetch_bytes(u)
192
  if img:
193
  st.image(img, use_container_width=True)
194
- except:
195
- st.error("Failed")
196
 
197
  def bulk_zip(urls: List[str]):
 
 
198
  buf = io.BytesIO()
199
- with zipfile.ZipFile(buf, "w") as z:
200
  for i, u in enumerate(urls, 1):
201
  data = fetch_bytes(u)
202
  if data:
203
- name = f"image_{i}.png"
204
- z.writestr(name, data)
205
  buf.seek(0)
206
- st.download_button("Download All", buf, "images.zip", "application/zip")
207
 
208
  # ----------------------------
209
- # JSON loader & runner
210
  # ----------------------------
211
  def load_json(file) -> List[str]:
212
  data = json.loads(file.getvalue().decode("utf-8"))
213
- if not isinstance(data, dict) or "prompts" not in data:
214
- raise ValueError("JSON must be { 'prompts': [ ... ] }")
215
- return [p for p in data["prompts"] if isinstance(p, str) and p.strip()]
 
 
 
216
 
217
  def run_batch(prompts: List[str], model: str, aspect: str):
218
  total = len(prompts)
219
- status = [st.empty() for _ in prompts]
220
- progress = st.progress(0, f"0/{total}")
221
-
222
- all_urls = []
223
- with ThreadPoolExecutor(max_workers=min(MAX_WORKERS, total)) as ex:
224
- futs = {ex.submit(process_prompt, i, p, model, aspect): i for i,p in enumerate(prompts,1)}
225
- done = 0
226
- for f in as_completed(futs):
227
- i = futs[f]
228
- try:
229
- res = f.result()
230
- except Exception as e:
231
- res = {"idx": i, "urls": [], "error": str(e)}
232
- if res["urls"]:
233
- save_record(model, aspect, prompts[i-1], res["urls"])
234
- status[i-1].success(f"Prompt {i}/{total} ")
235
- all_urls.extend(res["urls"])
236
- else:
237
- status[i-1].error(f"Prompt {i}/{total} ✗ ({res['error']})")
238
- done += 1
239
- progress.progress(done/total, f"{done}/{total}")
 
 
 
240
 
241
  if all_urls:
242
  st.subheader("Gallery")
243
  display_gallery(all_urls)
244
  bulk_zip(all_urls)
 
 
245
 
246
  # ----------------------------
247
  # Pages
248
  # ----------------------------
249
  def render_json_page():
250
  st.subheader("Generate from JSON")
 
251
  up = st.file_uploader("Upload JSON", type=["json"])
252
- col1,col2 = st.columns([1,1])
253
- with col1: model = st.selectbox("Model", list(MODEL_REGISTRY.keys()), 0)
254
- with col2: aspect = st.selectbox("Aspect", MODEL_REGISTRY[model]["aspect_ratios"], 0)
 
 
255
 
256
  if up:
257
- prompts = load_json(up)
258
- st.json(prompts)
259
- if st.button("Generate", type="primary", use_container_width=True):
260
- run_batch(prompts, model, aspect)
 
 
 
 
 
 
261
 
262
  def render_library_page():
263
  st.subheader("Creative Library")
 
264
  today = datetime.utcnow().date()
265
- start = st.date_input("Start", today - timedelta(days=30))
266
- end = st.date_input("End", today)
267
- records = query_records(datetime.combine(start, datetime.min.time()),
268
- datetime.combine(end+timedelta(days=1), datetime.min.time()))
269
- all_urls = []
 
 
 
270
  for r in records:
271
- all_urls.extend(r.get("urls", []))
 
272
  if all_urls:
 
273
  display_gallery(all_urls)
274
  bulk_zip(all_urls)
275
  else:
276
- st.info("No records found.")
277
 
278
  # ----------------------------
279
  # Auth
280
  # ----------------------------
281
  @lru_cache(maxsize=1)
282
- def check_token(tok: str):
283
- return tok == os.getenv("ACCESS_TOKEN")
 
 
 
 
284
 
285
  def main_app():
 
286
  st.title("File-to-Image Generator")
287
- page = st.sidebar.radio("Menu", ["Generate from JSON","Creative Library"])
288
- if page=="Generate from JSON": render_json_page()
289
- else: render_library_page()
 
 
290
 
291
  def main():
292
  if not st.session_state.get("auth"):
293
  st.markdown("## Access Required")
294
- t = st.text_input("Token", type="password")
295
- if st.button("Unlock"):
296
- if check_token(t): st.session_state["auth"]=True; st.rerun()
297
- else: st.error("Invalid token")
 
 
 
298
  else:
299
  main_app()
300
 
301
- if __name__=="__main__":
302
  main()
 
 
1
  import os
2
  import io
3
  import zipfile
 
4
  import time
5
  import logging
6
  from concurrent.futures import ThreadPoolExecutor, as_completed
7
  from datetime import datetime, timedelta, date
8
  from typing import Dict, Any, List, Tuple, Optional
9
+ import json
10
+ import threading
11
+ from functools import lru_cache
12
+ from urllib.parse import urlparse
13
+
14
  import requests
15
  import streamlit as st
16
  from pymongo import MongoClient
17
  import boto3
18
+ import replicate # type: ignore
19
  from uuid import uuid4
20
  from dotenv import load_dotenv
 
 
 
 
21
 
22
  load_dotenv()
23
 
 
32
  MONGO_DB = os.getenv("MONGO_DB", "adgenesis_image_text")
33
  MONGO_COLLECTION = os.getenv("MONGO_COLLECTION", "creatives")
34
 
 
35
  REQUEST_TIMEOUT = 30
36
  RETRY_ATTEMPTS = 3
37
+ MAX_WORKERS = min(32, (os.cpu_count() or 1) + 4)
38
+ LIBRARY_PAGE_SIZE = 200 # aggregate many for one gallery view
39
 
40
  MODEL_REGISTRY: Dict[str, Dict[str, Any]] = {
41
+ "imagegen-4-ultra": {"id": "google/imagen-4-ultra","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"],"param_name": "aspect_ratio"},
42
+ "imagen-4": {"id": "google/imagen-4","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"],"param_name": "aspect_ratio"},
43
+ "nano-banana": {"id": "google/nano-banana","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"],"param_name": "aspect_ratio"},
44
+ "qwen": {"id": "qwen/qwen-image","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3","3:2","2:3"],"param_name": "aspect_ratio"},
45
+ "seedream-3": {"id": "bytedance/seedream-3","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3","3:2","2:3","21:9"],"param_name": "aspect_ratio"},
46
+ "recraft-v3": {"id": "recraft-ai/recraft-v3","aspect_ratios": ["1:1","4:3","3:4","3:2","2:3","16:9","9:16"],"param_name": "aspect_ratio"},
47
  }
48
 
49
  _thread_local = threading.local()
50
 
51
  # ----------------------------
52
+ # Preflight/debug helpers
53
+ # ----------------------------
54
+ def show_env_warnings():
55
+ if not REPLICATE_API_TOKEN:
56
+ st.warning("Missing **REPLICATE_API_TOKEN** — generation will return ‘No URLs’.", icon="⚠️")
57
+ for v in ["R2_ENDPOINT", "R2_ACCESS_KEY", "R2_SECRET_KEY", "R2_BUCKET_NAME", "NEW_BASE"]:
58
+ if not os.getenv(v):
59
+ st.info(f"Optional: {v} not set → images won’t be copied to R2 (source URLs will be used).", icon="ℹ️")
60
+ if not MONGO_URI:
61
+ st.info("Optional: MONGO_URI not set → results won’t be saved to the Creative Library.", icon="ℹ️")
62
+
63
+ # ----------------------------
64
+ # Clients (Replicate / Mongo / S3)
65
  # ----------------------------
66
+ def get_replicate_client():
67
+ if not hasattr(_thread_local, "replicate_client"):
68
+ # Explicit client avoids env-specific issues with module-level run()
69
+ _thread_local.replicate_client = replicate.Client(api_token=REPLICATE_API_TOKEN) if REPLICATE_API_TOKEN else None
70
+ return _thread_local.replicate_client
71
+
72
  def get_mongo_collection():
73
+ if not hasattr(_thread_local, "mongo_collection"):
74
  if not MONGO_URI:
75
  _thread_local.mongo_collection = None
76
  return None
77
  try:
78
  client = MongoClient(MONGO_URI, serverSelectionTimeoutMS=3000)
79
  db = client[MONGO_DB]
80
+ coll = db[MONGO_COLLECTION]
81
+ client.admin.command("ping")
82
+ _thread_local.mongo_collection = coll
83
  except Exception as e:
84
  logger.error(f"MongoDB connection failed: {e}")
85
  _thread_local.mongo_collection = None
86
  return _thread_local.mongo_collection
87
 
88
  def get_s3_client():
89
+ if not hasattr(_thread_local, "s3_client"):
90
+ required = ["R2_ENDPOINT","R2_ACCESS_KEY","R2_SECRET_KEY","R2_BUCKET_NAME","NEW_BASE"]
91
+ if any(not os.getenv(k) for k in required):
92
  _thread_local.s3_client = None
93
  return None
94
  try:
 
100
  region_name="auto",
101
  )
102
  except Exception as e:
103
+ logger.error(f"S3 client init failed: {e}")
104
  _thread_local.s3_client = None
105
  return _thread_local.s3_client
106
 
 
 
 
 
107
  # ----------------------------
108
+ # Core ops: R2 / Generate / Fetch
109
  # ----------------------------
110
  def upload_to_r2(image_bytes: bytes) -> Optional[str]:
111
  s3 = get_s3_client()
112
+ if s3 is None:
113
  return None
114
  try:
115
  filename = f"{uuid4().hex}.png"
 
122
  )
123
  return f"{os.getenv('NEW_BASE').rstrip('/')}/{file_key}"
124
  except Exception as e:
125
+ logger.error(f"S3 upload failed: {e}")
126
  return None
127
 
128
  def generate_one(model_key: str, prompt: str, aspect_ratio: str) -> List[str]:
129
+ """
130
+ Returns: [image_url] from Replicate (or [])
131
+ """
132
+ client = get_replicate_client()
133
+ if client is None:
134
  return []
135
+ model = MODEL_REGISTRY.get(model_key)
136
+ if not model:
137
  return []
138
  try:
139
+ output = client.run(model["id"], input={"prompt": prompt, model["param_name"]: aspect_ratio})
140
  if isinstance(output, list) and output:
141
  return [str(output[0])]
142
+ if isinstance(output, str):
143
  return [output]
144
+ if hasattr(output, "url"):
145
+ return [getattr(output, "url")]
146
  return []
147
  except Exception as e:
148
  logger.error(f"Replicate error: {e}")
149
  return []
150
 
151
  def fetch_bytes(url: str) -> Optional[bytes]:
152
+ for attempt in range(RETRY_ATTEMPTS):
153
  try:
154
  r = requests.get(url, timeout=REQUEST_TIMEOUT, stream=True)
155
  r.raise_for_status()
156
  return r.content
157
  except Exception:
158
+ if attempt == RETRY_ATTEMPTS - 1:
159
+ return None
160
  time.sleep(1)
161
  return None
162
 
163
+ def process_prompt(i: int, text: str, model: str, aspect: str) -> Dict[str, Any]:
164
+ """
165
+ One image per prompt:
166
+ - generate via Replicate
167
+ - try to upload to R2
168
+ - fallback to source url if R2 not available
169
+ """
170
  urls = generate_one(model, text, aspect)
171
  if not urls:
172
  return {"idx": i, "urls": [], "error": "No URLs"}
173
+ src = urls[0]
174
+ data = fetch_bytes(src)
175
+ if data is None:
176
  return {"idx": i, "urls": [], "error": "Fetch failed"}
177
+ r2 = upload_to_r2(data)
178
+ return {"idx": i, "urls": [r2 or src], "error": None}
179
 
180
  # ----------------------------
181
  # Persistence
182
  # ----------------------------
183
  def save_record(model: str, aspect: str, prompt: str, urls: List[str]):
184
  coll = get_mongo_collection()
185
+ if coll is None:
186
  return None
187
  try:
188
  return str(coll.insert_one({
 
191
  "prompt": prompt,
192
  "urls": urls,
193
  "lob": "balraaj",
194
+ "created_at": datetime.utcnow(),
195
  }).inserted_id)
196
  except Exception as e:
197
  logger.error(f"Mongo insert failed: {e}")
 
200
  @st.cache_data(ttl=300)
201
  def query_records(start: datetime, end: datetime) -> List[Dict[str, Any]]:
202
  coll = get_mongo_collection()
203
+ # FIX: compare with None explicitly (avoids NotImplementedError)
204
+ if coll is None:
205
  return []
206
  try:
207
  return list(coll.find(
208
  {"created_at": {"$gte": start, "$lt": end}, "lob": "balraaj"}
209
  ).sort("created_at", -1).limit(LIBRARY_PAGE_SIZE))
210
+ except Exception as e:
211
+ logger.error(f"Mongo query failed: {e}")
212
  return []
213
 
214
  # ----------------------------
215
  # Gallery helpers
216
  # ----------------------------
217
  def display_gallery(urls: List[str]):
218
+ if not urls:
219
+ return
220
  cols = st.columns(4)
221
  for i, u in enumerate(urls):
222
  with cols[i % 4]:
 
224
  img = fetch_bytes(u)
225
  if img:
226
  st.image(img, use_container_width=True)
227
+ except Exception:
228
+ st.error("Failed to load image")
229
 
230
  def bulk_zip(urls: List[str]):
231
+ if not urls:
232
+ return
233
  buf = io.BytesIO()
234
+ with zipfile.ZipFile(buf, "w", compression=zipfile.ZIP_DEFLATED) as z:
235
  for i, u in enumerate(urls, 1):
236
  data = fetch_bytes(u)
237
  if data:
238
+ z.writestr(f"image_{i}.png", data)
 
239
  buf.seek(0)
240
+ st.download_button("Download All", buf, "images.zip", "application/zip", use_container_width=True)
241
 
242
  # ----------------------------
243
+ # JSON loader & batch run (parallel)
244
  # ----------------------------
245
  def load_json(file) -> List[str]:
246
  data = json.loads(file.getvalue().decode("utf-8"))
247
+ if not isinstance(data, dict) or "prompts" not in data or not isinstance(data["prompts"], list):
248
+ raise ValueError("JSON must be { 'prompts': [ '...','...' ] } (strings only).")
249
+ out = [p.strip() for p in data["prompts"] if isinstance(p, str) and p.strip()]
250
+ if not out:
251
+ raise ValueError("No valid prompts found.")
252
+ return out
253
 
254
  def run_batch(prompts: List[str], model: str, aspect: str):
255
  total = len(prompts)
256
+ rows = [st.empty() for _ in prompts]
257
+ progress = st.progress(0.0, text=f"0/{total}")
258
+
259
+ all_urls: List[str] = []
260
+ done = 0
261
+
262
+ max_workers = min(MAX_WORKERS, max(2, (os.cpu_count() or 2)))
263
+ with st.spinner("Generating images..."):
264
+ with ThreadPoolExecutor(max_workers=max_workers) as ex:
265
+ futs = {ex.submit(process_prompt, i, p, model, aspect): i for i, p in enumerate(prompts, 1)}
266
+ for fut in as_completed(futs):
267
+ i = futs[fut]
268
+ try:
269
+ res = fut.result()
270
+ except Exception as e:
271
+ res = {"idx": i, "urls": [], "error": str(e)}
272
+ if res["urls"]:
273
+ save_record(model, aspect, prompts[i-1], res["urls"])
274
+ rows[i-1].success(f"Prompt {i}/{total} ")
275
+ all_urls.extend(res["urls"])
276
+ else:
277
+ rows[i-1].error(f"Prompt {i}/{total} ✗ ({res['error']})")
278
+ done += 1
279
+ progress.progress(done/total, text=f"{done}/{total}")
280
 
281
  if all_urls:
282
  st.subheader("Gallery")
283
  display_gallery(all_urls)
284
  bulk_zip(all_urls)
285
+ else:
286
+ st.info("No images to display.")
287
 
288
  # ----------------------------
289
  # Pages
290
  # ----------------------------
291
  def render_json_page():
292
  st.subheader("Generate from JSON")
293
+ show_env_warnings()
294
  up = st.file_uploader("Upload JSON", type=["json"])
295
+ col1, col2 = st.columns([1, 1])
296
+ with col1:
297
+ model = st.selectbox("Model", list(MODEL_REGISTRY.keys()), 0)
298
+ with col2:
299
+ aspect = st.selectbox("Aspect", MODEL_REGISTRY[model]["aspect_ratios"], 0)
300
 
301
  if up:
302
+ try:
303
+ prompts = load_json(up)
304
+ with st.expander("Preview prompts", expanded=False):
305
+ st.json(prompts)
306
+ if st.button("Generate", type="primary", use_container_width=True):
307
+ run_batch(prompts, model, aspect)
308
+ except Exception as e:
309
+ st.error(str(e))
310
+ else:
311
+ st.caption('Expected: { "prompts": ["prompt 1", "prompt 2", ...] }')
312
 
313
  def render_library_page():
314
  st.subheader("Creative Library")
315
+ show_env_warnings()
316
  today = datetime.utcnow().date()
317
+ start = st.date_input("Start date", today - timedelta(days=30))
318
+ end = st.date_input("End date", today)
319
+
320
+ start_dt = datetime.combine(start, datetime.min.time())
321
+ end_dt = datetime.combine(end + timedelta(days=1), datetime.min.time())
322
+
323
+ records = query_records(start_dt, end_dt)
324
+ all_urls: List[str] = []
325
  for r in records:
326
+ all_urls.extend(r.get("urls", []) or [])
327
+
328
  if all_urls:
329
+ st.caption(f"Showing {len(all_urls)} images from {len(records)} records")
330
  display_gallery(all_urls)
331
  bulk_zip(all_urls)
332
  else:
333
+ st.info("No records found in the selected range.")
334
 
335
  # ----------------------------
336
  # Auth
337
  # ----------------------------
338
  @lru_cache(maxsize=1)
339
+ def check_token(tok: str) -> bool:
340
+ acc = os.getenv("ACCESS_TOKEN")
341
+ if not acc:
342
+ # If ACCESS_TOKEN is not configured, allow through to avoid lockout in dev.
343
+ return True
344
+ return tok == acc
345
 
346
  def main_app():
347
+ st.set_page_config(page_title="File-to-Image • Creative Library", layout="wide")
348
  st.title("File-to-Image Generator")
349
+ page = st.sidebar.radio("Menu", ["Generate from JSON", "Creative Library"])
350
+ if page == "Generate from JSON":
351
+ render_json_page()
352
+ else:
353
+ render_library_page()
354
 
355
  def main():
356
  if not st.session_state.get("auth"):
357
  st.markdown("## Access Required")
358
+ token = st.text_input("Enter Access Token", type="password")
359
+ if st.button("Unlock App"):
360
+ if check_token(token):
361
+ st.session_state["auth"] = True
362
+ st.rerun()
363
+ else:
364
+ st.error("Invalid token.")
365
  else:
366
  main_app()
367
 
368
+ if __name__ == "__main__":
369
  main()