userIdc2024 commited on
Commit
6fb57f4
·
verified ·
1 Parent(s): c66ef08

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +306 -100
src/streamlit_app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import io
3
  import zipfile
@@ -23,29 +24,38 @@ load_dotenv()
23
  logging.basicConfig(level=logging.INFO)
24
  logger = logging.getLogger("imagegen_app")
25
 
 
 
 
26
  REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN")
27
  MONGO_URI = os.getenv("MONGO_URI")
28
  MONGO_DB = os.getenv("MONGO_DB", "adgenesis_image_text")
29
  MONGO_COLLECTION = os.getenv("MONGO_COLLECTION", "creatives")
 
30
  MAX_WORKERS = min(32, (os.cpu_count() or 1) + 4)
31
  REQUEST_TIMEOUT = 30
32
  RETRY_ATTEMPTS = 3
33
  LIBRARY_PAGE_SIZE = 20
34
 
 
35
  MODEL_REGISTRY: Dict[str, Dict[str, Any]] = {
36
- "imagegen-4-ultra": {"id": "google/imagen-4-ultra","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"],"param_name": "aspect_ratio"},
37
- "imagen-4": {"id": "google/imagen-4","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"],"param_name": "aspect_ratio"},
38
- "qwen": {"id": "qwen/qwen-image","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3","3:2","2:3"],"param_name": "aspect_ratio"},
39
- "seedream-3": {"id": "bytedance/seedream-3","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3","3:2","2:3","21:9"],"param_name": "aspect_ratio"},
40
- "recraft-v3": {"id": "recraft-ai/recraft-v3","aspect_ratios": ["1:1","4:3","3:4","3:2","2:3","16:9","9:16"],"param_name": "aspect_ratio"},
41
- "photon": {"id": "luma/photon","aspect_ratios": ["1:1","3:4","4:3","9:16","16:9","21:9"],"param_name": "aspect_ratio"},
42
- "ideogram-v3-quality": {"id": "ideogram-ai/ideogram-v3-quality","aspect_ratios": ["1:1","16:9","9:16","2:3","3:2","4:5","5:4"],"param_name": "aspect_ratio"},
 
43
  }
44
 
45
  _thread_local = threading.local()
46
 
 
 
 
47
  def get_mongo_collection():
48
- if not hasattr(_thread_local, "mongo_collection"):
49
  if not MONGO_URI:
50
  _thread_local.mongo_collection = None
51
  return None
@@ -53,7 +63,7 @@ def get_mongo_collection():
53
  client = MongoClient(MONGO_URI, serverSelectionTimeoutMS=3000)
54
  db = client[MONGO_DB]
55
  collection = db[MONGO_COLLECTION]
56
- client.admin.command("ping")
57
  _thread_local.mongo_collection = collection
58
  except Exception as e:
59
  logger.error(f"MongoDB connection failed: {e}")
@@ -61,8 +71,8 @@ def get_mongo_collection():
61
  return _thread_local.mongo_collection
62
 
63
  def get_s3_client():
64
- if not hasattr(_thread_local, "s3_client"):
65
- required_vars = ["R2_ENDPOINT","R2_ACCESS_KEY","R2_SECRET_KEY","R2_BUCKET_NAME"]
66
  missing = [var for var in required_vars if not os.getenv(var)]
67
  if missing:
68
  _thread_local.s3_client = None
@@ -84,24 +94,31 @@ def get_s3_client():
84
  def get_model_config(model_key: str) -> Optional[Dict[str, Any]]:
85
  return MODEL_REGISTRY.get(model_key)
86
 
 
 
 
87
  def upload_to_r2_optimized(image_bytes: bytes) -> Optional[str]:
88
  s3_client = get_s3_client()
89
  if not s3_client:
90
  return None
91
  try:
92
  filename = f"{uuid4().hex}.png"
93
- file_key = f"adgenesis_image_file/json/images/{filename}"
94
  s3_client.put_object(
95
  Bucket=os.getenv("R2_BUCKET_NAME"),
96
  Key=file_key,
97
  Body=image_bytes,
98
  ContentType="image/png",
99
  )
100
- return f"{os.getenv('NEW_BASE').rstrip('/')}/{file_key}"
 
101
  except Exception as e:
