editableweb / app.py
AkashKumarave's picture
Update app.py
cd42783 verified
raw
history blame
7.21 kB
import os
import io
import time
import base64
import requests
from PIL import Image
import gradio as gr
# -------------------------------------------------------------------
# CONFIG (via Secrets in your HF Space)
# -------------------------------------------------------------------
# Required: your Kling secret key. In HF: Settings β†’ Secrets β†’ add KLING_SECRET_KEY
KLING_SECRET_KEY = os.getenv("KLING_SECRET_KEY", "").strip()
# One of these should be set:
# 1) If you have a proxy (e.g., your Supabase Edge Function):
# e.g. https://rzbyjtariqsnhgwfxdxa.supabase.co/functions/v1/klingai-proxy
KLING_PROXY_URL = os.getenv("KLING_PROXY_URL", "").strip()
# 2) Or, if you want to call Kling API directly, set this to the base URL they provide.
# Example placeholder (replace with the *real* one if you have it):
KLING_API_URL = os.getenv("KLING_API_URL", "").strip()
# Optional tuning defaults
DEFAULT_SIZE = "1024x1024"
TIMEOUT_SEC = 120 # request timeout
POLL_DELAY = 2 # for any status polling, if your proxy supports it
# -------------------------------------------------------------------
# Helpers
# -------------------------------------------------------------------
def _b64_image_from_pil(img: Image.Image) -> str:
buf = io.BytesIO()
img.save(buf, format="PNG")
return base64.b64encode(buf.getvalue()).decode("utf-8")
def _pil_from_b64(b64_str: str) -> Image.Image:
raw = base64.b64decode(b64_str)
return Image.open(io.BytesIO(raw)).convert("RGBA").convert("RGB")
def _ensure_ready():
if not KLING_SECRET_KEY:
raise RuntimeError(
"Missing KLING_SECRET_KEY. In your Hugging Face Space, go to Settings β†’ Secrets and add it."
)
if not (KLING_PROXY_URL or KLING_API_URL):
raise RuntimeError(
"Set either KLING_PROXY_URL (your proxy) OR KLING_API_URL (direct API) in Settings β†’ Secrets."
)
def _post_json(url: str, json_body: dict):
headers = {
"Authorization": f"Bearer {KLING_SECRET_KEY}",
"Content-Type": "application/json",
}
resp = requests.post(url, json=json_body, headers=headers, timeout=TIMEOUT_SEC)
# Try to provide helpful error info
if resp.status_code // 100 != 2:
try:
msg = resp.json()
except Exception:
msg = resp.text
raise RuntimeError(f"HTTP {resp.status_code}: {msg}")
return resp.json()
# -------------------------------------------------------------------
# Unified call: works with your proxy or direct API
# Expected proxy/ API contract:
# Request JSON:
# {
# "mode": "txt2img" | "img2img",
# "prompt": "...",
# "size": "1024x1024",
# "seed": 0,
# "strength": 0.8, # only for img2img
# "image_base64": "<PNG b64>" # only for img2img
# }
# Response JSON (synchronous):
# { "image_base64": "<PNG b64>" }
# OR (async job pattern):
# { "job_id": "..." } then a GET to /result?job_id=... returns same {image_base64: ...}
# Adjust if your proxy differs.
# -------------------------------------------------------------------
def kling_generate(prompt: str,
size: str,
seed: int,
strength: float,
init_image: Image.Image | None):
_ensure_ready()
mode = "img2img" if init_image is not None else "txt2img"
payload = {
"mode": mode,
"prompt": (prompt or "").strip(),
"size": size or DEFAULT_SIZE,
"seed": int(seed) if seed is not None else 0,
}
if mode == "img2img":
payload["strength"] = float(strength)
payload["image_base64"] = _b64_image_from_pil(init_image)
# Decide endpoint
base = KLING_PROXY_URL or KLING_API_URL
# Two possible styles:
# - single endpoint that handles both modes
# - separate endpoints per mode
# Here we use a single generic /generate endpoint for simplicity:
gen_url = base.rstrip("/") + "/generate"
try:
data = _post_json(gen_url, payload)
except Exception as e:
# If your proxy uses job polling, try a second path:
# On non-2xx, we just bubble up error. You can customize below.
raise
# Handle both sync and async contract
if "image_base64" in data:
return _pil_from_b64(data["image_base64"])
# Async job fallback
job_id = data.get("job_id")
if not job_id:
raise RuntimeError("Unexpected response. Neither 'image_base64' nor 'job_id' found.")
# Polling
result_url = base.rstrip("/") + "/result"
started = time.time()
while True:
if time.time() - started > TIMEOUT_SEC:
raise RuntimeError("Generation timeout. Try smaller size or check your proxy/API.")
try:
r = requests.get(result_url, params={"job_id": job_id}, timeout=TIMEOUT_SEC)
if r.status_code // 100 != 2:
time.sleep(POLL_DELAY)
continue
jd = r.json()
if "status" in jd and jd["status"] in ("queued", "running"):
time.sleep(POLL_DELAY)
continue
if "image_base64" in jd:
return _pil_from_b64(jd["image_base64"])
# If API returns error shape
if "error" in jd:
raise RuntimeError(str(jd["error"]))
except Exception:
time.sleep(POLL_DELAY)
# -------------------------------------------------------------------
# Gradio UI
# -------------------------------------------------------------------
with gr.Blocks(title="Kling AI β€” Image Generator") as demo:
gr.Markdown(
"""
# Kling AI β€” Image & Image-to-Image
- Add your **KLING_SECRET_KEY** in Space **Settings β†’ Secrets**.
- Set either **KLING_PROXY_URL** (recommended) or **KLING_API_URL**.
"""
)
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the image you want...",
lines=4
)
init_image = gr.Image(label="Init Image (optional for img2img)", type="pil")
strength = gr.Slider(0.0, 1.0, value=0.8, step=0.05, label="Strength (img2img only)")
size = gr.Dropdown(
["512x512", "768x768", "1024x1024", "1024x1536", "1536x1024"],
value=DEFAULT_SIZE, label="Size"
)
seed = gr.Number(value=0, precision=0, label="Seed (0 = random)")
btn = gr.Button("Generate", variant="primary")
with gr.Column():
out = gr.Image(label="Output", type="pil")
def _on_click(prompt, size, seed, strength, init_image):
if not prompt and init_image is None:
raise gr.Error("Please enter a prompt or provide an init image.")
try:
img = kling_generate(prompt, size, int(seed), float(strength), init_image)
return img
except Exception as e:
raise gr.Error(str(e))
btn.click(
_on_click,
inputs=[prompt, size, seed, strength, init_image],
outputs=[out]
)
if __name__ == "__main__":
demo.launch()