File size: 15,421 Bytes
4c69ac5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
"""
Ad Creative Generator (Fourth Flow)
-----------------------------------
Flow:
  1) Scrape product URL to fetch product image
  2) Run one Gemini vision call for analysis + generation prompt
  3) Generate image via nano-banana-2 with product/template/logo references
  4) Save creative + analysis payload to output_creatives/
"""

import base64
import json
import os
import shutil
import sys
import time
import uuid
from pathlib import Path

import requests
from dotenv import load_dotenv
from google import genai
from google.genai import types

load_dotenv()

# Local backend imports
BACKEND_DIR = Path(__file__).resolve().parent / "backend"
if str(BACKEND_DIR) not in sys.path:
    sys.path.insert(0, str(BACKEND_DIR))

from app.replicate_image import generate_image_sync  # noqa: E402  # type: ignore[reportMissingImports]
from app.scraper import scrape_product  # noqa: E402  # type: ignore[reportMissingImports]

GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "YOUR_GEMINI_API_KEY")
GEMINI_MODEL = os.getenv("GEMINI_MODEL", "gemini-3.1-pro-preview")
MODEL_KEY = "nano-banana-2"
REPLICATE_API_KEY = os.getenv("REPLICATE_API_KEY") or os.getenv("REPLICATE_API_TOKEN") or "YOUR_REPLICATE_API_KEY"
GENERATION_MAX_ATTEMPTS = 3
GENERATION_RETRY_DELAY_SEC = 4

OUTPUT_DIR = Path("output_creatives")
OUTPUT_DIR.mkdir(exist_ok=True)
REFERENCE_UPLOAD_DIR = BACKEND_DIR / "reference_uploads"
REFERENCE_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)


VISION_USER_PROMPT = """You are an expert creative director and ad-tech specialist.
Analyse these images and return ONLY valid JSON with exactly these keys:
- template_analysis
- product_description
- brand_info
- image_generation_prompt
- negative_prompt

Use first image as template/layout reference, second image as product, third image (if present) as logo/brand cue.
CRITICAL: Treat the template like a copy-paste layout lock.
- Keep the same composition, block structure, spacing, framing, and text placement zones as template.
- Keep the same visual style direction (editorial/commercial look), not a new design.
- Only change the content inside the template: product subject, brand/logo presence, and copy text.
- Do NOT invent a different layout or scene format.

Make image_generation_prompt optimized for premium photorealistic jewelry ad output"""


def load_image_as_b64(path: str) -> str:
    with open(path, "rb") as f:
        return base64.b64encode(f.read()).decode("utf-8")


def _is_supported_image(path: Path) -> bool:
    return path.suffix.lower() in {".png", ".jpg", ".jpeg", ".webp"}


def _reference_image_base_url() -> str:
    return (os.getenv("BASE_URL") or "http://localhost:8002").rstrip("/")


def _publish_local_reference(path: str) -> str:
    src = Path(path)
    if not src.exists():
        raise FileNotFoundError(f"reference image not found: {path}")
    if not _is_supported_image(src):
        raise ValueError(f"unsupported reference image format: {path}")
    name = f"{uuid.uuid4().hex}{src.suffix.lower()}"
    dst = REFERENCE_UPLOAD_DIR / name
    shutil.copy2(src, dst)
    return f"{_reference_image_base_url()}/api/serve-reference/{name}"


def resolve_template_path(template_path: str | None, examples_dir: str, example_file: str | None) -> str:
    if template_path:
        p = Path(template_path)
        if not p.exists():
            raise FileNotFoundError(f"template image not found: {template_path}")
        if not _is_supported_image(p):
            raise ValueError(f"unsupported template image format: {template_path}")
        return str(p)
    if example_file:
        p = Path(examples_dir) / example_file
        if not p.exists():
            raise FileNotFoundError(f"example template not found: {p}")
        if not _is_supported_image(p):
            raise ValueError(f"unsupported example template format: {p}")
        return str(p)
    raise ValueError("Provide either --template or --example-file.")


def list_template_paths(examples_dir: str, bulk_limit: int = 0) -> list[str]:
    p = Path(examples_dir)
    if not p.exists() or not p.is_dir():
        raise FileNotFoundError(f"examples dir not found: {examples_dir}")
    files = sorted(
        [f for f in p.iterdir() if f.is_file() and _is_supported_image(f)],
        key=lambda x: x.name.lower(),
    )
    if not files:
        raise ValueError(f"No supported template images found in: {examples_dir}")
    if bulk_limit > 0:
        files = files[:bulk_limit]
    return [str(f) for f in files]


def scrape_product_image_url(product_url: str) -> tuple[str, dict]:
    data = scrape_product(product_url)
    images = [u.strip() for u in (data.get("product_images") or "").split(",") if u.strip()]
    first = next((u for u in images if u.startswith("http://") or u.startswith("https://")), "")
    if not first:
        raise ValueError("No valid product image URL found after scraping.")
    return first, data