102
  logger.error(f"S3 upload failed: {e}")
103
  return None
104
 
 
 
 
105
  def generate_one_image_optimized(model_key: str, prompt: str, aspect_ratio: str) -> List[str]:
106
  if not REPLICATE_API_TOKEN:
107
  return []
@@ -113,15 +130,18 @@ def generate_one_image_optimized(model_key: str, prompt: str, aspect_ratio: str)
113
  ar_param = config["param_name"]
114
  inputs = {"prompt": prompt, ar_param: aspect_ratio}
115
  output = replicate.run(model_id, input=inputs)
 
116
  if isinstance(output, list) and output:
117
- return [str(output[0])]
 
118
  elif isinstance(output, str):
119
  return [output]
120
  elif hasattr(output, "url"):
121
  return [getattr(output, "url")]
 
122
  except Exception as e:
123
  logger.error(f"Replicate error: {e}")
124
- return []
125
 
126
  def fetch_image_bytes_optimized(url: str) -> Optional[bytes]:
127
  for attempt in range(RETRY_ATTEMPTS):
@@ -135,12 +155,12 @@ def fetch_image_bytes_optimized(url: str) -> Optional[bytes]:
135
  time.sleep(1)
136
  return None
137
 
138
- def process_single_image(args: Tuple[str,str,str,int]) -> Dict[str,Any]:
139
  model_key, prompt, aspect_ratio, index = args
140
- result = {"index": index,"success": False,"source_url": None,"r2_url": None,"error": None}
141
  urls = generate_one_image_optimized(model_key, prompt, aspect_ratio)
142
  if not urls:
143
- result["error"] = "No URLs returned"
144
  return result
145
  source_url = urls[0]
146
  result["source_url"] = source_url
@@ -156,136 +176,322 @@ def process_single_image(args: Tuple[str,str,str,int]) -> Dict[str,Any]:
156
  result["error"] = "Failed to upload to R2"
157
  return result
158
 
159
- def generate_images_parallel(model_key: str, aspect_ratio: str, prompt: str) -> Tuple[List[str],List[str],List[str]]:
160
- result = process_single_image((model_key, prompt, aspect_ratio, 0))
161
- if result["success"]:
162
- return [result["r2_url"]], [result["source_url"]], []
 
163
  else:
164
- return [], [], [result["error"] or "Generation failed"]
165
 
166
- def save_creative_record_optimized(model_key:str,aspect_ratio:str,prompt:str,urls:List[str]) -> Optional[str]:
 
 
 
167
  collection = get_mongo_collection()
168
  if collection is None:
169
  return None
170
  try:
171
- doc = {"model": model_key,"aspect_ratio": aspect_ratio,"prompt": prompt,"urls": urls,"num_images": len(urls),"lob": "json_batch","created_at": datetime.utcnow()}
172
- return str(collection.insert_one(doc).inserted_id)
 
 
 
 
 
 
 
 
 
173
  except Exception as e:
174
  logger.error(f"Mongo insert failed: {e}")
175
  return None
176
 
177
  @st.cache_data(ttl=300)
178
- def query_creatives_optimized(start_dt:datetime,end_dt:datetime,page:int=0)->Tuple[List[Dict[str,Any]],int]:
179
  collection = get_mongo_collection()
180
  if collection is None:
181
- return [],0
182
  try:
183
- total_count = collection.count_documents({"created_at":{"$gte":start_dt,"$lt":end_dt},"lob":"json_batch"})
184
- cursor = collection.find({"created_at":{"$gte":start_dt,"$lt":end_dt},"lob":"json_batch"}).sort("created_at",-1).skip(page*LIBRARY_PAGE_SIZE).limit(LIBRARY_PAGE_SIZE)
 
 
 
 
 
185
  return list(cursor), total_count
186
  except Exception:
187
- return [],0
 
 
 
 
 
 
 
188
 
189
- def display_image_with_download_optimized(url:str):
190
  try:
191
- img_bytes = fetch_image_bytes_optimized(url)
192
  if not img_bytes:
193
  st.error("Failed to load image")
194
  return
195
  st.image(img_bytes, use_container_width=True)
196
  base = os.path.basename(urlparse(url).path) or "image.png"
197
  if not os.path.splitext(base)[1]:
