Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import datetime | |
| import shutil | |
| import json | |
| import requests | |
| import gradio as gr | |
| import re | |
| import base64 | |
| import concurrent.futures | |
| import zipfile | |
| import random | |
| from PIL import Image | |
| import io | |
| import pandas as pd | |
| # ================= 配置区域 ================= | |
| BASE_URL = "https://yunwu.ai" | |
| # 1. 通用/Veo API Key (用于 LLM 分析 和 Veo 视频) | |
| YUNWU_API_KEY = "sk-Vhxjwm4XXu5fKrAtRNbRZGdPbocDZjG7B9UsSUjAdOQLyMUA" | |
| # 2. Sora 专用 API Key | |
| SORA_API_KEY = "sk-heZhMAAncKvJybPfhfx6rbj6ek0CoImJxrGPeRaXqSRpQR2t" | |
| # 模型配置 | |
| TEXT_MODEL = "gemini-3-pro-preview-thinking" | |
| MODEL_OPTIONS = ["sora-2-all", "veo_3_1-fast"] | |
| # 输出目录 | |
| OUTPUT_DIR = "Ecommerce_Vertical_Output" | |
| # 并发配置 | |
| MAX_WORKERS = 10 | |
| VIDEO_WORKERS = 2 # 视频生成并发数 | |
| # ================= 提示词模版 (竖屏电商专用) ================= | |
| ECOMMERCE_ANALYSIS_PROMPT = """你是一位顶级短视频/TikTok电商导演,擅长制作高转化率的竖屏产品展示视频。 | |
| 我提供了一张产品图片和描述。请分析产品的核心卖点、材质和适用场景。 | |
| 任务:生成 {count} 个**截然不同**的视频生成提示词 (Prompts)。 | |
| 要求: | |
| 1. **画幅适配**:所有构图必须适合 9:16 竖屏播放 (Vertical, Portrait composition)。 | |
| 2. **多样性**:涵盖特写(Close-up)、展示(Showcase)、场景(Lifestyle)、创意光影(Cinematic Lighting)。 | |
| 3. **结构**:[镜头运动] + [主体描述] + [环境/光影] + [风格]。 | |
| 【输出格式】: | |
| 严格返回一个 JSON 字符串列表,不要包含 Markdown 格式或序号。 | |
| 示例: | |
| [ | |
| "Vertical shot, camera panning up the product, golden hour lighting, 4k resolution...", | |
| "Slow motion close-up of the texture, soft studio lighting, portrait mode...", | |
| "Product placed on a wooden table, cozy atmosphere, steam rising, vertical frame..." | |
| ] | |
| """ | |
| # ================= 工具类 ================= | |
| class PipelineLogger: | |
| def __init__(self): | |
| self.logs = [] | |
| def log(self, message): | |
| timestamp = datetime.datetime.now().strftime("%H:%M:%S") | |
| formatted = f"[{timestamp}] {message}" | |
| print(formatted) | |
| self.logs.append(formatted) | |
| return "\n".join(self.logs) | |
| def clean_json_text(text): | |
| if "```" in text: | |
| try: | |
| return text.split("```json")[-1].split("```")[0].strip() | |
| except: | |
| return text.split("```")[-1].split("```")[0].strip() | |
| return text.strip() | |
| def image_to_base64(image_path): | |
| if not image_path: return None | |
| with open(image_path, "rb") as f: | |
| return base64.b64encode(f.read()).decode("utf-8") | |
| def image_to_data_uri(image_path): | |
| if not image_path: return None | |
| b64 = image_to_base64(image_path) | |
| return f"data:image/png;base64,{b64}" | |
| def download_file(url): | |
| for attempt in range(5): | |
| try: | |
| headers = { | |
| "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36" | |
| } | |
| r = requests.get(url, headers=headers, stream=True, timeout=60) | |
| if r.status_code == 200: return r.content | |
| except: time.sleep(2) | |
| return None | |
| def create_zip(source_dir, output_filename): | |
| zip_path = output_filename + ".zip" | |
| has_file = False | |
| with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: | |
| for root, dirs, files in os.walk(source_dir): | |
| for file in files: | |
| if file.endswith('.mp4'): | |
| zipf.write(os.path.join(root, file), file) | |
| has_file = True | |
| return zip_path if has_file else None | |
| def clear_output_dir(): | |
| if os.path.exists(OUTPUT_DIR): shutil.rmtree(OUTPUT_DIR) | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| # ================= 核心 API 交互类 ================= | |
| class EcommerceDirector: | |
| def __init__(self, base_url): | |
| self.base_url = base_url | |
| # Step 1: 分析 (Gemini) | |
| def analyze_and_plan(self, image_path, description, count): | |
| headers = { | |
| "Authorization": f"Bearer {YUNWU_API_KEY}", | |
| "Content-Type": "application/json" | |
| } | |
| content_payload = [] | |
| content_payload.append({"type": "text", "text": f"Product Description: {description}\nTarget: Generate {count} vertical video prompts."}) | |
| if image_path: | |
| img_b64 = image_to_base64(image_path) | |
| content_payload.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64}"}}) | |
| payload = { | |
| "model": TEXT_MODEL, | |
| "messages": [ | |
| {"role": "system", "content": ECOMMERCE_ANALYSIS_PROMPT.format(count=count)}, | |
| {"role": "user", "content": content_payload} | |
| ], | |
| "temperature": 0.7 | |
| } | |
| try: | |
| resp = requests.post(f"{self.base_url}/v1/chat/completions", headers=headers, json=payload, timeout=60) | |
| if resp.status_code == 200: | |
| content = resp.json()['choices'][0]['message']['content'] | |
| return json.loads(clean_json_text(content)) | |
| else: | |
| print(f"Analysis Failed: {resp.text}") | |
| return [f"Showcase video of product {i+1} vertical style" for i in range(count)] | |
| except Exception as e: | |
| print(f"API Error: {e}") | |
| return [f"Showcase video of product {i+1} vertical style" for i in range(count)] | |
| # Step 2: 路由 | |
| def generate_video(self, model_name, prompt, ref_image_path): | |
| if not ref_image_path: | |
| return None, "Error: Reference image is mandatory." | |
| if "sora" in model_name.lower(): | |
| return self._generate_sora(model_name, prompt, ref_image_path) | |
| else: | |
| return self._generate_veo(model_name, prompt, ref_image_path) | |
| # === VEO 逻辑 (9x16, Multipart) === | |
| def _generate_veo(self, model_name, prompt, ref_image_path): | |
| url = f"{self.base_url}/v1/videos" | |
| headers = {"Authorization": f"Bearer {YUNWU_API_KEY}"} | |
| for attempt in range(1, 4): | |
| try: | |
| data = { | |
| 'model': model_name, | |
| 'prompt': prompt, | |
| 'seconds': '5', | |
| 'size': '9x16', # 竖屏 | |
| 'watermark': 'false' | |
| } | |
| with open(ref_image_path, 'rb') as f_img: | |
| files = [('input_reference', (os.path.basename(ref_image_path), f_img, 'image/png'))] | |
| resp = requests.post(url, headers=headers, data=data, files=files, timeout=120) | |
| if resp.status_code == 200: | |
| task_id = resp.json().get('id') | |
| return self._poll_veo(task_id) # Veo 使用标准轮询 | |
| print(f"[Veo] Submit Fail ({attempt}): {resp.text}") | |
| time.sleep(2) | |
| except Exception as e: | |
| print(f"[Veo] Error: {e}") | |
| time.sleep(2) | |
| return None, "Veo Failed" | |
| # === SORA 逻辑 (Portrait, JSON+Base64, 独立 Query) === | |
| def _generate_sora(self, model_name, prompt, ref_image_path): | |
| url = f"{self.base_url}/v1/video/create" | |
| headers = { | |
| "Authorization": f"Bearer {SORA_API_KEY}", | |
| "Content-Type": "application/json", | |
| "Accept": "application/json" | |
| } | |
| for attempt in range(1, 4): | |
| try: | |
| data_uri = image_to_data_uri(ref_image_path) | |
| payload = { | |
| "model": model_name, | |
| "orientation": "portrait", # 竖屏 | |
| "prompt": prompt, | |
| "size": "large", | |
| "duration": 5, | |
| "watermark": False, | |
| "images": [data_uri] # 强制垫图 | |
| } | |
| resp = requests.post(url, headers=headers, json=payload, timeout=120) | |
| if resp.status_code == 200: | |
| resp_json = resp.json() | |
| task_id = resp_json.get('id') | |
| if task_id: | |
| return self._poll_sora(task_id) # Sora 使用特殊轮询 | |
| print(f"[Sora] Submit Fail ({attempt}): {resp.text}") | |
| time.sleep(2) | |
| except Exception as e: | |
| print(f"[Sora] Error: {e}") | |
| time.sleep(2) | |
| return None, "Sora Failed" | |
| # --- Veo 轮询 (标准) --- | |
| def _poll_veo(self, task_id): | |
| url = f"{self.base_url}/v1/videos/{task_id}" | |
| headers = {"Authorization": f"Bearer {YUNWU_API_KEY}"} | |
| return self._do_poll(url, headers) | |
| # --- Sora 轮询 (Query 参数) --- | |
| def _poll_sora(self, task_id): | |
| # 按照文档:/v1/video/query?id=task_id | |
| url = f"{self.base_url}/v1/video/query" | |
| headers = { | |
| "Authorization": f"Bearer {SORA_API_KEY}", | |
| "Accept": "application/json" | |
| } | |
| # requests params 会自动拼接 ?id=... | |
| return self._do_poll(url, headers, params={"id": task_id}) | |
| # --- 通用轮询器 --- | |
| def _do_poll(self, url, headers, params=None): | |
| for _ in range(60): # 3分钟 | |
| time.sleep(3) | |
| try: | |
| resp = requests.get(url, headers=headers, params=params) | |
| if resp.status_code == 200: | |
| data = resp.json() | |
| status = data.get('status') | |
| if status in ['succeeded', 'success', 'completed']: | |
| # 深度查找 video_url | |
| final_url = self._deep_find_url(data) | |
| if final_url: return final_url, "OK" | |
| elif status == 'failed': | |
| return None, f"Remote Fail: {data.get('error') or 'Unknown'}" | |
| except: pass | |
| return None, "Timeout" | |
| def _deep_find_url(self, data): | |
| if isinstance(data, str) and data.startswith('http') and ('.mp4' in data or '.mov' in data or 'oss' in data): | |
| return data | |
| if isinstance(data, dict): | |
| for key, value in data.items(): | |
| if key in ['url', 'video_url', 'output_url'] and isinstance(value, str) and value.startswith('http'): | |
| return value | |
| res = self._deep_find_url(value) | |
| if res: return res | |
| elif isinstance(data, list): | |
| for item in data: | |
| res = self._deep_find_url(item) | |
| if res: return res | |
| return None | |
| # ================= 业务逻辑 ================= | |
| logger = PipelineLogger() | |
| director = EcommerceDirector(BASE_URL) | |
| def run_analysis_step(image, desc, count): | |
| if not image and not desc: | |
| return "⚠️ 请上传图片或填写描述", None, gr.update(visible=False) | |
| logger.log(f"🕵️ Analyzing Product... Target: {count} videos") | |
| prompts = director.analyze_and_plan(image, desc, count) | |
| # Dataframe: [ID, Prompt] | |
| df_data = [[i+1, p] for i, p in enumerate(prompts)] | |
| logger.log(f"✅ Generated {len(prompts)} prompts.") | |
| return logger.log("Ready to generate."), df_data, gr.update(visible=True) | |
| def run_generation_step(image, prompt_data, model_name): | |
| # 解析 Dataframe | |
| data_list = [] | |
| if prompt_data is None: | |
| return "⚠️ 无提示词", None, "Failed" | |
| if isinstance(prompt_data, list): | |
| data_list = prompt_data | |
| elif hasattr(prompt_data, 'values'): | |
| if prompt_data.empty: return "⚠️ 提示词为空", None, "Failed" | |
| data_list = prompt_data.values.tolist() | |
| if len(data_list) == 0: return "⚠️ 列表为空", None, "Failed" | |
| if not image: | |
| return logger.log("⚠️ Error: 必须提供垫图 (Reference Image)"), None, "Error" | |
| clear_output_dir() | |
| logger.log(f"🎬 Batch Start. Model: {model_name} | Count: {len(data_list)}") | |
| futures = [] | |
| video_executor = concurrent.futures.ThreadPoolExecutor(max_workers=VIDEO_WORKERS) | |
| for row in data_list: | |
| idx = row[0] | |
| prompt = row[1] | |
| logger.log(f"➕ Queueing Video {idx} ({model_name})...") | |
| futures.append(video_executor.submit(process_single_video, idx, prompt, image, model_name)) | |
| completed = 0 | |
| total = len(futures) | |
| for f in concurrent.futures.as_completed(futures): | |
| idx, status = f.result() | |
| if status == "OK": | |
| completed += 1 | |
| logger.log(f"✅ Video {idx}/{total} Finished.") | |
| else: | |
| logger.log(f"❌ Video {idx}/{total} Failed.") | |
| video_executor.shutdown(wait=True) | |
| logger.log("📦 Zipping videos...") | |
| zip_path = create_zip(OUTPUT_DIR, "Ecommerce_Videos") | |
| return logger.log("🎉 All Done!"), zip_path, f"Completed {completed}/{total}" | |
| def process_single_video(idx, prompt, img_path, model_name): | |
| try: | |
| url, msg = director.generate_video(model_name, prompt, img_path) | |
| if url: | |
| vid_bytes = download_file(url) | |
| if vid_bytes: | |
| save_path = os.path.join(OUTPUT_DIR, f"Product_Video_{idx:02d}.mp4") | |
| with open(save_path, "wb") as f: | |
| f.write(vid_bytes) | |
| return idx, "OK" | |
| except Exception as e: | |
| print(f"Worker Error: {e}") | |
| return idx, "Fail" | |
| # ================= UI 界面 (默认颜色) ================= | |
| with gr.Blocks(title="Ecommerce Video Generator") as demo: | |
| gr.Markdown("## 🛍️ 电商竖屏视频批量生成 (Sora-2 & Veo)") | |
| gr.Markdown("单图生视频模式:Step 1 分析并生成分镜脚本 -> Step 2 使用主图批量生成视频") | |
| with gr.Row(): | |
| # 左侧配置 | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 1. 商品信息 (必填)") | |
| input_image = gr.Image(label="商品主图 (必须上传,用于垫图)", type="filepath", height=250) | |
| input_desc = gr.Textbox(label="商品描述", placeholder="输入商品卖点...", lines=4) | |
| gr.Markdown("### 2. 生成配置") | |
| count_slider = gr.Slider(minimum=1, maximum=100, value=5, step=1, label="生成数量") | |
| model_dropdown = gr.Dropdown(choices=MODEL_OPTIONS, value="sora-2-all", label="视频模型") | |
| analyze_btn = gr.Button("🔍 1. 分析并生成脚本", variant="primary") | |
| # 右侧操作 | |
| with gr.Column(scale=2): | |
| gr.Markdown("### 3. 脚本确认") | |
| prompt_dataframe = gr.Dataframe( | |
| headers=["ID", "Prompt"], | |
| datatype=["number", "str"], | |
| col_count=(2, "fixed"), | |
| interactive=True, | |
| label="生成的分镜提示词 (可修改)", | |
| value=[[1, "等待分析..."]] | |
| ) | |
| generate_btn = gr.Button("🎬 2. 开始批量生成 (使用主图)", variant="primary", visible=False) | |
| gr.Markdown("### 4. 结果") | |
| log_box = gr.TextArea(label="日志", lines=8, interactive=False) | |
| status_box = gr.Textbox(label="状态", interactive=False) | |
| download_zip = gr.File(label="下载视频包") | |
| # 逻辑绑定 | |
| analyze_btn.click( | |
| fn=run_analysis_step, | |
| inputs=[input_image, input_desc, count_slider], | |
| outputs=[log_box, prompt_dataframe, generate_btn] | |
| ) | |
| generate_btn.click( | |
| fn=run_generation_step, | |
| inputs=[input_image, prompt_dataframe, model_dropdown], | |
| outputs=[log_box, download_zip, status_box] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch(server_name="0.0.0.0", inbrowser=True) |