|
|
import os, time, base64, mimetypes, requests, traceback, glob |
|
|
from dataclasses import dataclass |
|
|
from typing import Optional, Dict, Any, Generator, List |
|
|
|
|
|
import gradio as gr |
|
|
from dotenv import load_dotenv |
|
|
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type |
|
|
|
|
|
|
|
|
try: |
|
|
from openai import OpenAI |
|
|
from openai import RateLimitError, APIConnectionError, APIStatusError |
|
|
except Exception: |
|
|
OpenAI = None |
|
|
RateLimitError = APIConnectionError = APIStatusError = Exception |
|
|
|
|
|
load_dotenv() |
|
|
ENV_FALLBACK_KEY = (os.getenv("OPENAI_API_KEY") or "").strip() |
|
|
|
|
|
ALLOWED_MODELS = ["sora-2", "sora-2-pro", "sora"] |
|
|
ALLOWED_SIZES = ["1280x720","720x1280","1792x1024","1024x1792","1920x1080","1080x1920"] |
|
|
|
|
|
TMP_PATH = "/tmp/sora_output.mp4" |
|
|
|
|
|
@dataclass |
|
|
class JobStatus: |
|
|
status: str |
|
|
error: Optional[str] = None |
|
|
output_url: Optional[str] = None |
|
|
output_b64: Optional[str] = None |
|
|
|
|
|
|
|
|
def _startup_cleanup_tmp(): |
|
|
try: |
|
|
now = time.time() |
|
|
for p in glob.glob("/tmp/*.mp4"): |
|
|
try: |
|
|
if now - os.path.getmtime(p) > 3600: |
|
|
os.remove(p) |
|
|
except Exception: |
|
|
pass |
|
|
except Exception: |
|
|
pass |
|
|
_startup_cleanup_tmp() |
|
|
|
|
|
|
|
|
def _file_to_b64(path: str) -> str: |
|
|
with open(path, "rb") as f: |
|
|
return base64.b64encode(f.read()).decode("utf-8") |
|
|
|
|
|
def _sanitize_prompt(p: str) -> str: |
|
|
p = (p or "").strip() |
|
|
if not p: raise ValueError("Prompt is required.") |
|
|
if len(p) > 8000: p = p[:8000] |
|
|
return p |
|
|
|
|
|
def _validate_duration(d: int) -> int: |
|
|
try: d = int(d) |
|
|
except: d = 10 |
|
|
return max(1, min(d, 30)) |
|
|
|
|
|
def _validate_guidance(g: float) -> float: |
|
|
try: g = float(g) |
|
|
except: g = 7.5 |
|
|
return max(0.0, min(g, 20.0)) |
|
|
|
|
|
def _validate_model(m: str) -> str: |
|
|
return m if m in ALLOWED_MODELS else "sora-2" |
|
|
|
|
|
def _validate_size(s: str) -> str: |
|
|
return s if s in ALLOWED_SIZES else "1280x720" |
|
|
|
|
|
def _make_client(user_key: Optional[str]) -> OpenAI: |
|
|
if OpenAI is None: |
|
|
raise RuntimeError("OpenAI SDK failed to import. Check requirements.txt and rebuild.") |
|
|
key = (user_key or "").strip() or ENV_FALLBACK_KEY |
|
|
if not key: |
|
|
raise ValueError("Missing API key. Paste a valid OpenAI API key.") |
|
|
return OpenAI(api_key=key) |
|
|
|
|
|
|
|
|
from requests import RequestException as ReqErr |
|
|
_OAI_EXC = tuple(e for e in [RateLimitError, APIConnectionError, APIStatusError] if isinstance(e, type)) or (Exception,) |
|
|
|
|
|
@retry( |
|
|
retry=retry_if_exception_type(_OAI_EXC), |
|
|
wait=wait_exponential(multiplier=1, min=1, max=8), |
|
|
stop=stop_after_attempt(5), |
|
|
reraise=True, |
|
|
) |
|
|
def _videos_generate(client: OpenAI, **kwargs) -> Any: |
|
|
|
|
|
if hasattr(client, "videos") and hasattr(client.videos, "generate"): |
|
|
try: return client.videos.generate(**kwargs) |
|
|
except Exception as e_a: last_a = e_a |
|
|
else: |
|
|
last_a = "client.videos.generate not found" |
|
|
|
|
|
|
|
|
if hasattr(client, "videos") and hasattr(client.videos, "jobs") and hasattr(client.videos.jobs, "create"): |
|
|
try: return client.videos.jobs.create(**kwargs) |
|
|
except Exception as e_b: |
|
|
raise RuntimeError(f"videos.generate failed/absent: {last_a}\njobs.create failed: {e_b}") |
|
|
|
|
|
raise RuntimeError("No videos endpoints on this SDK/account. Update openai package or check org access.") |
|
|
|
|
|
@retry( |
|
|
retry=retry_if_exception_type(_OAI_EXC), |
|
|
wait=wait_exponential(multiplier=1, min=1, max=8), |
|
|
stop=stop_after_attempt(120), |
|
|
reraise=True, |
|
|
) |
|
|
def _videos_retrieve(client: OpenAI, job_id: str) -> Any: |
|
|
if hasattr(client, "videos") and hasattr(client.videos, "retrieve"): |
|
|
try: return client.videos.retrieve(job_id) |
|
|
except Exception: pass |
|
|
if hasattr(client, "videos") and hasattr(client.videos, "jobs") and hasattr(client.videos.jobs, "retrieve"): |
|
|
return client.videos.jobs.retrieve(job_id) |
|
|
raise RuntimeError("No videos.retrieve or videos.jobs.retrieve on this SDK/account.") |
|
|
|
|
|
def _extract_status(resp: Any) -> JobStatus: |
|
|
status = getattr(resp, "status", None) or getattr(resp, "state", None) or "unknown" |
|
|
err = None; out_url = None; out_b64 = None |
|
|
|
|
|
output = getattr(resp, "output", None) or getattr(resp, "result", None) |
|
|
if output: |
|
|
out_b64 = getattr(output, "b64_mp4", None) or getattr(output, "b64_video", None) |
|
|
out_url = getattr(output, "url", None) or getattr(output, "video_url", None) |
|
|
artifacts = getattr(output, "artifacts", None) |
|
|
if artifacts and isinstance(artifacts, list): |
|
|
for a in artifacts: |
|
|
out_b64 = out_b64 or getattr(a, "b64_mp4", None) or getattr(a, "b64_video", None) |
|
|
out_url = out_url or getattr(a, "url", None) |
|
|
|
|
|
err_obj = getattr(resp, "error", None) |
|
|
if err_obj: |
|
|
err = getattr(err_obj, "message", None) or str(err_obj) |
|
|
|
|
|
return JobStatus(status=status, error=err, output_url=out_url, output_b64=out_b64) |
|
|
|
|
|
|
|
|
def generate_video_stream( |
|
|
api_key: str, |
|
|
prompt: str, |
|
|
model: str, |
|
|
duration: int, |
|
|
size: str, |
|
|
seed: int, |
|
|
audio: str, |
|
|
guidance: float, |
|
|
init_image: Optional[str], |
|
|
) -> Generator[List[Any], None, None]: |
|
|
|
|
|
yield [gr.update(), "Starting…"] |
|
|
|
|
|
|
|
|
try: |
|
|
client = _make_client(api_key) |
|
|
except Exception as e_init: |
|
|
yield [gr.update(), f"Setup error: {e_init}"]; return |
|
|
|
|
|
|
|
|
try: |
|
|
prompt = _sanitize_prompt(prompt) |
|
|
model = _validate_model(model) |
|
|
duration = _validate_duration(duration) |
|
|
size = _validate_size(size) |
|
|
guidance = _validate_guidance(guidance) |
|
|
audio = "on" if audio == "on" else "off" |
|
|
|
|
|
if init_image: |
|
|
mt, _ = mimetypes.guess_type(init_image) |
|
|
if not (mt and mt.startswith("image/")): |
|
|
yield [gr.update(), "Provided conditioning file isn’t an image."]; return |
|
|
|
|
|
req: Dict[str, Any] = { |
|
|
"model": model, |
|
|
"prompt": prompt, |
|
|
"duration": duration, |
|
|
"size": size, |
|
|
"audio": audio, |
|
|
"guidance": guidance, |
|
|
} |
|
|
if seed and int(seed) > 0: req["seed"] = int(seed) |
|
|
if init_image: req["image"] = {"b64": _file_to_b64(init_image)} |
|
|
except Exception as e_val: |
|
|
yield [gr.update(), f"Validation error: {e_val}"]; return |
|
|
|
|
|
|
|
|
try: |
|
|
yield [gr.update(), "Submitting job…"] |
|
|
job = _videos_generate(client, **req) |
|
|
job_id = getattr(job, "id", None) or getattr(job, "job_id", None) |
|
|
if not job_id: |
|
|
yield [gr.update(), f"Could not get a job id. Raw job object: {repr(job)}"]; return |
|
|
yield [gr.update(), f"Job accepted → id={job_id}"] |
|
|
except _OAI_EXC as oe: |
|
|
yield [gr.update(), f"OpenAI API issue on submit: {oe}"]; return |
|
|
except Exception as e_submit: |
|
|
yield [gr.update(), f"Submit error: {e_submit}\n{traceback.format_exc(limit=2)}"]; return |
|
|
|
|
|
|
|
|
start = time.time(); last_emit = 0 |
|
|
while True: |
|
|
try: |
|
|
status_obj = _videos_retrieve(client, job_id) |
|
|
js = _extract_status(status_obj) |
|
|
except _OAI_EXC as oe: |
|
|
yield [gr.update(), f"OpenAI API issue on poll: {oe}"]; return |
|
|
except Exception as e_poll: |
|
|
yield [gr.update(), f"Polling error: {e_poll}\n{traceback.format_exc(limit=2)}"]; return |
|
|
|
|
|
now = time.time() |
|
|
if now - last_emit > 5: |
|
|
last_emit = now |
|
|
yield [gr.update(), f"Rendering… status={js.status}"] |
|
|
|
|
|
if js.status in ("succeeded", "completed", "complete"): |
|
|
|
|
|
if js.output_url: |
|
|
yield [js.output_url, f"Ready (URL). Done with {model} ({size}, {duration}s)."]; return |
|
|
if js.output_b64: |
|
|
try: |
|
|
with open(TMP_PATH, "wb") as f: |
|
|
f.write(base64.b64decode(js.output_b64)) |
|
|
yield [TMP_PATH, f"Done with {model} ({size}, {duration}s)."]; return |
|
|
except Exception as werr: |
|
|
yield [gr.update(), f"Write error: {werr}"]; return |
|
|
yield [gr.update(), "Job succeeded but no video payload was returned."]; return |
|
|
|
|
|
if js.status in ("failed", "error", "canceled", "cancelled"): |
|
|
detail = f"Status: {js.status}." |
|
|
if js.error: detail += f" Error: {js.error}" |
|
|
yield [gr.update(), detail]; return |
|
|
|
|
|
if now - start > 1800: |
|
|
yield [gr.update(), "Timed out waiting for the video. Try shorter duration."]; return |
|
|
|
|
|
time.sleep(2) |
|
|
|
|
|
|
|
|
def build_ui(): |
|
|
with gr.Blocks(title="ZEN — Sora / Sora-2 / Sora-2-Pro") as demo: |
|
|
gr.Markdown("## ZEN — Sora / Sora-2 / Sora-2-Pro (OpenAI Videos API)") |
|
|
gr.Markdown("Paste an OpenAI API key (not stored). Zero persistent downloads; /tmp is reused/cleaned.") |
|
|
|
|
|
with gr.Row(): |
|
|
api_key = gr.Textbox(label="OpenAI API key (not stored)", type="password", placeholder="sk-...", value="") |
|
|
model = gr.Dropdown(ALLOWED_MODELS, value="sora-2", label="Model") |
|
|
size = gr.Dropdown(ALLOWED_SIZES, value="1280x720", label="Resolution") |
|
|
|
|
|
with gr.Row(): |
|
|
duration = gr.Slider(1, 30, value=10, step=1, label="Duration (seconds)") |
|
|
seed = gr.Number(value=0, precision=0, label="Seed (0 = random)") |
|
|
guidance = gr.Slider(0.0, 20.0, value=7.5, step=0.5, label="Guidance") |
|
|
audio = gr.Dropdown(["on","off"], value="on", label="Audio") |
|
|
|
|
|
prompt = gr.Textbox(label="Prompt", lines=8, placeholder="Cinematic wide drone shot at sunrise…") |
|
|
init_image = gr.Image(label="Optional image (conditioning)", type="filepath") |
|
|
|
|
|
go = gr.Button("Generate", variant="primary") |
|
|
video = gr.Video(label="Result", autoplay=True) |
|
|
status = gr.Textbox(label="Status / Logs", interactive=False) |
|
|
|
|
|
go.click(fn=generate_video_stream, |
|
|
inputs=[api_key, prompt, model, duration, size, seed, audio, guidance, init_image], |
|
|
outputs=[video, status]) |
|
|
return demo |
|
|
|
|
|
demo = build_ui() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|