198
- base += ".png"
199
- st.download_button("Download image", data=img_bytes, file_name=base, mime="image/png", use_container_width=True)
 
 
 
 
 
 
200
  except Exception as e:
201
  st.error(f"Failed to display image: {e}")
202
 
203
  def display_image_gallery_optimized(urls: List[str]):
204
- if not urls: return
205
- cols = st.columns(min(4,len(urls)))
206
- for i,url in enumerate(urls):
207
- with cols[i % len(cols)]:
 
 
208
  display_image_with_download_optimized(url)
209
 
210
- def bulk_download_button(urls: List[str], filename="images_bundle.zip"):
211
- if not urls: return
 
212
  zip_buffer = io.BytesIO()
213
- with zipfile.ZipFile(zip_buffer,"w",compression=zipfile.ZIP_DEFLATED) as zipf:
214
- for i,url in enumerate(urls,1):
215
- img = fetch_image_bytes_optimized(url)
216
- if img:
217
- base = os.path.basename(urlparse(url).path) or f"img_{i}.png"
218
- if not os.path.splitext(base)[1]:
219
- base += ".png"
220
- zipf.writestr(base,img)
 
 
 
 
221
  zip_buffer.seek(0)
222
- st.download_button("Download All Images",data=zip_buffer,file_name=filename,mime="application/zip",use_container_width=True)
 
 
 
 
 
 
223
 
