Spaces:
Running
Running
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +59 -25
src/streamlit_app.py
CHANGED
|
@@ -35,7 +35,12 @@ MONGO_COLLECTION = os.getenv("MONGO_COLLECTION", "creatives")
|
|
| 35 |
REQUEST_TIMEOUT = 30
|
| 36 |
RETRY_ATTEMPTS = 3
|
| 37 |
MAX_WORKERS = min(32, (os.cpu_count() or 1) + 4)
|
| 38 |
-
LIBRARY_PAGE_SIZE = 200 #
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
MODEL_REGISTRY: Dict[str, Dict[str, Any]] = {
|
| 41 |
"imagegen-4-ultra": {"id": "google/imagen-4-ultra","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"],"param_name": "aspect_ratio"},
|
|
@@ -65,7 +70,6 @@ def show_env_warnings():
|
|
| 65 |
# ----------------------------
|
| 66 |
def get_replicate_client():
|
| 67 |
if not hasattr(_thread_local, "replicate_client"):
|
| 68 |
-
# Explicit client avoids env-specific issues with module-level run()
|
| 69 |
_thread_local.replicate_client = replicate.Client(api_token=REPLICATE_API_TOKEN) if REPLICATE_API_TOKEN else None
|
| 70 |
return _thread_local.replicate_client
|
| 71 |
|
|
@@ -160,22 +164,46 @@ def fetch_bytes(url: str) -> Optional[bytes]:
|
|
| 160 |
time.sleep(1)
|
| 161 |
return None
|
| 162 |
|
| 163 |
-
def
|
| 164 |
"""
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
- try to upload to R2
|
| 168 |
-
- fallback to source url if R2 not available
|
| 169 |
"""
|
| 170 |
-
|
|
|
|
| 171 |
if not urls:
|
| 172 |
-
return
|
| 173 |
src = urls[0]
|
| 174 |
data = fetch_bytes(src)
|
| 175 |
if data is None:
|
| 176 |
-
return
|
| 177 |
r2 = upload_to_r2(data)
|
| 178 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
|
| 180 |
# ----------------------------
|
| 181 |
# Persistence
|
|
@@ -190,6 +218,7 @@ def save_record(model: str, aspect: str, prompt: str, urls: List[str]):
|
|
| 190 |
"aspect_ratio": aspect,
|
| 191 |
"prompt": prompt,
|
| 192 |
"urls": urls,
|
|
|
|
| 193 |
"lob": "balraaj",
|
| 194 |
"created_at": datetime.utcnow(),
|
| 195 |
}).inserted_id)
|
|
@@ -200,7 +229,6 @@ def save_record(model: str, aspect: str, prompt: str, urls: List[str]):
|
|
| 200 |
@st.cache_data(ttl=300)
|
| 201 |
def query_records(start: datetime, end: datetime) -> List[Dict[str, Any]]:
|
| 202 |
coll = get_mongo_collection()
|
| 203 |
-
# FIX: compare with None explicitly (avoids NotImplementedError)
|
| 204 |
if coll is None:
|
| 205 |
return []
|
| 206 |
try:
|
|
@@ -240,7 +268,7 @@ def bulk_zip(urls: List[str]):
|
|
| 240 |
st.download_button("Download All", buf, "images.zip", "application/zip", use_container_width=True)
|
| 241 |
|
| 242 |
# ----------------------------
|
| 243 |
-
# JSON loader & batch run (parallel)
|
| 244 |
# ----------------------------
|
| 245 |
def load_json(file) -> List[str]:
|
| 246 |
data = json.loads(file.getvalue().decode("utf-8"))
|
|
@@ -251,7 +279,7 @@ def load_json(file) -> List[str]:
|
|
| 251 |
raise ValueError("No valid prompts found.")
|
| 252 |
return out
|
| 253 |
|
| 254 |
-
def run_batch(prompts: List[str], model: str, aspect: str):
|
| 255 |
total = len(prompts)
|
| 256 |
rows = [st.empty() for _ in prompts]
|
| 257 |
progress = st.progress(0.0, text=f"0/{total}")
|
|
@@ -259,22 +287,25 @@ def run_batch(prompts: List[str], model: str, aspect: str):
|
|
| 259 |
all_urls: List[str] = []
|
| 260 |
done = 0
|
| 261 |
|
| 262 |
-
|
| 263 |
with st.spinner("Generating images..."):
|
| 264 |
-
with ThreadPoolExecutor(max_workers=
|
| 265 |
-
futs = {ex.submit(
|
| 266 |
for fut in as_completed(futs):
|
| 267 |
i = futs[fut]
|
| 268 |
try:
|
| 269 |
res = fut.result()
|
| 270 |
except Exception as e:
|
| 271 |
-
res = {"idx": i, "urls": [], "
|
|
|
|
| 272 |
if res["urls"]:
|
| 273 |
save_record(model, aspect, prompts[i-1], res["urls"])
|
| 274 |
-
rows[i-1].success(f"Prompt {i}/{total} β")
|
| 275 |
all_urls.extend(res["urls"])
|
| 276 |
else:
|
| 277 |
-
|
|
|
|
|
|
|
| 278 |
done += 1
|
| 279 |
progress.progress(done/total, text=f"{done}/{total}")
|
| 280 |
|
|
@@ -291,12 +322,15 @@ def run_batch(prompts: List[str], model: str, aspect: str):
|
|
| 291 |
def render_json_page():
|
| 292 |
st.subheader("Generate from JSON")
|
| 293 |
show_env_warnings()
|
|
|
|
| 294 |
up = st.file_uploader("Upload JSON", type=["json"])
|
| 295 |
-
|
| 296 |
-
with
|
| 297 |
model = st.selectbox("Model", list(MODEL_REGISTRY.keys()), 0)
|
| 298 |
-
with
|
| 299 |
aspect = st.selectbox("Aspect", MODEL_REGISTRY[model]["aspect_ratios"], 0)
|
|
|
|
|
|
|
| 300 |
|
| 301 |
if up:
|
| 302 |
try:
|
|
@@ -304,7 +338,7 @@ def render_json_page():
|
|
| 304 |
with st.expander("Preview prompts", expanded=False):
|
| 305 |
st.json(prompts)
|
| 306 |
if st.button("Generate", type="primary", use_container_width=True):
|
| 307 |
-
run_batch(prompts, model, aspect)
|
| 308 |
except Exception as e:
|
| 309 |
st.error(str(e))
|
| 310 |
else:
|
|
@@ -339,7 +373,7 @@ def render_library_page():
|
|
| 339 |
def check_token(tok: str) -> bool:
|
| 340 |
acc = os.getenv("ACCESS_TOKEN")
|
| 341 |
if not acc:
|
| 342 |
-
#
|
| 343 |
return True
|
| 344 |
return tok == acc
|
| 345 |
|
|
|
|
| 35 |
REQUEST_TIMEOUT = 30
|
| 36 |
RETRY_ATTEMPTS = 3
|
| 37 |
MAX_WORKERS = min(32, (os.cpu_count() or 1) + 4)
|
| 38 |
+
LIBRARY_PAGE_SIZE = 200 # aggregated view
|
| 39 |
+
|
| 40 |
+
# Global throttle so we don't overload Replicate / network
|
| 41 |
+
# This caps total in-flight image generations across all prompts.
|
| 42 |
+
GLOBAL_CONCURRENCY = max(4, min(16, MAX_WORKERS))
|
| 43 |
+
_GEN_SEMAPHORE = threading.Semaphore(GLOBAL_CONCURRENCY)
|
| 44 |
|
| 45 |
MODEL_REGISTRY: Dict[str, Dict[str, Any]] = {
|
| 46 |
"imagegen-4-ultra": {"id": "google/imagen-4-ultra","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"],"param_name": "aspect_ratio"},
|
|
|
|
| 70 |
# ----------------------------
|
| 71 |
def get_replicate_client():
|
| 72 |
if not hasattr(_thread_local, "replicate_client"):
|
|
|
|
| 73 |
_thread_local.replicate_client = replicate.Client(api_token=REPLICATE_API_TOKEN) if REPLICATE_API_TOKEN else None
|
| 74 |
return _thread_local.replicate_client
|
| 75 |
|
|
|
|
| 164 |
time.sleep(1)
|
| 165 |
return None
|
| 166 |
|
| 167 |
+
def _generate_single_image_full(prompt: str, model: str, aspect: str) -> Tuple[Optional[str], Optional[str]]:
|
| 168 |
"""
|
| 169 |
+
Throttled: generate -> fetch -> upload -> return final_url or src_url
|
| 170 |
+
Returns (final_url_or_src, error_text)
|
|
|
|
|
|
|
| 171 |
"""
|
| 172 |
+
with _GEN_SEMAPHORE: # throttle
|
| 173 |
+
urls = generate_one(model, prompt, aspect)
|
| 174 |
if not urls:
|
| 175 |
+
return None, "No URLs"
|
| 176 |
src = urls[0]
|
| 177 |
data = fetch_bytes(src)
|
| 178 |
if data is None:
|
| 179 |
+
return None, "Fetch failed"
|
| 180 |
r2 = upload_to_r2(data)
|
| 181 |
+
return (r2 or src), None
|
| 182 |
+
|
| 183 |
+
def process_prompt_batch(idx: int, prompt: str, model: str, aspect: str, num_images: int) -> Dict[str, Any]:
|
| 184 |
+
"""
|
| 185 |
+
Generate N images for one prompt.
|
| 186 |
+
Uses a small threadpool but still obeys GLOBAL throttle.
|
| 187 |
+
"""
|
| 188 |
+
n = max(1, int(num_images))
|
| 189 |
+
urls: List[str] = []
|
| 190 |
+
errs: List[str] = []
|
| 191 |
+
|
| 192 |
+
# limit inner parallelism to keep global pressure sane
|
| 193 |
+
inner_workers = min( min(4, n), GLOBAL_CONCURRENCY )
|
| 194 |
+
if n == 1:
|
| 195 |
+
u, e = _generate_single_image_full(prompt, model, aspect)
|
| 196 |
+
if u: urls.append(u)
|
| 197 |
+
if e: errs.append(e)
|
| 198 |
+
else:
|
| 199 |
+
with ThreadPoolExecutor(max_workers=inner_workers) as ex:
|
| 200 |
+
futures = [ex.submit(_generate_single_image_full, prompt, model, aspect) for _ in range(n)]
|
| 201 |
+
for fut in as_completed(futures):
|
| 202 |
+
u, e = fut.result()
|
| 203 |
+
if u: urls.append(u)
|
| 204 |
+
if e: errs.append(e)
|
| 205 |
+
|
| 206 |
+
return {"idx": idx, "urls": urls, "errors": errs}
|
| 207 |
|
| 208 |
# ----------------------------
|
| 209 |
# Persistence
|
|
|
|
| 218 |
"aspect_ratio": aspect,
|
| 219 |
"prompt": prompt,
|
| 220 |
"urls": urls,
|
| 221 |
+
"num_images": len(urls),
|
| 222 |
"lob": "balraaj",
|
| 223 |
"created_at": datetime.utcnow(),
|
| 224 |
}).inserted_id)
|
|
|
|
| 229 |
@st.cache_data(ttl=300)
|
| 230 |
def query_records(start: datetime, end: datetime) -> List[Dict[str, Any]]:
|
| 231 |
coll = get_mongo_collection()
|
|
|
|
| 232 |
if coll is None:
|
| 233 |
return []
|
| 234 |
try:
|
|
|
|
| 268 |
st.download_button("Download All", buf, "images.zip", "application/zip", use_container_width=True)
|
| 269 |
|
| 270 |
# ----------------------------
|
| 271 |
+
# JSON loader & batch run (parallel across prompts)
|
| 272 |
# ----------------------------
|
| 273 |
def load_json(file) -> List[str]:
|
| 274 |
data = json.loads(file.getvalue().decode("utf-8"))
|
|
|
|
| 279 |
raise ValueError("No valid prompts found.")
|
| 280 |
return out
|
| 281 |
|
| 282 |
+
def run_batch(prompts: List[str], model: str, aspect: str, num_images: int):
|
| 283 |
total = len(prompts)
|
| 284 |
rows = [st.empty() for _ in prompts]
|
| 285 |
progress = st.progress(0.0, text=f"0/{total}")
|
|
|
|
| 287 |
all_urls: List[str] = []
|
| 288 |
done = 0
|
| 289 |
|
| 290 |
+
outer_workers = min(MAX_WORKERS, max(2, (os.cpu_count() or 2)))
|
| 291 |
with st.spinner("Generating images..."):
|
| 292 |
+
with ThreadPoolExecutor(max_workers=outer_workers) as ex:
|
| 293 |
+
futs = {ex.submit(process_prompt_batch, i, p, model, aspect, num_images): i for i, p in enumerate(prompts, 1)}
|
| 294 |
for fut in as_completed(futs):
|
| 295 |
i = futs[fut]
|
| 296 |
try:
|
| 297 |
res = fut.result()
|
| 298 |
except Exception as e:
|
| 299 |
+
res = {"idx": i, "urls": [], "errors": [str(e)]}
|
| 300 |
+
|
| 301 |
if res["urls"]:
|
| 302 |
save_record(model, aspect, prompts[i-1], res["urls"])
|
| 303 |
+
rows[i-1].success(f"Prompt {i}/{total} β ({len(res['urls'])} images)")
|
| 304 |
all_urls.extend(res["urls"])
|
| 305 |
else:
|
| 306 |
+
err_msg = ", ".join(res.get("errors") or ["No images"])
|
| 307 |
+
rows[i-1].error(f"Prompt {i}/{total} β ({err_msg})")
|
| 308 |
+
|
| 309 |
done += 1
|
| 310 |
progress.progress(done/total, text=f"{done}/{total}")
|
| 311 |
|
|
|
|
| 322 |
def render_json_page():
|
| 323 |
st.subheader("Generate from JSON")
|
| 324 |
show_env_warnings()
|
| 325 |
+
|
| 326 |
up = st.file_uploader("Upload JSON", type=["json"])
|
| 327 |
+
c1, c2, c3 = st.columns([1, 1, 1])
|
| 328 |
+
with c1:
|
| 329 |
model = st.selectbox("Model", list(MODEL_REGISTRY.keys()), 0)
|
| 330 |
+
with c2:
|
| 331 |
aspect = st.selectbox("Aspect", MODEL_REGISTRY[model]["aspect_ratios"], 0)
|
| 332 |
+
with c3:
|
| 333 |
+
num_images = st.slider("Images per prompt", min_value=1, max_value=12, value=1, step=1)
|
| 334 |
|
| 335 |
if up:
|
| 336 |
try:
|
|
|
|
| 338 |
with st.expander("Preview prompts", expanded=False):
|
| 339 |
st.json(prompts)
|
| 340 |
if st.button("Generate", type="primary", use_container_width=True):
|
| 341 |
+
run_batch(prompts, model, aspect, num_images)
|
| 342 |
except Exception as e:
|
| 343 |
st.error(str(e))
|
| 344 |
else:
|
|
|
|
| 373 |
def check_token(tok: str) -> bool:
|
| 374 |
acc = os.getenv("ACCESS_TOKEN")
|
| 375 |
if not acc:
|
| 376 |
+
# No token configured β allow dev access
|
| 377 |
return True
|
| 378 |
return tok == acc
|
| 379 |
|