editableweb / app.py
AkashKumarave's picture
Update app.py
a594839 verified
raw
history blame
13 kB
# app.py
import os
import time
import base64
import json
from pathlib import Path
import gradio as gr
import requests
import jwt
from PIL import Image
# ──────────────────────────────────────────────────────────────────────────────
# CONFIG β€” set your keys as HF Space secrets or env vars for safety.
# (Falls back to the keys you shared.)
# ──────────────────────────────────────────────────────────────────────────────
ACCESS_KEY_ID = os.getenv("KLING_ACCESS_KEY_ID", "AGBGmadNd9hakFYfahytyQQJtN8CJmDJ")
ACCESS_KEY_SECRET = os.getenv("KLING_ACCESS_KEY_SECRET", "dp3pAe4PpdmnAHCAPgEd3PyLmBQrkMde")
API_BASE = "https://api.klingai.com"
ENDPOINT_KOLORS = f"{API_BASE}/v1/images/kolors" # face/subject reference modes (image-to-image)
ENDPOINT_GENERATIONS = f"{API_BASE}/v1/images/generations" # listing (used as a fallback poller)
ENDPOINT_TASK = lambda tid: f"{API_BASE}/v1/tasks/{tid}" # primary poller
# ──────────────────────────────────────────────────────────────────────────────
# AUTH β€” Kling uses JWT: iss / exp / nbf with HS256 (no "access_key" field)
# ──────────────────────────────────────────────────────────────────────────────
def make_jwt() -> str:
headers = {"alg": "HS256", "typ": "JWT"}
now = int(time.time())
payload = {
"iss": ACCESS_KEY_ID,
"exp": now + 1800, # 30 minutes
"nbf": now - 5, # start now (minus small skew)
}
return jwt.encode(payload, ACCESS_KEY_SECRET, algorithm="HS256", headers=headers)
# ──────────────────────────────────────────────────────────────────────────────
# HELPERS
# ──────────────────────────────────────────────────────────────────────────────
def ensure_image_ok(img_path: str):
with Image.open(img_path) as im:
im.verify() # quick integrity check
def b64_data_uri(img_path: str) -> str:
mime = "image/png" if img_path.lower().endswith(".png") else "image/jpeg"
with open(img_path, "rb") as f:
b = base64.b64encode(f.read()).decode("utf-8")
return f"data:{mime};base64,{b}"
def extract_task_id(resp_json: dict) -> str | None:
# Common shapes seen in the wild
if not resp_json:
return None
for key in ("task_id", "taskId", "id"):
if key in resp_json:
return str(resp_json[key])
data = resp_json.get("data") or {}
for key in ("task_id", "taskId", "id"):
if key in data:
return str(data[key])
# Sometimes nested deeper (e.g., {"task": {"id": ...}})
task = resp_json.get("task") or data.get("task") or {}
if "id" in task:
return str(task["id"])
return None
def extract_image_urls(resp_json: dict) -> list[str]:
if not resp_json:
return []
data = resp_json.get("data") or {}
# Typical: data.task_result.images = [{url: "..."}]
task_result = data.get("task_result") or {}
images = task_result.get("images") or []
urls = [img.get("url") for img in images if isinstance(img, dict) and img.get("url")]
if urls:
return urls
# Some variants: output, image_url, result.image_url
for k in ("output", "image_url"):
if k in resp_json and isinstance(resp_json[k], str):
return [resp_json[k]]
result = resp_json.get("result") or {}
if isinstance(result, dict) and result.get("image_url"):
return [result["image_url"]]
# Works array pattern
works = resp_json.get("works") or data.get("works") or []
urls = []
for w in works:
if isinstance(w, dict):
u = w.get("url") or w.get("imageUrl")
if u:
urls.append(u)
return urls
def download_to_file(url: str, out_path: Path) -> Path:
r = requests.get(url, timeout=60)
r.raise_for_status()
out_path.parent.mkdir(parents=True, exist_ok=True)
with open(out_path, "wb") as f:
f.write(r.content)
return out_path
def poll_for_result(task_id: str, headers: dict, timeout_s: int = 300, interval_s: float = 3.0) -> dict:
"""Poll task endpoint first; fallback to listing."""
deadline = time.time() + timeout_s
last_error = None
while time.time() < deadline:
try:
# Preferred: direct task status
r = requests.get(ENDPOINT_TASK(task_id), headers=headers, timeout=30)
if r.status_code == 200:
j = r.json()
# Either "status_name":"succeed" or "data.task_status":"succeed"
status_name = (j.get("status_name")
or (j.get("data") or {}).get("task_status")
or (j.get("task") or {}).get("status_name"))
if isinstance(status_name, dict):
# Some SDKs wrap status as enum-like
status_name = status_name.get("value")
if status_name in ("succeed", "succeeded", "success", "SUCCEED"):
return j
if status_name in ("failed", "FAIL", "failed_with_error"):
return j
elif r.status_code in (401, 403, 404):
last_error = r.text
# Fallback: scan generations list
r2 = requests.get(ENDPOINT_GENERATIONS, headers=headers, params={"pageSize": 200}, timeout=30)
if r2.status_code == 200:
j2 = r2.json()
for item in (j2.get("data") or []):
if str(item.get("task_id")) == str(task_id):
status = item.get("task_status")
if status in ("succeed", "succeeded", "success"):
return item
if status in ("failed",):
return item
except requests.RequestException as e:
last_error = str(e)
time.sleep(interval_s)
raise TimeoutError(f"Polling timed out for task_id {task_id}. Last error: {last_error or 'n/a'}")
# ──────────────────────────────────────────────────────────────────────────────
# CORE CALL β€” Kolors face reference (single reference, faceStrength=97)
# ──────────────────────────────────────────────────────────────────────────────
def kling_face_reference(image_path: str, prompt: str, face_strength: int = 97, aspect_ratio: str = "1:1") -> tuple[str, str]:
"""
Returns (display_image_path, download_file_path)
"""
if not image_path:
raise gr.Error("Please upload a face/reference image.")
ensure_image_ok(image_path)
token = make_jwt()
headers_json = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
headers_multipart = {
"Authorization": f"Bearer {token}",
}
# First try: multipart/form-data (send file as `imageReference`)
data_multipart = {
"prompt": (None, prompt),
"reference": (None, "face"),
"faceStrength": (None, str(max(1, min(100, int(face_strength))))),
"faceNo": (None, "1"), # single face reference
"imageCount": (None, "1"),
"aspect_ratio": (None, aspect_ratio),
}
files = {
"imageReference": (os.path.basename(image_path), open(image_path, "rb"),
"image/png" if image_path.lower().endswith(".png") else "image/jpeg")
}
# Attempt 1 β€” multipart
try:
resp = requests.post(ENDPOINT_KOLORS, headers=headers_multipart, files=files, data=data_multipart, timeout=60)
if resp.status_code == 200:
j = resp.json()
else:
# Read JSON anyway if possible
try:
j = resp.json()
except Exception:
j = {"code": resp.status_code, "message": resp.text}
finally:
# Close file handle if opened
try:
files["imageReference"][1].close()
except Exception:
pass
task_id = extract_task_id(j)
# If Kolors rejected multipart or no task_id, try JSON with data URI
if not task_id:
payload = {
"prompt": prompt,
"reference": "face",
"faceStrength": max(1, min(100, int(face_strength))),
"faceNo": 1,
"imageCount": 1,
"aspect_ratio": aspect_ratio,
"imageReference": b64_data_uri(image_path),
}
resp2 = requests.post(ENDPOINT_KOLORS, headers=headers_json, json=payload, timeout=60)
try:
j = resp2.json()
except Exception:
j = {"code": resp2.status_code, "message": resp2.text}
task_id = extract_task_id(j)
if not task_id:
code = j.get("code") or j.get("service_code") or "?"
msg = j.get("message") or j.get("error") or f"HTTP {resp.status_code if 'resp' in locals() else '?'}"
raise gr.Error(f"Create task failed. Code: {code}. Message: {msg}")
# Poll
result_json = poll_for_result(task_id, headers=headers_json, timeout_s=420, interval_s=3.0)
# Gather image URLs
urls = extract_image_urls(result_json)
if not urls:
# Some APIs return the latest object on /v1/images/generations with same task_id
try:
listing = requests.get(ENDPOINT_GENERATIONS, headers=headers_json, params={"pageSize": 200}, timeout=30).json()
for item in (listing.get("data") or []):
if str(item.get("task_id")) == str(task_id):
urls = extract_image_urls(item)
if urls:
break
except Exception:
pass
if not urls:
raise gr.Error(f"Task {task_id} succeeded but no image URL found in response.")
# Download first image
out_dir = Path("outputs")
out_dir.mkdir(parents=True, exist_ok=True)
out_path = out_dir / f"kling_face_{task_id}.png"
download_to_file(urls[0], out_path)
# Return same path for preview and download
return str(out_path), str(out_path)
# ──────────────────────────────────────────────────────────────────────────────
# GRADIO UI
# ──────────────────────────────────────────────────────────────────────────────
with gr.Blocks(title="Kling AI β€” Image to Image (Face Reference)") as demo:
gr.Markdown("### Kling AI β€” Image-to-Image (Single Face Reference)\nUpload a face image and a prompt. Strength defaults to 97.")
with gr.Row():
in_image = gr.Image(type="filepath", label="Reference Face Image (PNG/JPG)")
in_prompt = gr.Textbox(label="Prompt", placeholder="e.g., Ultra-detailed portrait, soft light, studio background", lines=2)
with gr.Row():
in_strength = gr.Slider(1, 100, value=97, step=1, label="Face Reference Strength")
in_aspect = gr.Dropdown(choices=["1:1", "3:4", "4:3", "2:3", "3:2", "16:9", "9:16", "21:9"], value="1:1", label="Aspect Ratio")
btn = gr.Button("Generate", variant="primary")
out_img = gr.Image(label="Generated Image", show_download_button=False)
out_file = gr.File(label="Download Image")
def run(image, prompt, strength, aspect):
if not prompt or not prompt.strip():
raise gr.Error("Please enter a prompt.")
return kling_face_reference(image, prompt.strip(), int(strength), aspect)
btn.click(fn=run, inputs=[in_image, in_prompt, in_strength, in_aspect], outputs=[out_img, out_file])
if __name__ == "__main__":
# On HF Spaces, just `python app.py` is enough β€” no need to set host/port.
demo.launch()