def analyse_images(template_path: str, product_image_url: str, logo_path: str | None = None) -> dict:
    if GEMINI_API_KEY == "YOUR_GEMINI_API_KEY":
        raise RuntimeError("GEMINI_API_KEY is not set.")
    client = genai.Client(api_key=GEMINI_API_KEY)
    template_mime = "image/png" if Path(template_path).suffix.lower() == ".png" else "image/jpeg"
    parts: list = [
        VISION_USER_PROMPT,
        {"inline_data": {"mime_type": template_mime, "data": load_image_as_b64(template_path)}},
        {"file_data": {"mime_type": "image/jpeg", "file_uri": product_image_url}},
    ]
    if logo_path:
        logo_mime = "image/png" if Path(logo_path).suffix.lower() == ".png" else "image/jpeg"
        parts.append({"inline_data": {"mime_type": logo_mime, "data": load_image_as_b64(logo_path)}})
    response = client.models.generate_content(
        model=GEMINI_MODEL,
        contents=parts,
        config=types.GenerateContentConfig(
            temperature=0.35,
            response_mime_type="application/json",
            thinking_config=types.ThinkingConfig(thinking_level="medium"),
        ),
    )
    raw = (response.text or "").strip()
    if raw.startswith("```"):
        raw = raw.strip().strip("`")
        if raw.lower().startswith("json"):
            raw = raw[4:].strip()
    return json.loads(raw)


def build_base_prompt_from_analysis(analysis: dict) -> str:
    direct = (analysis.get("image_generation_prompt") or "").strip()
    product_lock = (
        "CRITICAL PRODUCT LOCK: Keep the exact same product from the reference image. "
        "Do not change product type, silhouette, geometry, metal tone, gemstone colors, gemstone count, "
        "stone shapes, setting style, proportions, or signature details. "
        "No redesign, no substitutions, no style drift."
    )
    logic_lock = (
        "LOGICAL REALISM LOCK: Final creative must be physically and contextually believable. "
        "Jewelry placement must be natural (ring on a finger or realistic display surface), "
        "hand anatomy must be correct, perspective/scale must be coherent, lighting and shadows must match scene geometry, "
        "materials must look real (metal reflectance and gemstone refraction), and text placement must feel intentional/readable. "
        "Avoid impossible poses, floating objects, mismatched reflections, or visually confusing composition."
    )
    if direct:
        return (
            f"{direct} "
            f"{product_lock} "
            f"{logic_lock} "
            "Strictly preserve the template layout and composition one-to-one; "
            "only replace content (product/copy/logo) inside the same structure."
        ).strip()
    synthesized = (
        "Create a premium photorealistic jewelry ad. "
        f"Template guidance: {analysis.get('template_analysis', '')}. "
        f"Product guidance: {analysis.get('product_description', '')}. "
        f"Brand guidance: {analysis.get('brand_info', '')}. "
        "Luxury tone, clear hierarchy, readable text, clean composition."
    )
    neg = (analysis.get("negative_prompt") or "").strip()
    if neg:
        synthesized += f" Avoid: {neg}."
    print("  ⚠️  image_generation_prompt missing; synthesized from analysis.")
    return (
        f"{synthesized} "
        f"{product_lock} "
        f"{logic_lock} "
        "Strictly preserve the template layout and composition one-to-one; "
        "only replace content (product/copy/logo) inside the same structure."
    ).strip()


def generate_with_nano_banana(base_prompt: str, reference_image_urls: list[str], width: int, height: int, num_outputs: int) -> list[str]:
    os.environ["REPLICATE_API_TOKEN"] = REPLICATE_API_KEY
    refs = [u for i, u in enumerate(reference_image_urls) if u and u not in reference_image_urls[:i]]
    urls: list[str] = []
    for _ in range(num_outputs):
        final_url = None
        final_err = "Image generation failed."
        for attempt in range(1, GENERATION_MAX_ATTEMPTS + 1):
            url, err = generate_image_sync(
                prompt=base_prompt,
                model_key=MODEL_KEY,
                width=width,
                height=height,
                reference_image_urls=refs,
            )
            if url and not err:
                final_url = url
                break
            final_err = err or "Image generation returned no URL."
            print(f"  ⚠️  Attempt {attempt}/{GENERATION_MAX_ATTEMPTS} failed: {final_err}")
            if attempt < GENERATION_MAX_ATTEMPTS:
                time.sleep(GENERATION_RETRY_DELAY_SEC)
        if not final_url:
            raise RuntimeError(f"Image generation failed after {GENERATION_MAX_ATTEMPTS} attempts: {final_err}")
        urls.append(final_url)
    return urls


def save_image_from_url(url: str, filename: str) -> Path:
    resp = requests.get(url, timeout=60)
    resp.raise_for_status()
    out = OUTPUT_DIR / filename
    out.write_bytes(resp.content)
    return out


