GLAkavya commited on
Commit
4d9294e
·
verified ·
1 Parent(s): 4e38ad3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +451 -0
app.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import tempfile
4
+ import io
5
+ import math
6
+ import time
7
+ import numpy as np
8
+ import cv2
9
+ import gradio as gr
10
+ from google import genai
11
+ from google.genai import types
12
+ from PIL import Image
13
+
14
+ # ── ENV SETUP ────────────────────────────────────────────────────────────────
15
+ gemini_key = (
16
+ os.environ.get("GEMINI_API_KEY", "")
17
+ or os.environ.get("GOOGLE_API_KEY", "")
18
+ ).strip()
19
+ if gemini_key:
20
+ os.environ["GOOGLE_API_KEY"] = gemini_key
21
+ print(f"✅ Gemini key loaded (len={len(gemini_key)})")
22
+ else:
23
+ print("❌ No Gemini key found!")
24
+
25
+ hf_token = (
26
+ os.environ.get("HF_TOKEN", "")
27
+ or os.environ.get("HF_KEY", "")
28
+ ).strip()
29
+ if hf_token:
30
+ try:
31
+ from huggingface_hub import login, InferenceClient
32
+ login(token=hf_token)
33
+ hf_client = InferenceClient(token=hf_token)
34
+ print("✅ HF login OK")
35
+ except Exception as e:
36
+ hf_client = None
37
+ print(f"⚠️ HF login skipped: {e}")
38
+ else:
39
+ hf_client = None
40
+ print("⚠️ No HF token — will use Ken Burns fallback")
41
+
42
+ print("✅ App ready!")
43
+
44
+
45
+ # ── HF MODEL FALLBACK CHAIN ──────────────────────────────────────────────────
46
+ # Models tried in order — first success wins, last is Ken Burns (always works)
47
+
48
+ HF_MODELS = [
49
+ {
50
+ "id": "Lightricks/LTX-2",
51
+ "name": "LTX-2 (Lightricks)",
52
+ "note": "Best quality, fastest inference available ⚡",
53
+ },
54
+ {
55
+ "id": "Wan-AI/Wan2.2-I2V-A14B",
56
+ "name": "Wan 2.2 14B",
57
+ "note": "High quality, slightly slower",
58
+ },
59
+ {
60
+ "id": "stabilityai/stable-video-diffusion-img2vid-xt",
61
+ "name": "Stable Video Diffusion XT",
62
+ "note": "136k downloads, reliable classic",
63
+ },
64
+ {
65
+ "id": "KlingTeam/LivePortrait",
66
+ "name": "KlingTeam LivePortrait",
67
+ "note": "Great for portraits / faces",
68
+ },
69
+ {
70
+ "id": "Lightricks/LTX-Video",
71
+ "name": "LTX-Video (older)",
72
+ "note": "248k downloads, solid fallback",
73
+ },
74
+ # Final fallback — pure OpenCV, always works
75
+ {
76
+ "id": "__ken_burns__",
77
+ "name": "Ken Burns (local, no API)",
78
+ "note": "Always works — cinematic zoom/pan effect",
79
+ },
80
+ ]
81
+
82
+
83
+ def try_hf_model(model_id: str, pil_image: Image.Image, prompt: str) -> bytes | None:
84
+ """Try one HuggingFace model. Returns video bytes or None on failure."""
85
+ if hf_client is None:
86
+ return None
87
+ try:
88
+ buf = io.BytesIO()
89
+ pil_image.save(buf, format="JPEG")
90
+ image_bytes = buf.getvalue()
91
+
92
+ print(f" 🤖 Trying {model_id} ...")
93
+ result = hf_client.image_to_video(
94
+ image=image_bytes,
95
+ model=model_id,
96
+ prompt=prompt,
97
+ )
98
+
99
+ if isinstance(result, bytes):
100
+ return result
101
+ elif hasattr(result, "read"):
102
+ return result.read()
103
+ else:
104
+ return None
105
+
106
+ except Exception as e:
107
+ print(f" ❌ {model_id} failed: {e}")
108
+ return None
109
+
110
+
111
+ def generate_video_with_fallback(
112
+ pil_image: Image.Image,
113
+ prompt: str,
114
+ style: str,
115
+ progress_callback=None,
116
+ ) -> tuple[str, str]:
117
+ """
118
+ Tries HF models in order. Falls back to Ken Burns if all fail.
119
+ Returns (video_path, model_used_name).
120
+ """
121
+ for model_info in HF_MODELS:
122
+ model_id = model_info["id"]
123
+ model_name = model_info["name"]
124
+
125
+ if progress_callback:
126
+ progress_callback(f"⏳ Trying: **{model_name}** — {model_info['note']}")
127
+
128
+ # Ken Burns is always last and always works
129
+ if model_id == "__ken_burns__":
130
+ print(" 🎬 Using Ken Burns (local fallback)")
131
+ path = generate_video_ken_burns(pil_image, duration_sec=5, fps=24, style=style.lower())
132
+ return path, f"🎨 {model_name}"
133
+
134
+ # Try HF model
135
+ video_bytes = try_hf_model(model_id, pil_image, prompt)
136
+ if video_bytes:
137
+ tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
138
+ tmp.write(video_bytes)
139
+ tmp.flush()
140
+ print(f" ✅ SUCCESS with {model_name}")
141
+ return tmp.name, f"🤖 {model_name}"
142
+
143
+ # Small wait between retries to avoid hammering API
144
+ time.sleep(1)
145
+
146
+ # Should never reach here (Ken Burns is last), but just in case
147
+ path = generate_video_ken_burns(pil_image, duration_sec=5, fps=24, style=style.lower())
148
+ return path, "🎨 Ken Burns (local)"
149
+
150
+
151
+ # ── GEMINI ────────────────────────────────────────────────────────────────────
152
+ def call_gemini(pil_image: Image.Image, user_desc: str, language: str, style: str) -> dict:
153
+ client = genai.Client()
154
+
155
+ lang_map = {
156
+ "English": "Write everything in English.",
157
+ "Hindi": "सब कुछ हिंदी में लिखें।",
158
+ "Hinglish": "Write in Hinglish (mix of Hindi and English).",
159
+ }
160
+ style_map = {
161
+ "Fun": "tone: playful, witty, youthful",
162
+ "Premium": "tone: luxurious, sophisticated, aspirational",
163
+ "Energetic": "tone: high-energy, bold, action-packed",
164
+ }
165
+
166
+ prompt = f"""You are an expert ad copywriter. Analyze this product image and create a compelling social-media video ad.
167
+
168
+ {f'Product description: {user_desc}' if user_desc.strip() else ''}
169
+ Language rule : {lang_map.get(language, lang_map['English'])}
170
+ Style rule : {style_map.get(style, style_map['Fun'])}
171
+
172
+ CRITICAL: Return ONLY raw JSON. No markdown. No ```json. No explanation. Pure JSON only.
173
+ {{
174
+ "hook": "attention-grabbing opening line (1-2 sentences)",
175
+ "script": "full 15-20 second voiceover script",
176
+ "cta": "call-to-action phrase",
177
+ "video_prompt": "detailed cinematic advertising scene description for image-to-video AI"
178
+ }}"""
179
+
180
+ buf = io.BytesIO()
181
+ pil_image.save(buf, format="JPEG")
182
+ image_bytes = buf.getvalue()
183
+
184
+ response = client.models.generate_content(
185
+ model="gemini-2.5-flash",
186
+ contents=[
187
+ types.Part.from_bytes(data=image_bytes, mime_type="image/jpeg"),
188
+ types.Part.from_text(text=prompt),
189
+ ],
190
+ )
191
+
192
+ raw = response.text.strip()
193
+ if "```" in raw:
194
+ raw = raw.split("```")[1]
195
+ if raw.lower().startswith("json"):
196
+ raw = raw[4:]
197
+ raw = raw.strip()
198
+
199
+ return json.loads(raw)
200
+
201
+
202
+ # ── KEN BURNS VIDEO (local fallback) ─────────────────────────────────────────
203
+ def ease_in_out(t):
204
+ return t * t * (3 - 2 * t)
205
+
206
+ def ease_out_bounce(t):
207
+ if t < 1/2.75:
208
+ return 7.5625 * t * t
209
+ elif t < 2/2.75:
210
+ t -= 1.5/2.75
211
+ return 7.5625 * t * t + 0.75
212
+ elif t < 2.5/2.75:
213
+ t -= 2.25/2.75
214
+ return 7.5625 * t * t + 0.9375
215
+ else:
216
+ t -= 2.625/2.75
217
+ return 7.5625 * t * t + 0.984375
218
+
219
+ def apply_vignette(frame, strength=0.6):
220
+ h, w = frame.shape[:2]
221
+ Y, X = np.ogrid[:h, :w]
222
+ cx, cy = w / 2, h / 2
223
+ dist = np.sqrt(((X - cx) / cx) ** 2 + ((Y - cy) / cy) ** 2)
224
+ mask = np.clip(1.0 - strength * (dist ** 1.5), 0, 1)
225
+ return (frame * mask[:, :, np.newaxis]).astype(np.uint8)
226
+
227
+ def apply_color_grade(frame, style="premium"):
228
+ f = frame.astype(np.float32)
229
+ if style == "premium":
230
+ f[:,:,0] = np.clip(f[:,:,0] * 1.05, 0, 255)
231
+ f[:,:,2] = np.clip(f[:,:,2] * 1.08, 0, 255)
232
+ f = np.clip(f * 1.05, 0, 255)
233
+ elif style == "energetic":
234
+ gray = np.mean(f, axis=2, keepdims=True)
235
+ f = np.clip(gray + 1.4 * (f - gray), 0, 255)
236
+ f = np.clip(f * 1.1, 0, 255)
237
+ elif style == "fun":
238
+ f[:,:,0] = np.clip(f[:,:,0] * 1.1, 0, 255)
239
+ f[:,:,1] = np.clip(f[:,:,1] * 1.05, 0, 255)
240
+ return f.astype(np.uint8)
241
+
242
+ def generate_video_ken_burns(pil_image: Image.Image, duration_sec: int = 5, fps: int = 24, style: str = "premium") -> str:
243
+ total_frames = duration_sec * fps
244
+
245
+ img = pil_image.convert("RGB")
246
+ target_w, target_h = 720, 1280
247
+ img = img.resize((target_w, target_h), Image.LANCZOS)
248
+
249
+ tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
250
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
251
+ out = cv2.VideoWriter(tmp.name, fourcc, fps, (target_w, target_h))
252
+
253
+ pad = 160
254
+ big_h, big_w = target_h + pad * 2, target_w + pad * 2
255
+ big_img = np.array(img.resize((big_w, big_h), Image.LANCZOS))
256
+
257
+ s1_end = int(fps * 1.5)
258
+ s2_end = int(fps * 3.0)
259
+ s3_end = int(fps * 4.2)
260
+ s4_end = total_frames
261
+
262
+ for i in range(total_frames):
263
+ if i < s1_end:
264
+ t = i / s1_end
265
+ te = ease_out_bounce(min(t * 1.1, 1.0))
266
+ zoom = 1.35 - 0.25 * te
267
+ pan_x = int(pad * 0.1 * t)
268
+ pan_y = int(-pad * 0.15 * t)
269
+ elif i < s2_end:
270
+ t = (i - s1_end) / (s2_end - s1_end)
271
+ te = ease_in_out(t)
272
+ zoom = 1.10 - 0.05 * te
273
+ shake_x = int(3 * math.sin(i * 0.8))
274
+ shake_y = int(2 * math.cos(i * 1.1))
275
+ pan_x = int(pad * 0.1 + shake_x)
276
+ pan_y = int(-pad * 0.15 - pad * 0.20 * te + shake_y)
277
+ elif i < s3_end:
278
+ t = (i - s2_end) / (s3_end - s2_end)
279
+ te = ease_in_out(t)
280
+ zoom = 1.05 - 0.04 * te
281
+ pan_x = int(pad * 0.1 * (1 - te))
282
+ pan_y = int(-pad * 0.35 * (1 - te))
283
+ else:
284
+ t = (i - s3_end) / (s4_end - s3_end)
285
+ te = ease_in_out(t)
286
+ zoom = 1.01 + 0.03 * te
287
+ pan_x = 0
288
+ pan_y = 0
289
+
290
+ crop_w = int(target_w / zoom)
291
+ crop_h = int(target_h / zoom)
292
+ cx = big_w // 2 + pan_x
293
+ cy = big_h // 2 + pan_y
294
+ x1 = max(0, cx - crop_w // 2)
295
+ y1 = max(0, cy - crop_h // 2)
296
+ x2 = min(big_w, x1 + crop_w)
297
+ y2 = min(big_h, y1 + crop_h)
298
+
299
+ if x2 - x1 < 10 or y2 - y1 < 10:
300
+ x1, y1, x2, y2 = 0, 0, target_w, target_h
301
+
302
+ cropped = big_img[y1:y2, x1:x2]
303
+ frame = cv2.resize(cropped, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
304
+ frame = apply_color_grade(frame, style)
305
+ frame = apply_vignette(frame, strength=0.55)
306
+
307
+ fade_in_end = int(fps * 0.4)
308
+ fade_out_sta = int(fps * 4.4)
309
+ if i < fade_in_end:
310
+ alpha = ease_in_out(i / fade_in_end)
311
+ elif i >= fade_out_sta:
312
+ alpha = ease_in_out(1.0 - (i - fade_out_sta) / (total_frames - fade_out_sta))
313
+ else:
314
+ alpha = 1.0
315
+
316
+ flash_frames = {s1_end, s1_end+1, s2_end, s2_end+1}
317
+ if i in flash_frames:
318
+ flash_strength = 0.35 if i in {s1_end, s2_end} else 0.15
319
+ white = np.ones_like(frame) * 255
320
+ frame = cv2.addWeighted(frame, 1 - flash_strength, white.astype(np.uint8), flash_strength, 0)
321
+
322
+ frame = np.clip(frame.astype(np.float32) * alpha, 0, 255).astype(np.uint8)
323
+ frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
324
+ out.write(frame_bgr)
325
+
326
+ out.release()
327
+ return tmp.name
328
+
329
+
330
+ # ── MAIN PIPELINE ─────────────────────────────────────────────────────────────
331
+ _status_log = []
332
+
333
+ def generate_ad(image, user_desc, language, style, progress=gr.Progress()):
334
+ global _status_log
335
+ _status_log = []
336
+
337
+ if image is None:
338
+ return None, "⚠️ Please upload a product image.", "", "", "❌ No image"
339
+
340
+ pil_image = image if isinstance(image, Image.Image) else Image.fromarray(image)
341
+
342
+ # STEP 1 — Gemini ad copy
343
+ progress(0.1, desc="🧠 Gemini generating ad copy...")
344
+ try:
345
+ ad_data = call_gemini(pil_image, user_desc or "", language, style)
346
+ except Exception as e:
347
+ return None, f"❌ Gemini error: {e}", "", "", "❌ Gemini failed"
348
+
349
+ hook = ad_data.get("hook", "")
350
+ script = ad_data.get("script", "")
351
+ cta = ad_data.get("cta", "")
352
+ video_prompt = ad_data.get("video_prompt", hook)
353
+
354
+ # STEP 2 — Video with fallback chain
355
+ progress(0.3, desc="🎬 Generating video (trying AI models)...")
356
+
357
+ status_lines = []
358
+
359
+ def log_progress(msg):
360
+ status_lines.append(msg)
361
+ progress(0.3 + len(status_lines) * 0.1, desc=msg.replace("**", "").replace("*", ""))
362
+
363
+ try:
364
+ video_path, model_used = generate_video_with_fallback(
365
+ pil_image,
366
+ prompt=video_prompt,
367
+ style=style,
368
+ progress_callback=log_progress,
369
+ )
370
+ except Exception as e:
371
+ return None, hook, f"❌ Video error: {e}\n\n{script}", cta, "❌ All models failed"
372
+
373
+ progress(1.0, desc="✅ Done!")
374
+
375
+ model_log = "\n".join(status_lines) + f"\n\n✅ **Used:** {model_used}"
376
+ return video_path, hook, script, cta, model_log
377
+
378
+
379
+ # ── GRADIO UI ─────────────────────────────────────────────────────────────────
380
+ css = """
381
+ #title { text-align:center; font-size:2.2rem; font-weight:800; margin-bottom:.2rem; }
382
+ #sub { text-align:center; color:#888; margin-bottom:1.5rem; }
383
+ .model-chain { font-size:.85rem; line-height:1.7; }
384
+ """
385
+
386
+ with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="violet")) as demo:
387
+
388
+ gr.Markdown("# 🎬 AI Reel Generator", elem_id="title")
389
+ gr.Markdown(
390
+ "Upload a product image → Gemini writes ad copy → "
391
+ "AI generates cinematic 5-sec reel (5-model fallback chain).",
392
+ elem_id="sub",
393
+ )
394
+
395
+ with gr.Row():
396
+ # ── LEFT COLUMN ──────────────────────────────────────────────────────
397
+ with gr.Column(scale=1):
398
+ image_input = gr.Image(label="📸 Upload Product Image", type="pil", height=300)
399
+ desc_input = gr.Textbox(
400
+ label="📝 Describe your product (optional)",
401
+ placeholder="e.g. Premium sneakers with star design …",
402
+ lines=3,
403
+ )
404
+ with gr.Row():
405
+ lang_dropdown = gr.Dropdown(
406
+ choices=["English", "Hindi", "Hinglish"],
407
+ value="English", label="🌐 Language",
408
+ )
409
+ style_dropdown = gr.Dropdown(
410
+ choices=["Fun", "Premium", "Energetic"],
411
+ value="Fun", label="🎨 Style",
412
+ )
413
+ gen_btn = gr.Button("🚀 Generate Ad", variant="primary", size="lg")
414
+
415
+ # Model chain info box
416
+ gr.Markdown(
417
+ "**🔗 Model Fallback Chain:**\n"
418
+ "1. 🤖 Lightricks/LTX-2 ⚡\n"
419
+ "2. 🤖 Wan 2.2 I2V-A14B\n"
420
+ "3. 🤖 Stable Video Diffusion XT\n"
421
+ "4. 🤖 KlingTeam/LivePortrait\n"
422
+ "5. 🤖 Lightricks/LTX-Video\n"
423
+ "6. 🎨 Ken Burns (local, always works)",
424
+ elem_classes="model-chain",
425
+ )
426
+
427
+ # ── RIGHT COLUMN ─────────────────────────────────────────────────────
428
+ with gr.Column(scale=1):
429
+ video_out = gr.Video(label="🎥 5-Second Ad Reel", height=400)
430
+ hook_out = gr.Textbox(label="⚡ Hook", lines=2, interactive=False)
431
+ script_out = gr.Textbox(label="📄 Script", lines=5, interactive=False)
432
+ cta_out = gr.Textbox(label="🎯 CTA", lines=1, interactive=False)
433
+ status_out = gr.Textbox(label="📊 Model Log", lines=6, interactive=False)
434
+
435
+ gen_btn.click(
436
+ fn=generate_ad,
437
+ inputs=[image_input, desc_input, lang_dropdown, style_dropdown],
438
+ outputs=[video_out, hook_out, script_out, cta_out, status_out],
439
+ )
440
+
441
+ gr.Markdown(
442
+ "---\n**How it works:** "
443
+ "1️⃣ Gemini 2.5 Flash → hook + script + CTA + video prompt. "
444
+ "2️⃣ Tries 5 HuggingFace image-to-video models in order. "
445
+ "3️⃣ First success wins → downloads video. "
446
+ "4️⃣ If all API calls fail → Ken Burns cinematic effect (local, always works). "
447
+ "⚡ With HF token + inference-available model: ~10-30 seconds total!"
448
+ )
449
+
450
+ if __name__ == "__main__":
451
+ demo.launch()