224
- def load_json_prompts(file)->List[Dict[str,str]]:
225
- raw=file.getvalue().decode("utf-8")
226
- data=json.loads(raw)
227
- if not isinstance(data,dict) or "prompts" not in data or not isinstance(data["prompts"],list):
228
- raise ValueError("JSON must be { 'prompts': [ 'string', ... ] }")
229
- return [{"id":f"p{i}","content":p.strip()} for i,p in enumerate(data["prompts"],1) if isinstance(p,str) and p.strip()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
  def render_json_page():
232
  st.subheader("Generate from JSON Prompts")
233
- up=st.file_uploader("Upload prompts JSON",type=["json"])
234
- col1,col2=st.columns([1,1])
235
- with col1: default_model=st.selectbox("Default Model",list(MODEL_REGISTRY.keys()),0)
236
- with col2: default_aspect=st.selectbox("Default Aspect Ratio",MODEL_REGISTRY[default_model]["aspect_ratios"],0)
237
- debug=st.checkbox("Debug Mode",False)
 
 
 
 
 
 
238
  if up:
239
  try:
240
- prompts=load_json_prompts(up)
241
- st.json(prompts)
242
- if st.button("Generate for All Prompts",type="primary",use_container_width=True):
243
- handle_bulk_json_generation(prompts,default_model,default_aspect,debug)
 
 
 
 
244
  except Exception as e:
245
- st.error(str(e))
246
-
247
- def handle_bulk_json_generation(prompts:List[Dict[str,str]],default_model:str,default_aspect:str,debug:bool):
248
- total=len(prompts)
249
- all_urls=[]
250
- for i,p in enumerate(prompts,1):
251
- st.markdown(f"**Prompt {i}/{total}** {p['content']}")
252
- r2,src,errs=generate_images_parallel(default_model,default_aspect,p["content"])
253
- if r2:
254
- save_creative_record_optimized(default_model,default_aspect,p["content"],r2)
255
- display_image_gallery_optimized(r2)
256
- all_urls.extend(r2)
257
- elif src:
258
- display_image_gallery_optimized(src)
259
- all_urls.extend(src)
260
- else:
261
- st.error("No image generated")
262
- if errs and debug:
263
- st.error(errs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  if all_urls:
265
- st.subheader("Download All")
266
- bulk_download_button(all_urls,"all_prompts_images.zip")
 
 
267
 
 
 
 
268
  def render_library_page():
269
  st.subheader("Creative Library")
270
- today=datetime.utcnow().date()
271
- start_date=st.date_input("Start date",today-timedelta(days=30))
272
- end_date=st.date_input("End date",today)
273
- start_dt=datetime.combine(start_date,datetime.min.time())
274
- end_dt=datetime.combine(end_date+timedelta(days=1),datetime.min.time())
275
- records,total=query_creatives_optimized(start_dt,end_dt,0)
276
- st.caption(f"{total} items")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  for rec in records:
278
- urls=rec.get("urls",[])
279
- if urls: display_image_gallery_optimized(urls)
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
 
 
 
281
  def main_app():
 
282
  st.title("File-to-Image Generator")
283
- page=st.sidebar.radio("Navigation",["Generate from JSON","Creative Library"])
284
- if page=="Generate from JSON": render_json_page()
285
- else: render_library_page()
 
 
 
286
 
287
  def main():
288
- main_app()
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
290
- if __name__=="__main__":
291
  main()
 
1
+ # main.py
2
  import os
3
  import io
4
  import zipfile
 
24
  logging.basicConfig(level=logging.INFO)
25
  logger = logging.getLogger("imagegen_app")
26
 
27
+ # ----------------------------
28
+ # Config / Constants
29
+ # ----------------------------
30
  REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN")
31
  MONGO_URI = os.getenv("MONGO_URI")
32
  MONGO_DB = os.getenv("MONGO_DB", "adgenesis_image_text")
33
  MONGO_COLLECTION = os.getenv("MONGO_COLLECTION", "creatives")
34
+
35
  MAX_WORKERS = min(32, (os.cpu_count() or 1) + 4)
36
  REQUEST_TIMEOUT = 30
37
  RETRY_ATTEMPTS = 3
38
  LIBRARY_PAGE_SIZE = 20
39
 
40
+ # Model registry (subset with common ARs)
41
  MODEL_REGISTRY: Dict[str, Dict[str, Any]] = {
42
+ "imagegen-4-ultra": {"id": "google/imagen-4-ultra", "aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"], "param_name": "aspect_ratio"},
43
+ "imagen-4": {"id": "google/imagen-4", "aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"], "param_name": "aspect_ratio"},
44
+ "nano-banana": {"id": "google/nano-banana", "aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"], "param_name": "aspect_ratio"},
45
+ "qwen": {"id": "qwen/qwen-image", "aspect_ratios": ["1:1","16:9","9:16","3:4","4:3","3:2","2:3"], "param_name": "aspect_ratio"},
46
+ "seedream-3": {"id": "bytedance/seedream-3", "aspect_ratios": ["1:1","16:9","9:16","3:4","4:3","3:2","2:3","21:9"], "param_name": "aspect_ratio"},
47
+ "recraft-v3": {"id": "recraft-ai/recraft-v3", "aspect_ratios": ["1:1","4:3","3:4","3:2","2:3","16:9","9:16"], "param_name": "aspect_ratio"},
48
+ "photon": {"id": "luma/photon", "aspect_ratios": ["1:1","3:4","4:3","9:16","16:9","21:9"], "param_name": "aspect_ratio"},
49
+ "ideogram-v3-quality":{"id": "ideogram-ai/ideogram-v3-quality", "aspect_ratios": ["1:1","16:9","9:16","2:3","3:2","4:5","5:4"], "param_name": "aspect_ratio"},
50
  }
51
 
52
  _thread_local = threading.local()
53
 
54
+ # ----------------------------
55
+ # Infra helpers (Mongo / S3)
56
+ # ----------------------------
57
  def get_mongo_collection():
58
+ if not hasattr(_thread_local, 'mongo_collection'):
59
  if not MONGO_URI:
60
  _thread_local.mongo_collection = None
61
  return None
 
63
  client = MongoClient(MONGO_URI, serverSelectionTimeoutMS=3000)
64
  db = client[MONGO_DB]
65
  collection = db[MONGO_COLLECTION]
66
+ client.admin.command('ping')
67
  _thread_local.mongo_collection = collection
68
  except Exception as e:
69
  logger.error(f"MongoDB connection failed: {e}")
 
71
  return _thread_local.mongo_collection
72
 
73
  def get_s3_client():
74
+ if not hasattr(_thread_local, 's3_client'):
75
+ required_vars = ["R2_ENDPOINT", "R2_ACCESS_KEY", "R2_SECRET_KEY", "R2_BUCKET_NAME"]
76
  missing = [var for var in required_vars if not os.getenv(var)]
77
  if missing:
78
  _thread_local.s3_client = None
 
94
  def get_model_config(model_key: str) -> Optional[Dict[str, Any]]:
