Spaces:
Running
Running
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +306 -100
src/streamlit_app.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import os
|
| 2 |
import io
|
| 3 |
import zipfile
|
|
@@ -23,29 +24,38 @@ load_dotenv()
|
|
| 23 |
logging.basicConfig(level=logging.INFO)
|
| 24 |
logger = logging.getLogger("imagegen_app")
|
| 25 |
|
|
|
|
|
|
|
|
|
|
| 26 |
REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN")
|
| 27 |
MONGO_URI = os.getenv("MONGO_URI")
|
| 28 |
MONGO_DB = os.getenv("MONGO_DB", "adgenesis_image_text")
|
| 29 |
MONGO_COLLECTION = os.getenv("MONGO_COLLECTION", "creatives")
|
|
|
|
| 30 |
MAX_WORKERS = min(32, (os.cpu_count() or 1) + 4)
|
| 31 |
REQUEST_TIMEOUT = 30
|
| 32 |
RETRY_ATTEMPTS = 3
|
| 33 |
LIBRARY_PAGE_SIZE = 20
|
| 34 |
|
|
|
|
| 35 |
MODEL_REGISTRY: Dict[str, Dict[str, Any]] = {
|
| 36 |
-
"imagegen-4-ultra": {"id": "google/imagen-4-ultra","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"],"param_name": "aspect_ratio"},
|
| 37 |
-
"imagen-4":
|
| 38 |
-
"
|
| 39 |
-
"
|
| 40 |
-
"
|
| 41 |
-
"
|
| 42 |
-
"
|
|
|
|
| 43 |
}
|
| 44 |
|
| 45 |
_thread_local = threading.local()
|
| 46 |
|
|
|
|
|
|
|
|
|
|
| 47 |
def get_mongo_collection():
|
| 48 |
-
if not hasattr(_thread_local,
|
| 49 |
if not MONGO_URI:
|
| 50 |
_thread_local.mongo_collection = None
|
| 51 |
return None
|
|
@@ -53,7 +63,7 @@ def get_mongo_collection():
|
|
| 53 |
client = MongoClient(MONGO_URI, serverSelectionTimeoutMS=3000)
|
| 54 |
db = client[MONGO_DB]
|
| 55 |
collection = db[MONGO_COLLECTION]
|
| 56 |
-
client.admin.command(
|
| 57 |
_thread_local.mongo_collection = collection
|
| 58 |
except Exception as e:
|
| 59 |
logger.error(f"MongoDB connection failed: {e}")
|
|
@@ -61,8 +71,8 @@ def get_mongo_collection():
|
|
| 61 |
return _thread_local.mongo_collection
|
| 62 |
|
| 63 |
def get_s3_client():
|
| 64 |
-
if not hasattr(_thread_local,
|
| 65 |
-
required_vars = ["R2_ENDPOINT","R2_ACCESS_KEY","R2_SECRET_KEY","R2_BUCKET_NAME"]
|
| 66 |
missing = [var for var in required_vars if not os.getenv(var)]
|
| 67 |
if missing:
|
| 68 |
_thread_local.s3_client = None
|
|
@@ -84,24 +94,31 @@ def get_s3_client():
|
|
| 84 |
def get_model_config(model_key: str) -> Optional[Dict[str, Any]]:
|
| 85 |
return MODEL_REGISTRY.get(model_key)
|
| 86 |
|
|
|
|
|
|
|
|
|
|
| 87 |
def upload_to_r2_optimized(image_bytes: bytes) -> Optional[str]:
|
| 88 |
s3_client = get_s3_client()
|
| 89 |
if not s3_client:
|
| 90 |
return None
|
| 91 |
try:
|
| 92 |
filename = f"{uuid4().hex}.png"
|
| 93 |
-
file_key = f"adgenesis_image_file/
|
| 94 |
s3_client.put_object(
|
| 95 |
Bucket=os.getenv("R2_BUCKET_NAME"),
|
| 96 |
Key=file_key,
|
| 97 |
Body=image_bytes,
|
| 98 |
ContentType="image/png",
|
| 99 |
)
|
| 100 |
-
|
|
|
|
| 101 |
except Exception as e:
|
| 102 |
logger.error(f"S3 upload failed: {e}")
|
| 103 |
return None
|
| 104 |
|
|
|
|
|
|
|
|
|
|
| 105 |
def generate_one_image_optimized(model_key: str, prompt: str, aspect_ratio: str) -> List[str]:
|
| 106 |
if not REPLICATE_API_TOKEN:
|
| 107 |
return []
|
|
@@ -113,15 +130,18 @@ def generate_one_image_optimized(model_key: str, prompt: str, aspect_ratio: str)
|
|
| 113 |
ar_param = config["param_name"]
|
| 114 |
inputs = {"prompt": prompt, ar_param: aspect_ratio}
|
| 115 |
output = replicate.run(model_id, input=inputs)
|
|
|
|
| 116 |
if isinstance(output, list) and output:
|
| 117 |
-
|
|
|
|
| 118 |
elif isinstance(output, str):
|
| 119 |
return [output]
|
| 120 |
elif hasattr(output, "url"):
|
| 121 |
return [getattr(output, "url")]
|
|
|
|
| 122 |
except Exception as e:
|
| 123 |
logger.error(f"Replicate error: {e}")
|
| 124 |
-
|
| 125 |
|
| 126 |
def fetch_image_bytes_optimized(url: str) -> Optional[bytes]:
|
| 127 |
for attempt in range(RETRY_ATTEMPTS):
|
|
@@ -135,12 +155,12 @@ def fetch_image_bytes_optimized(url: str) -> Optional[bytes]:
|
|
| 135 |
time.sleep(1)
|
| 136 |
return None
|
| 137 |
|
| 138 |
-
def process_single_image(args: Tuple[str,str,str,int]) -> Dict[str,Any]:
|
| 139 |
model_key, prompt, aspect_ratio, index = args
|
| 140 |
-
result = {"index": index,"success": False,"source_url": None,"r2_url": None,"error": None}
|
| 141 |
urls = generate_one_image_optimized(model_key, prompt, aspect_ratio)
|
| 142 |
if not urls:
|
| 143 |
-
result["error"] = "No URLs returned"
|
| 144 |
return result
|
| 145 |
source_url = urls[0]
|
| 146 |
result["source_url"] = source_url
|
|
@@ -156,136 +176,322 @@ def process_single_image(args: Tuple[str,str,str,int]) -> Dict[str,Any]:
|
|
| 156 |
result["error"] = "Failed to upload to R2"
|
| 157 |
return result
|
| 158 |
|
| 159 |
-
def
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
|
|
|
| 163 |
else:
|
| 164 |
-
return [], [], [
|
| 165 |
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
| 167 |
collection = get_mongo_collection()
|
| 168 |
if collection is None:
|
| 169 |
return None
|
| 170 |
try:
|
| 171 |
-
doc = {
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
except Exception as e:
|
| 174 |
logger.error(f"Mongo insert failed: {e}")
|
| 175 |
return None
|
| 176 |
|
| 177 |
@st.cache_data(ttl=300)
|
| 178 |
-
def query_creatives_optimized(start_dt:datetime,end_dt:datetime,page:int=0)->Tuple[List[Dict[str,Any]],int]:
|
| 179 |
collection = get_mongo_collection()
|
| 180 |
if collection is None:
|
| 181 |
-
return [],0
|
| 182 |
try:
|
| 183 |
-
total_count = collection.count_documents({"created_at":{"$gte":start_dt,"$lt":end_dt},"lob":"
|
| 184 |
-
cursor =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
return list(cursor), total_count
|
| 186 |
except Exception:
|
| 187 |
-
return [],0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
|
| 189 |
-
def display_image_with_download_optimized(url:str):
|
| 190 |
try:
|
| 191 |
-
img_bytes =
|
| 192 |
if not img_bytes:
|
| 193 |
st.error("Failed to load image")
|
| 194 |
return
|
| 195 |
st.image(img_bytes, use_container_width=True)
|
| 196 |
base = os.path.basename(urlparse(url).path) or "image.png"
|
| 197 |
if not os.path.splitext(base)[1]:
|
| 198 |
-
base
|
| 199 |
-
st.download_button(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
except Exception as e:
|
| 201 |
st.error(f"Failed to display image: {e}")
|
| 202 |
|
| 203 |
def display_image_gallery_optimized(urls: List[str]):
|
| 204 |
-
if not urls:
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
|
|
|
|
|
|
| 208 |
display_image_with_download_optimized(url)
|
| 209 |
|
| 210 |
-
def bulk_download_button(urls: List[str], filename="images_bundle.zip"):
|
| 211 |
-
if not urls:
|
|
|
|
| 212 |
zip_buffer = io.BytesIO()
|
| 213 |
-
with zipfile.ZipFile(zip_buffer,"w",compression=zipfile.ZIP_DEFLATED) as
|
| 214 |
-
for
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
base
|
| 220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
zip_buffer.seek(0)
|
| 222 |
-
st.download_button(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
|
| 231 |
def render_json_page():
|
| 232 |
st.subheader("Generate from JSON Prompts")
|
| 233 |
-
up=st.file_uploader("Upload prompts JSON",type=["json"])
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
with
|
| 237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
if up:
|
| 239 |
try:
|
| 240 |
-
|
| 241 |
-
st.
|
| 242 |
-
|
| 243 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
except Exception as e:
|
| 245 |
-
st.error(
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
st.
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
if all_urls:
|
| 265 |
-
st.subheader("
|
| 266 |
-
|
|
|
|
|
|
|
| 267 |
|
|
|
|
|
|
|
|
|
|
| 268 |
def render_library_page():
|
| 269 |
st.subheader("Creative Library")
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
for rec in records:
|
| 278 |
-
urls=rec.get("urls",[])
|
| 279 |
-
if urls:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
|
|
|
|
|
|
|
|
|
|
| 281 |
def main_app():
|
|
|
|
| 282 |
st.title("File-to-Image Generator")
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
|
|
|
|
|
|
|
|
|
| 286 |
|
| 287 |
def main():
|
| 288 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
|
| 290 |
-
if __name__=="__main__":
|
| 291 |
main()
|
|
|
|
| 1 |
+
# main.py
|
| 2 |
import os
|
| 3 |
import io
|
| 4 |
import zipfile
|
|
|
|
| 24 |
logging.basicConfig(level=logging.INFO)
|
| 25 |
logger = logging.getLogger("imagegen_app")
|
| 26 |
|
| 27 |
+
# ----------------------------
|
| 28 |
+
# Config / Constants
|
| 29 |
+
# ----------------------------
|
| 30 |
REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN")
|
| 31 |
MONGO_URI = os.getenv("MONGO_URI")
|
| 32 |
MONGO_DB = os.getenv("MONGO_DB", "adgenesis_image_text")
|
| 33 |
MONGO_COLLECTION = os.getenv("MONGO_COLLECTION", "creatives")
|
| 34 |
+
|
| 35 |
MAX_WORKERS = min(32, (os.cpu_count() or 1) + 4)
|
| 36 |
REQUEST_TIMEOUT = 30
|
| 37 |
RETRY_ATTEMPTS = 3
|
| 38 |
LIBRARY_PAGE_SIZE = 20
|
| 39 |
|
| 40 |
+
# Model registry (subset with common ARs)
|
| 41 |
MODEL_REGISTRY: Dict[str, Dict[str, Any]] = {
|
| 42 |
+
"imagegen-4-ultra": {"id": "google/imagen-4-ultra", "aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"], "param_name": "aspect_ratio"},
|
| 43 |
+
"imagen-4": {"id": "google/imagen-4", "aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"], "param_name": "aspect_ratio"},
|
| 44 |
+
"nano-banana": {"id": "google/nano-banana", "aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"], "param_name": "aspect_ratio"},
|
| 45 |
+
"qwen": {"id": "qwen/qwen-image", "aspect_ratios": ["1:1","16:9","9:16","3:4","4:3","3:2","2:3"], "param_name": "aspect_ratio"},
|
| 46 |
+
"seedream-3": {"id": "bytedance/seedream-3", "aspect_ratios": ["1:1","16:9","9:16","3:4","4:3","3:2","2:3","21:9"], "param_name": "aspect_ratio"},
|
| 47 |
+
"recraft-v3": {"id": "recraft-ai/recraft-v3", "aspect_ratios": ["1:1","4:3","3:4","3:2","2:3","16:9","9:16"], "param_name": "aspect_ratio"},
|
| 48 |
+
"photon": {"id": "luma/photon", "aspect_ratios": ["1:1","3:4","4:3","9:16","16:9","21:9"], "param_name": "aspect_ratio"},
|
| 49 |
+
"ideogram-v3-quality":{"id": "ideogram-ai/ideogram-v3-quality", "aspect_ratios": ["1:1","16:9","9:16","2:3","3:2","4:5","5:4"], "param_name": "aspect_ratio"},
|
| 50 |
}
|
| 51 |
|
| 52 |
_thread_local = threading.local()
|
| 53 |
|
| 54 |
+
# ----------------------------
|
| 55 |
+
# Infra helpers (Mongo / S3)
|
| 56 |
+
# ----------------------------
|
| 57 |
def get_mongo_collection():
|
| 58 |
+
if not hasattr(_thread_local, 'mongo_collection'):
|
| 59 |
if not MONGO_URI:
|
| 60 |
_thread_local.mongo_collection = None
|
| 61 |
return None
|
|
|
|
| 63 |
client = MongoClient(MONGO_URI, serverSelectionTimeoutMS=3000)
|
| 64 |
db = client[MONGO_DB]
|
| 65 |
collection = db[MONGO_COLLECTION]
|
| 66 |
+
client.admin.command('ping')
|
| 67 |
_thread_local.mongo_collection = collection
|
| 68 |
except Exception as e:
|
| 69 |
logger.error(f"MongoDB connection failed: {e}")
|
|
|
|
| 71 |
return _thread_local.mongo_collection
|
| 72 |
|
| 73 |
def get_s3_client():
|
| 74 |
+
if not hasattr(_thread_local, 's3_client'):
|
| 75 |
+
required_vars = ["R2_ENDPOINT", "R2_ACCESS_KEY", "R2_SECRET_KEY", "R2_BUCKET_NAME"]
|
| 76 |
missing = [var for var in required_vars if not os.getenv(var)]
|
| 77 |
if missing:
|
| 78 |
_thread_local.s3_client = None
|
|
|
|
| 94 |
def get_model_config(model_key: str) -> Optional[Dict[str, Any]]:
|
| 95 |
return MODEL_REGISTRY.get(model_key)
|
| 96 |
|
| 97 |
+
# ----------------------------
|
| 98 |
+
# R2 upload
|
| 99 |
+
# ----------------------------
|
| 100 |
def upload_to_r2_optimized(image_bytes: bytes) -> Optional[str]:
|
| 101 |
s3_client = get_s3_client()
|
| 102 |
if not s3_client:
|
| 103 |
return None
|
| 104 |
try:
|
| 105 |
filename = f"{uuid4().hex}.png"
|
| 106 |
+
file_key = f"adgenesis_image_file/balraaj/images/{filename}"
|
| 107 |
s3_client.put_object(
|
| 108 |
Bucket=os.getenv("R2_BUCKET_NAME"),
|
| 109 |
Key=file_key,
|
| 110 |
Body=image_bytes,
|
| 111 |
ContentType="image/png",
|
| 112 |
)
|
| 113 |
+
r2_url = f'{os.getenv("NEW_BASE").rstrip("/")}/{file_key}'
|
| 114 |
+
return r2_url
|
| 115 |
except Exception as e:
|
| 116 |
logger.error(f"S3 upload failed: {e}")
|
| 117 |
return None
|
| 118 |
|
| 119 |
+
# ----------------------------
|
| 120 |
+
# Generation & fetching
|
| 121 |
+
# ----------------------------
|
| 122 |
def generate_one_image_optimized(model_key: str, prompt: str, aspect_ratio: str) -> List[str]:
|
| 123 |
if not REPLICATE_API_TOKEN:
|
| 124 |
return []
|
|
|
|
| 130 |
ar_param = config["param_name"]
|
| 131 |
inputs = {"prompt": prompt, ar_param: aspect_ratio}
|
| 132 |
output = replicate.run(model_id, input=inputs)
|
| 133 |
+
# Normalize to list[str]
|
| 134 |
if isinstance(output, list) and output:
|
| 135 |
+
first = output[0]
|
| 136 |
+
return [getattr(first, "url", str(first))]
|
| 137 |
elif isinstance(output, str):
|
| 138 |
return [output]
|
| 139 |
elif hasattr(output, "url"):
|
| 140 |
return [getattr(output, "url")]
|
| 141 |
+
return []
|
| 142 |
except Exception as e:
|
| 143 |
logger.error(f"Replicate error: {e}")
|
| 144 |
+
return []
|
| 145 |
|
| 146 |
def fetch_image_bytes_optimized(url: str) -> Optional[bytes]:
|
| 147 |
for attempt in range(RETRY_ATTEMPTS):
|
|
|
|
| 155 |
time.sleep(1)
|
| 156 |
return None
|
| 157 |
|
| 158 |
+
def process_single_image(args: Tuple[str, str, str, int]) -> Dict[str, Any]:
|
| 159 |
model_key, prompt, aspect_ratio, index = args
|
| 160 |
+
result = {"index": index, "success": False, "source_url": None, "r2_url": None, "error": None}
|
| 161 |
urls = generate_one_image_optimized(model_key, prompt, aspect_ratio)
|
| 162 |
if not urls:
|
| 163 |
+
result["error"] = "No URLs returned from generation"
|
| 164 |
return result
|
| 165 |
source_url = urls[0]
|
| 166 |
result["source_url"] = source_url
|
|
|
|
| 176 |
result["error"] = "Failed to upload to R2"
|
| 177 |
return result
|
| 178 |
|
| 179 |
+
def generate_one_per_prompt(model_key: str, aspect_ratio: str, prompt: str) -> Tuple[List[str], List[str], List[str]]:
|
| 180 |
+
"""One image per prompt (no parallel within a prompt)."""
|
| 181 |
+
res = process_single_image((model_key, prompt, aspect_ratio, 0))
|
| 182 |
+
if res["success"]:
|
| 183 |
+
return [res["r2_url"]], [res["source_url"]], []
|
| 184 |
else:
|
| 185 |
+
return [], [], [res["error"] or "Generation failed"]
|
| 186 |
|
| 187 |
+
# ----------------------------
|
| 188 |
+
# Persistence
|
| 189 |
+
# ----------------------------
|
| 190 |
+
def save_creative_record_optimized(model_key: str, aspect_ratio: str, prompt: str, urls: List[str]) -> Optional[str]:
|
| 191 |
collection = get_mongo_collection()
|
| 192 |
if collection is None:
|
| 193 |
return None
|
| 194 |
try:
|
| 195 |
+
doc = {
|
| 196 |
+
"model": model_key,
|
| 197 |
+
"aspect_ratio": aspect_ratio,
|
| 198 |
+
"prompt": prompt,
|
| 199 |
+
"urls": urls,
|
| 200 |
+
"num_images": len(urls),
|
| 201 |
+
"lob": "balraaj",
|
| 202 |
+
"created_at": datetime.utcnow()
|
| 203 |
+
}
|
| 204 |
+
ins = collection.insert_one(doc)
|
| 205 |
+
return str(ins.inserted_id)
|
| 206 |
except Exception as e:
|
| 207 |
logger.error(f"Mongo insert failed: {e}")
|
| 208 |
return None
|
| 209 |
|
| 210 |
@st.cache_data(ttl=300)
|
| 211 |
+
def query_creatives_optimized(start_dt: datetime, end_dt: datetime, page: int = 0) -> Tuple[List[Dict[str, Any]], int]:
|
| 212 |
collection = get_mongo_collection()
|
| 213 |
if collection is None:
|
| 214 |
+
return [], 0
|
| 215 |
try:
|
| 216 |
+
total_count = collection.count_documents({"created_at": {"$gte": start_dt, "$lt": end_dt}, "lob": "balraaj"})
|
| 217 |
+
cursor = (
|
| 218 |
+
collection.find({"created_at": {"$gte": start_dt, "$lt": end_dt}, "lob": "balraaj"})
|
| 219 |
+
.sort("created_at", -1)
|
| 220 |
+
.skip(page * LIBRARY_PAGE_SIZE)
|
| 221 |
+
.limit(LIBRARY_PAGE_SIZE)
|
| 222 |
+
)
|
| 223 |
return list(cursor), total_count
|
| 224 |
except Exception:
|
| 225 |
+
return [], 0
|
| 226 |
+
|
| 227 |
+
# ----------------------------
|
| 228 |
+
# UI helpers: images
|
| 229 |
+
# ----------------------------
|
| 230 |
+
@st.cache_data(ttl=3600)
|
| 231 |
+
def get_image_bytes_cached(url: str) -> Optional[bytes]:
|
| 232 |
+
return fetch_image_bytes_optimized(url)
|
| 233 |
|
| 234 |
+
def display_image_with_download_optimized(url: str):
|
| 235 |
try:
|
| 236 |
+
img_bytes = get_image_bytes_cached(url)
|
| 237 |
if not img_bytes:
|
| 238 |
st.error("Failed to load image")
|
| 239 |
return
|
| 240 |
st.image(img_bytes, use_container_width=True)
|
| 241 |
base = os.path.basename(urlparse(url).path) or "image.png"
|
| 242 |
if not os.path.splitext(base)[1]:
|
| 243 |
+
base = f"{base}.png"
|
| 244 |
+
st.download_button(
|
| 245 |
+
label="Download image",
|
| 246 |
+
data=img_bytes,
|
| 247 |
+
file_name=base,
|
| 248 |
+
mime="image/png",
|
| 249 |
+
use_container_width=True
|
| 250 |
+
)
|
| 251 |
except Exception as e:
|
| 252 |
st.error(f"Failed to display image: {e}")
|
| 253 |
|
| 254 |
def display_image_gallery_optimized(urls: List[str]):
|
| 255 |
+
if not urls:
|
| 256 |
+
return
|
| 257 |
+
num_cols = min(4, max(1, len(urls)))
|
| 258 |
+
cols = st.columns(num_cols)
|
| 259 |
+
for i, url in enumerate(urls):
|
| 260 |
+
with cols[i % num_cols]:
|
| 261 |
display_image_with_download_optimized(url)
|
| 262 |
|
| 263 |
+
def bulk_download_button(urls: List[str], filename: str = "images_bundle.zip"):
|
| 264 |
+
if not urls:
|
| 265 |
+
return
|
| 266 |
zip_buffer = io.BytesIO()
|
| 267 |
+
with zipfile.ZipFile(zip_buffer, "w", compression=zipfile.ZIP_DEFLATED) as zip_file:
|
| 268 |
+
for idx, url in enumerate(urls, 1):
|
| 269 |
+
try:
|
| 270 |
+
img_bytes = fetch_image_bytes_optimized(url)
|
| 271 |
+
if img_bytes:
|
| 272 |
+
path = urlparse(url).path
|
| 273 |
+
base = os.path.basename(path) or f"image_{idx}.png"
|
| 274 |
+
if not os.path.splitext(base)[1]:
|
| 275 |
+
base = f"{base}.png"
|
| 276 |
+
zip_file.writestr(base, img_bytes)
|
| 277 |
+
except Exception:
|
| 278 |
+
pass
|
| 279 |
zip_buffer.seek(0)
|
| 280 |
+
st.download_button(
|
| 281 |
+
"Download All Images",
|
| 282 |
+
data=zip_buffer,
|
| 283 |
+
file_name=filename,
|
| 284 |
+
mime="application/zip",
|
| 285 |
+
use_container_width=True
|
| 286 |
+
)
|
| 287 |
|
| 288 |
+
# ----------------------------
|
| 289 |
+
# JSON loader (STRICT)
|
| 290 |
+
# ----------------------------
|
| 291 |
+
def load_json_prompts(file) -> List[Dict[str, Any]]:
|
| 292 |
+
raw = file.getvalue().decode("utf-8", errors="replace")
|
| 293 |
+
data = json.loads(raw)
|
| 294 |
+
if not isinstance(data, dict) or "prompts" not in data or not isinstance(data["prompts"], list):
|
| 295 |
+
raise ValueError("Invalid JSON. Expected an object with a 'prompts' array of strings.")
|
| 296 |
+
prompts_out: List[Dict[str, Any]] = []
|
| 297 |
+
for i, item in enumerate(data["prompts"], 1):
|
| 298 |
+
if not isinstance(item, str) or not item.strip():
|
| 299 |
+
raise ValueError(f"'prompts[{i-1}]' must be a non-empty string.")
|
| 300 |
+
prompts_out.append({"id": f"p{i}", "content": item.strip()})
|
| 301 |
+
return prompts_out
|
| 302 |
+
|
| 303 |
+
# ----------------------------
|
| 304 |
+
# JSON page (parallel across prompts)
|
| 305 |
+
# ----------------------------
|
| 306 |
+
def _run_single_prompt(idx: int, prompt_text: str, model_key: str, aspect_ratio: str):
|
| 307 |
+
r2_urls, src_urls, gen_errors = generate_one_per_prompt(model_key, aspect_ratio, prompt_text)
|
| 308 |
+
rec_id = None
|
| 309 |
+
if r2_urls:
|
| 310 |
+
rec_id = save_creative_record_optimized(model_key, aspect_ratio, prompt_text, r2_urls)
|
| 311 |
+
return {
|
| 312 |
+
"idx": idx,
|
| 313 |
+
"prompt": prompt_text,
|
| 314 |
+
"r2_urls": r2_urls,
|
| 315 |
+
"src_urls": src_urls,
|
| 316 |
+
"errors": gen_errors,
|
| 317 |
+
"rec_id": rec_id,
|
| 318 |
+
}
|
| 319 |
|
| 320 |
def render_json_page():
|
| 321 |
st.subheader("Generate from JSON Prompts")
|
| 322 |
+
up = st.file_uploader("Upload prompts JSON", type=["json"])
|
| 323 |
+
|
| 324 |
+
col1, col2 = st.columns([1, 1])
|
| 325 |
+
with col1:
|
| 326 |
+
default_model = st.selectbox("Default Model", list(MODEL_REGISTRY.keys()), index=0)
|
| 327 |
+
with col2:
|
| 328 |
+
aspect_options = MODEL_REGISTRY[default_model]["aspect_ratios"]
|
| 329 |
+
default_aspect = st.selectbox("Default Aspect Ratio", aspect_options, index=0, key="json_default_ar")
|
| 330 |
+
|
| 331 |
+
debug_mode = st.checkbox("Debug Mode", value=False, key="json_debug")
|
| 332 |
+
|
| 333 |
if up:
|
| 334 |
try:
|
| 335 |
+
prompts_list = load_json_prompts(up)
|
| 336 |
+
with st.expander("Preview normalized prompts", expanded=False):
|
| 337 |
+
st.json(prompts_list, expanded=False)
|
| 338 |
+
|
| 339 |
+
if st.button("Generate for All Prompts", type="primary", use_container_width=True):
|
| 340 |
+
handle_bulk_json_generation_parallel(prompts_list, default_model, default_aspect, debug_mode)
|
| 341 |
+
except json.JSONDecodeError as e:
|
| 342 |
+
st.error(f"Invalid JSON: {e}")
|
| 343 |
except Exception as e:
|
| 344 |
+
st.error(f"Failed to read prompts: {e}")
|
| 345 |
+
else:
|
| 346 |
+
st.caption('Expected format: { "prompts": ["prompt 1", "prompt 2", ...] }')
|
| 347 |
+
|
| 348 |
+
def handle_bulk_json_generation_parallel(prompts: List[Dict[str, str]], default_model: str, default_aspect: str, debug: bool):
|
| 349 |
+
if not REPLICATE_API_TOKEN:
|
| 350 |
+
st.error("Missing REPLICATE_API_TOKEN. Set it as an environment variable.")
|
| 351 |
+
return
|
| 352 |
+
total = len(prompts)
|
| 353 |
+
if total == 0:
|
| 354 |
+
st.info("No prompts to process.")
|
| 355 |
+
return
|
| 356 |
+
|
| 357 |
+
# Placeholders for stable on-page order
|
| 358 |
+
blocks = [st.container(border=True) for _ in range(total)]
|
| 359 |
+
progress = st.progress(0, text=f"Starting batch • 0/{total}")
|
| 360 |
+
|
| 361 |
+
all_urls: List[str] = []
|
| 362 |
+
completed = 0
|
| 363 |
+
|
| 364 |
+
max_workers = min(MAX_WORKERS, max(2, (os.cpu_count() or 2)))
|
| 365 |
+
|
| 366 |
+
with st.spinner("Generating images..."):
|
| 367 |
+
futures = {}
|
| 368 |
+
with ThreadPoolExecutor(max_workers=max_workers) as ex:
|
| 369 |
+
for i, p in enumerate(prompts, 1):
|
| 370 |
+
prompt_text = p.get("content", "").strip()
|
| 371 |
+
if not prompt_text:
|
| 372 |
+
# render immediately as invalid
|
| 373 |
+
with blocks[i-1]:
|
| 374 |
+
st.markdown(f"**Prompt {i}/{total}** — (empty)")
|
| 375 |
+
st.error("Prompt text is empty. Skipping.")
|
| 376 |
+
completed += 1
|
| 377 |
+
progress.progress(completed / total, text=f"Processed {completed}/{total}")
|
| 378 |
+
continue
|
| 379 |
+
futures[ex.submit(_run_single_prompt, i, prompt_text, default_model, default_aspect)] = i
|
| 380 |
+
|
| 381 |
+
for fut in as_completed(futures):
|
| 382 |
+
i = futures[fut]
|
| 383 |
+
try:
|
| 384 |
+
res = fut.result()
|
| 385 |
+
except Exception as e:
|
| 386 |
+
res = {"idx": i, "prompt": "", "r2_urls": [], "src_urls": [], "errors": [str(e)], "rec_id": None}
|
| 387 |
+
|
| 388 |
+
with blocks[i-1]:
|
| 389 |
+
st.markdown(f"**Prompt {i}/{total}** — Model: `{default_model}` • Aspect: `{default_aspect}` • Num: `1`")
|
| 390 |
+
st.code(res.get("prompt") or "(empty)", language="markdown")
|
| 391 |
+
|
| 392 |
+
if res["r2_urls"]:
|
| 393 |
+
st.success(f"Generated 1 image. DB: {res['rec_id'] or 'N/A'}")
|
| 394 |
+
display_image_gallery_optimized(res["r2_urls"])
|
| 395 |
+
bulk_download_button(res["r2_urls"], filename=f"prompt_{i}_image.zip")
|
| 396 |
+
all_urls.extend(res["r2_urls"])
|
| 397 |
+
elif res["src_urls"]:
|
| 398 |
+
st.warning("Image generated but R2 upload failed. Showing original:")
|
| 399 |
+
display_image_gallery_optimized(res["src_urls"])
|
| 400 |
+
bulk_download_button(res["src_urls"], filename=f"prompt_{i}_image.zip")
|
| 401 |
+
all_urls.extend(res["src_urls"])
|
| 402 |
+
else:
|
| 403 |
+
st.error("No image was generated for this prompt.")
|
| 404 |
+
|
| 405 |
+
if res.get("errors") and debug:
|
| 406 |
+
for e in res["errors"]:
|
| 407 |
+
st.error(e)
|
| 408 |
+
|
| 409 |
+
completed += 1
|
| 410 |
+
progress.progress(completed / total, text=f"Processed {completed}/{total}")
|
| 411 |
+
|
| 412 |
+
# Final all-images gallery & ZIP
|
| 413 |
if all_urls:
|
| 414 |
+
st.subheader("All Images Gallery")
|
| 415 |
+
display_image_gallery_optimized(all_urls)
|
| 416 |
+
st.subheader("Download All Generated")
|
| 417 |
+
bulk_download_button(all_urls, filename="all_prompts_images.zip")
|
| 418 |
|
| 419 |
+
# ----------------------------
|
| 420 |
+
# Creative Library page
|
| 421 |
+
# ----------------------------
|
| 422 |
def render_library_page():
|
| 423 |
st.subheader("Creative Library")
|
| 424 |
+
if "library_page" not in st.session_state:
|
| 425 |
+
st.session_state.library_page = 0
|
| 426 |
+
|
| 427 |
+
today_utc = datetime.utcnow().date()
|
| 428 |
+
default_start = today_utc - timedelta(days=30)
|
| 429 |
+
|
| 430 |
+
c1, c2, c3 = st.columns([1, 1, 1])
|
| 431 |
+
with c1:
|
| 432 |
+
start_date: date = st.date_input("Start date", value=default_start)
|
| 433 |
+
with c2:
|
| 434 |
+
end_date: date = st.date_input("End date", value=today_utc)
|
| 435 |
+
with c3:
|
| 436 |
+
if st.button("Apply Filters", use_container_width=True):
|
| 437 |
+
st.session_state.library_page = 0
|
| 438 |
+
st.cache_data.clear()
|
| 439 |
+
|
| 440 |
+
start_dt = datetime.combine(start_date, datetime.min.time())
|
| 441 |
+
end_dt = datetime.combine(end_date + timedelta(days=1), datetime.min.time())
|
| 442 |
+
|
| 443 |
+
records, total_count = query_creatives_optimized(start_dt, end_dt, st.session_state.library_page)
|
| 444 |
+
if not records and st.session_state.library_page == 0:
|
| 445 |
+
st.info("No creatives found for the selected dates.")
|
| 446 |
+
return
|
| 447 |
+
|
| 448 |
+
st.caption(f"Total items: {total_count}")
|
| 449 |
+
# simple gallery by record
|
| 450 |
for rec in records:
|
| 451 |
+
urls = rec.get("urls", []) or []
|
| 452 |
+
if urls:
|
| 453 |
+
display_image_gallery_optimized(urls)
|
| 454 |
+
|
| 455 |
+
# ----------------------------
|
| 456 |
+
# Auth
|
| 457 |
+
# ----------------------------
|
| 458 |
+
@lru_cache(maxsize=1)
|
| 459 |
+
def check_token_cached(user_token: str) -> Tuple[bool, str]:
|
| 460 |
+
ACCESS_TOKEN = os.getenv("ACCESS_TOKEN")
|
| 461 |
+
if not ACCESS_TOKEN:
|
| 462 |
+
return False, "Server error: Access token not configured."
|
| 463 |
+
if user_token == ACCESS_TOKEN:
|
| 464 |
+
return True, ""
|
| 465 |
+
return False, "Invalid token."
|
| 466 |
|
| 467 |
+
# ----------------------------
|
| 468 |
+
# App shell
|
| 469 |
+
# ----------------------------
|
| 470 |
def main_app():
|
| 471 |
+
st.set_page_config(page_title="File-to-Image • Creative Library", layout="wide")
|
| 472 |
st.title("File-to-Image Generator")
|
| 473 |
+
with st.sidebar:
|
| 474 |
+
page = st.radio("Navigation", ["Generate from JSON", "Creative Library"], index=0)
|
| 475 |
+
if page == "Generate from JSON":
|
| 476 |
+
render_json_page()
|
| 477 |
+
else:
|
| 478 |
+
render_library_page()
|
| 479 |
|
| 480 |
def main():
|
| 481 |
+
if "authenticated" not in st.session_state:
|
| 482 |
+
st.session_state["authenticated"] = False
|
| 483 |
+
if not st.session_state["authenticated"]:
|
| 484 |
+
st.markdown("## Access Required")
|
| 485 |
+
token_input = st.text_input("Enter Access Token", type="password")
|
| 486 |
+
if st.button("Unlock App"):
|
| 487 |
+
ok, error_msg = check_token_cached(token_input)
|
| 488 |
+
if ok:
|
| 489 |
+
st.session_state["authenticated"] = True
|
| 490 |
+
st.rerun()
|
| 491 |
+
else:
|
| 492 |
+
st.error(error_msg)
|
| 493 |
+
else:
|
| 494 |
+
main_app()
|
| 495 |
|
| 496 |
+
if __name__ == "__main__":
|
| 497 |
main()
|