def generate_ad_creative(template_path: str, product_url: str, logo_path: str | None, num_outputs: int, width: int, height: int) -> list[Path]:
    print("\n" + "═" * 56)
    print("   AD CREATIVE GENERATOR  •  Gemini + Nano Banana 2")
    print("═" * 56)
    print(f"  🧩  Template image: {template_path}")
    print(f"  🌍  Reference base URL: {_reference_image_base_url()}")

    print("\n[0/3] 🌐  Scraping product page …")
    product_image_url, product_data = scrape_product_image_url(product_url)
    print(f"  ✅  Product: {product_data.get('product_name', '')}")
    print(f"  ✅  Product image: {product_image_url}")

    template_ref = _publish_local_reference(template_path)
    print(f"  ✅  Published template reference: {template_ref}")
    logo_ref = None
    if logo_path:
        logo_ref = _publish_local_reference(logo_path)
        print(f"  ✅  Published logo reference: {logo_ref}")

    print(f"\n[1/3] 🔍  Analysing images + building prompt with {GEMINI_MODEL} …")
    analysis = analyse_images(template_path, product_image_url, logo_path)
    print("  ✅  Analysis complete.")
    base_prompt = build_base_prompt_from_analysis(analysis)

    ts = int(time.time())
    payload = {
        "analysis": analysis,
        "meta": {
            "product_url": product_url,
            "selected_product_image_url": product_image_url,
            "template_reference_url": template_ref,
            "logo_reference_url": logo_ref,
            "used_template_image": template_path,
            "product_name": product_data.get("product_name", ""),
            "model_key": MODEL_KEY,
            "template_path": template_path,
            "logo_path": logo_path,
            "timestamp": ts,
        },
    }
    analysis_file = OUTPUT_DIR / f"analysis_{ts}.json"
    analysis_file.write_text(json.dumps(payload, indent=2))
    print("  🧾  Payload:")
    print(json.dumps(payload, indent=2))
    print(f"  📄  Analysis JSON → {analysis_file}")

    print("\n[3/3] 🚀  Generating with nano-banana-2 …")
    refs = [product_image_url, template_ref]
    if logo_ref:
        refs.append(logo_ref)
    print(f"  📦  Generation references ({len(refs)}): {refs}")
    out_urls = generate_with_nano_banana(base_prompt, refs, width, height, num_outputs)

    saved: list[Path] = []
    for i, url in enumerate(out_urls, start=1):
        p = save_image_from_url(url, f"ad_creative_{ts}_{i}.png")
        saved.append(p)
        print(f"  ✅  Saved → {p}")

    print("\n" + "═" * 56)
    print(f"   ✨  {len(saved)} creative(s) ready in ./{OUTPUT_DIR}/")
    print("═" * 56 + "\n")
    return saved


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Generate ad creatives — Gemini Vision + Nano Banana 2")
    parser.add_argument("--template", default=None, help="Direct template path")
    parser.add_argument("--examples-dir", default="backend/data/creativity_examples", help="Examples dir")
    parser.add_argument("--example-file", default=None, help="Template filename inside examples dir")
    parser.add_argument("--product-url", default="https://amalfa.in/products/thalia-prism-ring", help="Product URL")
    parser.add_argument("--logo", default=None, help="Optional logo path")
    parser.add_argument("--num", type=int, default=1, help="Number of outputs")
    parser.add_argument("--width", type=int, default=1024)
    parser.add_argument("--height", type=int, default=1024)
    parser.add_argument("--bulk-templates", action="store_true", help="Run generation for all templates in --examples-dir")
    parser.add_argument("--bulk-limit", type=int, default=0, help="Optional cap on number of templates in bulk mode")
    args = parser.parse_args()

    try:
        files: list[Path] = []
        if args.bulk_templates:
            templates = list_template_paths(args.examples_dir, args.bulk_limit)
            print(f"Running bulk mode for {len(templates)} template(s) from {args.examples_dir}")
            failed: list[str] = []
            for idx, template in enumerate(templates, start=1):
                print(f"\n--- [{idx}/{len(templates)}] Template: {template} ---")
                try:
                    out = generate_ad_creative(
                        template_path=template,
                        product_url=args.product_url,
                        logo_path=args.logo,
                        num_outputs=args.num,
                        width=args.width,
                        height=args.height,
                    )
                    files.extend(out)
                except Exception as ex:
                    print(f"  ❌ Template failed: {template} | {ex}")
                    failed.append(template)
            if failed:
                print(f"\nBulk completed with {len(failed)} failed template(s).")
            else:
                print("\nBulk completed with no template failures.")
        else:
            template = resolve_template_path(args.template, args.examples_dir, args.example_file)
            files = generate_ad_creative(
                template_path=template,
                product_url=args.product_url,
                logo_path=args.logo,
                num_outputs=args.num,
                width=args.width,
                height=args.height,
            )
        print("Output files:")
        for p in files:
            print(f"  → {p}")
    except Exception as e:
        print(f"\n❌ Flow failed: {e}")
        raise SystemExit(1)