95
  return MODEL_REGISTRY.get(model_key)
96
 
97
+ # ----------------------------
98
+ # R2 upload
99
+ # ----------------------------
100
  def upload_to_r2_optimized(image_bytes: bytes) -> Optional[str]:
101
  s3_client = get_s3_client()
102
  if not s3_client:
103
  return None
104
  try:
105
  filename = f"{uuid4().hex}.png"
106
+ file_key = f"adgenesis_image_file/balraaj/images/{filename}"
107
  s3_client.put_object(
108
  Bucket=os.getenv("R2_BUCKET_NAME"),
109
  Key=file_key,
110
  Body=image_bytes,
111
  ContentType="image/png",
112
  )
113
+ r2_url = f'{os.getenv("NEW_BASE").rstrip("/")}/{file_key}'
114
+ return r2_url
115
  except Exception as e:
116
  logger.error(f"S3 upload failed: {e}")
117
  return None
118
 
119
+ # ----------------------------
120
+ # Generation & fetching
121
+ # ----------------------------
122
  def generate_one_image_optimized(model_key: str, prompt: str, aspect_ratio: str) -> List[str]:
123
  if not REPLICATE_API_TOKEN:
124
  return []
 
130
  ar_param = config["param_name"]
131
  inputs = {"prompt": prompt, ar_param: aspect_ratio}
132
  output = replicate.run(model_id, input=inputs)
133
+ # Normalize to list[str]
134
  if isinstance(output, list) and output:
135
+ first = output[0]
136
+ return [getattr(first, "url", str(first))]
137
  elif isinstance(output, str):
138
  return [output]
139
  elif hasattr(output, "url"):
140
  return [getattr(output, "url")]
141
+ return []
142
  except Exception as e:
143
  logger.error(f"Replicate error: {e}")
144
+ return []
145
 
146
  def fetch_image_bytes_optimized(url: str) -> Optional[bytes]:
147
  for attempt in range(RETRY_ATTEMPTS):
 
155
  time.sleep(1)
156
  return None
157
 
158
+ def process_single_image(args: Tuple[str, str, str, int]) -> Dict[str, Any]:
159
  model_key, prompt, aspect_ratio, index = args
160
+ result = {"index": index, "success": False, "source_url": None, "r2_url": None, "error": None}
161
  urls = generate_one_image_optimized(model_key, prompt, aspect_ratio)
162
  if not urls:
163
+ result["error"] = "No URLs returned from generation"
164
  return result
165
  source_url = urls[0]
166
  result["source_url"] = source_url
 
176
  result["error"] = "Failed to upload to R2"
177
  return result
178
 
179
+ def generate_one_per_prompt(model_key: str, aspect_ratio: str, prompt: str) -> Tuple[List[str], List[str], List[str]]:
180
+ """One image per prompt (no parallel within a prompt)."""
181
+ res = process_single_image((model_key, prompt, aspect_ratio, 0))
182
+ if res["success"]:
183
+ return [res["r2_url"]], [res["source_url"]], []
184
  else:
185
+ return [], [], [res["error"] or "Generation failed"]
186
 
187
+ # ----------------------------
188
+ # Persistence
189
+ # ----------------------------
190
+ def save_creative_record_optimized(model_key: str, aspect_ratio: str, prompt: str, urls: List[str]) -> Optional[str]:
191
  collection = get_mongo_collection()
192
  if collection is None:
193
  return None
194
  try:
195
+ doc = {
196
+ "model": model_key,
197
+ "aspect_ratio": aspect_ratio,
198
+ "prompt": prompt,
199
+ "urls": urls,
200
+ "num_images": len(urls),
201
+ "lob": "balraaj",
202
+ "created_at": datetime.utcnow()
203
+ }
204
+ ins = collection.insert_one(doc)
205
+ return str(ins.inserted_id)
206
  except Exception as e:
207
  logger.error(f"Mongo insert failed: {e}")
208
  return None
209
 
210
  @st.cache_data(ttl=300)
211
+ def query_creatives_optimized(start_dt: datetime, end_dt: datetime, page: int = 0) -> Tuple[List[Dict[str, Any]], int]:
212
  collection = get_mongo_collection()
213
  if collection is None:
214
+ return [], 0
215
  try:
216
+ total_count = collection.count_documents({"created_at": {"$gte": start_dt, "$lt": end_dt}, "lob": "balraaj"})
217
+ cursor = (
218
+ collection.find({"created_at": {"$gte": start_dt, "$lt": end_dt}, "lob": "balraaj"})
219
+ .sort("created_at", -1)
220
+ .skip(page * LIBRARY_PAGE_SIZE)
221
+ .limit(LIBRARY_PAGE_SIZE)
222
+ )
223
  return list(cursor), total_count
224
  except Exception:
225
+ return [], 0
226
+
227
+ # ----------------------------
228
+ # UI helpers: images
229
+ # ----------------------------
230
+ @st.cache_data(ttl=3600)
231
+ def get_image_bytes_cached(url: str) -> Optional[bytes]:
232
+ return fetch_image_bytes_optimized(url)
233
 
234
+ def display_image_with_download_optimized(url: str):
235
  try:
236
+ img_bytes = get_image_bytes_cached(url)
237
  if not img_bytes:
238
  st.error("Failed to load image")
239
  return
240
  st.image(img_bytes, use_container_width=True)
241
  base = os.path.basename(urlparse(url).path) or "image.png"
242
  if not os.path.splitext(base)[1]:
243
+ base = f"{base}.png"
244
+ st.download_button(
245
+ label="Download image",
246
+ data=img_bytes,
247
+ file_name=base,
248
+ mime="image/png",
249
+ use_container_width=True
250
+ )
251
  except Exception as e:
252
  st.error(f"Failed to display image: {e}")
253
 
254
  def display_image_gallery_optimized(urls: List[str]):
255
+ if not urls:
256
+ return
257
+ num_cols = min(4, max(1, len(urls)))
258
+ cols = st.columns(num_cols)
259
+ for i, url in enumerate(urls):
260
+ with cols[i % num_cols]:
261
  display_image_with_download_optimized(url)
262
 
263
+ def bulk_download_button(urls: List[str], filename: str = "images_bundle.zip"):
264
+ if not urls:
265
+ return
266
  zip_buffer = io.BytesIO()
267
+ with zipfile.ZipFile(zip_buffer, "w", compression=zipfile.ZIP_DEFLATED) as zip_file:
268
+ for idx, url in enumerate(urls, 1):
269
+ try:
270
+ img_bytes = fetch_image_bytes_optimized(url)
271
+ if img_bytes:
272
+ path = urlparse(url).path
273
+ base = os.path.basename(path) or f"image_{idx}.png"
274
+ if not os.path.splitext(base)[1]:
275
+ base = f"{base}.png"
276
+ zip_file.writestr(base, img_bytes)
277
+ except Exception:
278
+ pass
279
  zip_buffer.seek(0)
280
+ st.download_button(
281
+ "Download All Images",
282
+ data=zip_buffer,
283
+ file_name=filename,
284
+ mime="application/zip",
285
+ use_container_width=True
286
+ )
287
 
288
+ # ----------------------------
289
+ # JSON loader (STRICT)
290
+ # ----------------------------
291
+ def load_json_prompts(file) -> List[Dict[str, Any]]:
292
+ raw = file.getvalue().decode("utf-8", errors="replace")
293
+ data = json.loads(raw)
294
+ if not isinstance(data, dict) or "prompts" not in data or not isinstance(data["prompts"], list):
295
+ raise ValueError("Invalid JSON. Expected an object with a 'prompts' array of strings.")
296
+ prompts_out: List[Dict[str, Any]] = []
297
+ for i, item in enumerate(data["prompts"], 1):
298
+ if not isinstance(item, str) or not item.strip():
299
+ raise ValueError(f"'prompts[{i-1}]' must be a non-empty string.")
300
+ prompts_out.append({"id": f"p{i}", "content": item.strip()})
301
+ return prompts_out
302
+
303
+ # ----------------------------
304
+ # JSON page (parallel across prompts)
305
+ # ----------------------------
306
+ def _run_single_prompt(idx: int, prompt_text: str, model_key: str, aspect_ratio: str):
307
+ r2_urls, src_urls, gen_errors = generate_one_per_prompt(model_key, aspect_ratio, prompt_text)
308
+ rec_id = None
309
+ if r2_urls:
310
+ rec_id = save_creative_record_optimized(model_key, aspect_ratio, prompt_text, r2_urls)
311
+ return {
312
+ "idx": idx,
313
+ "prompt": prompt_text,
314
+ "r2_urls": r2_urls,
315
+ "src_urls": src_urls,
316
+ "errors": gen_errors,
317
+ "rec_id": rec_id,
318
+ }
319
 
320
  def render_json_page():
321
  st.subheader("Generate from JSON Prompts")
322
+ up = st.file_uploader("Upload prompts JSON", type=["json"])
323
+
324
+ col1, col2 = st.columns([1, 1])
325
+ with col1:
326
+ default_model = st.selectbox("Default Model", list(MODEL_REGISTRY.keys()), index=0)
327
+ with col2:
328
+ aspect_options = MODEL_REGISTRY[default_model]["aspect_ratios"]
329
+ default_aspect = st.selectbox("Default Aspect Ratio", aspect_options, index=0, key="json_default_ar")
330
+
331
+ debug_mode = st.checkbox("Debug Mode", value=False, key="json_debug")
332
+
333
  if up:
334
  try:
335
+ prompts_list = load_json_prompts(up)
336
+ with st.expander("Preview normalized prompts", expanded=False):
337
+ st.json(prompts_list, expanded=False)
338
+
339
+ if st.button("Generate for All Prompts", type="primary", use_container_width=True):
340
+ handle_bulk_json_generation_parallel(prompts_list, default_model, default_aspect, debug_mode)
341
+ except json.JSONDecodeError as e:
342
+ st.error(f"Invalid JSON: {e}")
343
  except Exception as e:
344
+ st.error(f"Failed to read prompts: {e}")
345
+ else:
346
+ st.caption('Expected format: { "prompts": ["prompt 1", "prompt 2", ...] }')
347
+
348
+ def handle_bulk_json_generation_parallel(prompts: List[Dict[str, str]], default_model: str, default_aspect: str, debug: bool):
349
+ if not REPLICATE_API_TOKEN:
350
+ st.error("Missing REPLICATE_API_TOKEN. Set it as an environment variable.")
351
+ return
352
+ total = len(prompts)
353
+ if total == 0:
354
+ st.info("No prompts to process.")
355
+ return
356
+
357
+ # Placeholders for stable on-page order
358
+ blocks = [st.container(border=True) for _ in range(total)]
359
+ progress = st.progress(0, text=f"Starting batch • 0/{total}")
360
+
361
+ all_urls: List[str] = []
362
+ completed = 0
363
+
364
+ max_workers = min(MAX_WORKERS, max(2, (os.cpu_count() or 2)))
365
+
366
+ with st.spinner("Generating images..."):
367
+ futures = {}
368
+ with ThreadPoolExecutor(max_workers=max_workers) as ex:
369
+ for i, p in enumerate(prompts, 1):
370
+ prompt_text = p.get("content", "").strip()
371
+ if not prompt_text:
372
+ # render immediately as invalid
373
+ with blocks[i-1]:
374
+ st.markdown(f"**Prompt {i}/{total}** — (empty)")
375
+ st.error("Prompt text is empty. Skipping.")
376
+ completed += 1
377
+ progress.progress(completed / total, text=f"Processed {completed}/{total}")
378
+ continue
379
+ futures[ex.submit(_run_single_prompt, i, prompt_text, default_model, default_aspect)] = i
380
+
381
+ for fut in as_completed(futures):
382
+ i = futures[fut]
383
+ try:
384
+ res = fut.result()
385
+ except Exception as e:
386
+ res = {"idx": i, "prompt": "", "r2_urls": [], "src_urls": [], "errors": [str(e)], "rec_id": None}
387
+
388
+ with blocks[i-1]:
389
+ st.markdown(f"**Prompt {i}/{total}** — Model: `{default_model}` • Aspect: `{default_aspect}` • Num: `1`")
390
+ st.code(res.get("prompt") or "(empty)", language="markdown")
391
+
392
+ if res["r2_urls"]:
393
+ st.success(f"Generated 1 image. DB: {res['rec_id'] or 'N/A'}")
394
+ display_image_gallery_optimized(res["r2_urls"])
395
+ bulk_download_button(res["r2_urls"], filename=f"prompt_{i}_image.zip")
396
+ all_urls.extend(res["r2_urls"])
397
+ elif res["src_urls"]:
398
+ st.warning("Image generated but R2 upload failed. Showing original:")
399
+ display_image_gallery_optimized(res["src_urls"])
400
+ bulk_download_button(res["src_urls"], filename=f"prompt_{i}_image.zip")
401
+ all_urls.extend(res["src_urls"])
402
+ else:
403
+ st.error("No image was generated for this prompt.")
404
+
405
+ if res.get("errors") and debug:
406
+ for e in res["errors"]:
407
+ st.error(e)
408
+
409
+ completed += 1
410
+ progress.progress(completed / total, text=f"Processed {completed}/{total}")
411
+
412
+ # Final all-images gallery & ZIP
413
  if all_urls:
414
+ st.subheader("All Images Gallery")
415
+ display_image_gallery_optimized(all_urls)
416
+ st.subheader("Download All Generated")
417
+ bulk_download_button(all_urls, filename="all_prompts_images.zip")
418
 
419
+ # ----------------------------
420
+ # Creative Library page
421
+ # ----------------------------
422
  def render_library_page():
423
  st.subheader("Creative Library")
424
+ if "library_page" not in st.session_state:
425
+ st.session_state.library_page = 0
426
+
427
+ today_utc = datetime.utcnow().date()
428
+ default_start = today_utc - timedelta(days=30)
429
+
430
+ c1, c2, c3 = st.columns([1, 1, 1])
431
+ with c1:
432
+ start_date: date = st.date_input("Start date", value=default_start)
433
+ with c2:
434
+ end_date: date = st.date_input("End date", value=today_utc)
435
+ with c3:
436
+ if st.button("Apply Filters", use_container_width=True):
437
+ st.session_state.library_page = 0
438
+ st.cache_data.clear()
439
+
440
+ start_dt = datetime.combine(start_date, datetime.min.time())
441
+ end_dt = datetime.combine(end_date + timedelta(days=1), datetime.min.time())
442
+
443
+ records, total_count = query_creatives_optimized(start_dt, end_dt, st.session_state.library_page)
444
+ if not records and st.session_state.library_page == 0:
445
+ st.info("No creatives found for the selected dates.")
446
+ return
447
+
448
+ st.caption(f"Total items: {total_count}")
449
+ # simple gallery by record
450
  for rec in records:
451
+ urls = rec.get("urls", []) or []
452
+ if urls:
453
+ display_image_gallery_optimized(urls)
454
+
455
+ # ----------------------------
456
+ # Auth
457
+ # ----------------------------
458
+ @lru_cache(maxsize=1)
459
+ def check_token_cached(user_token: str) -> Tuple[bool, str]:
460
+ ACCESS_TOKEN = os.getenv("ACCESS_TOKEN")
461
+ if not ACCESS_TOKEN:
462
+ return False, "Server error: Access token not configured."
463
+ if user_token == ACCESS_TOKEN:
464
+ return True, ""
465
+ return False, "Invalid token."
466
 
467
+ # ----------------------------
468
+ # App shell
469
+ # ----------------------------
470
  def main_app():
471
+ st.set_page_config(page_title="File-to-Image • Creative Library", layout="wide")
472
  st.title("File-to-Image Generator")
473
+ with st.sidebar:
474
+ page = st.radio("Navigation", ["Generate from JSON", "Creative Library"], index=0)
475
+ if page == "Generate from JSON":
476
+ render_json_page()
477
+ else:
478
+ render_library_page()
479
 
480
  def main():
481
+ if "authenticated" not in st.session_state:
482
+ st.session_state["authenticated"] = False
483
+ if not st.session_state["authenticated"]:
484
+ st.markdown("## Access Required")
485
+ token_input = st.text_input("Enter Access Token", type="password")
486
+ if st.button("Unlock App"):
487
+ ok, error_msg = check_token_cached(token_input)
488
+ if ok:
489
+ st.session_state["authenticated"] = True
490
+ st.rerun()
491
+ else:
492
+ st.error(error_msg)
493
+ else:
494
+ main_app()
495
 
496
+ if __name__ == "__main__":
497
  main()