Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import time | |
| import base64 | |
| import json | |
| from pathlib import Path | |
| import gradio as gr | |
| import requests | |
| import jwt | |
| from PIL import Image | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CONFIG β set your keys as HF Space secrets or env vars for safety. | |
| # (Falls back to the keys you shared.) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ACCESS_KEY_ID = os.getenv("KLING_ACCESS_KEY_ID", "AGBGmadNd9hakFYfahytyQQJtN8CJmDJ") | |
| ACCESS_KEY_SECRET = os.getenv("KLING_ACCESS_KEY_SECRET", "dp3pAe4PpdmnAHCAPgEd3PyLmBQrkMde") | |
| API_BASE = "https://api.klingai.com" | |
| ENDPOINT_KOLORS = f"{API_BASE}/v1/images/kolors" # face/subject reference modes (image-to-image) | |
| ENDPOINT_GENERATIONS = f"{API_BASE}/v1/images/generations" # listing (used as a fallback poller) | |
| ENDPOINT_TASK = lambda tid: f"{API_BASE}/v1/tasks/{tid}" # primary poller | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # AUTH β Kling uses JWT: iss / exp / nbf with HS256 (no "access_key" field) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def make_jwt() -> str: | |
| headers = {"alg": "HS256", "typ": "JWT"} | |
| now = int(time.time()) | |
| payload = { | |
| "iss": ACCESS_KEY_ID, | |
| "exp": now + 1800, # 30 minutes | |
| "nbf": now - 5, # start now (minus small skew) | |
| } | |
| return jwt.encode(payload, ACCESS_KEY_SECRET, algorithm="HS256", headers=headers) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # HELPERS | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def ensure_image_ok(img_path: str): | |
| with Image.open(img_path) as im: | |
| im.verify() # quick integrity check | |
| def b64_data_uri(img_path: str) -> str: | |
| mime = "image/png" if img_path.lower().endswith(".png") else "image/jpeg" | |
| with open(img_path, "rb") as f: | |
| b = base64.b64encode(f.read()).decode("utf-8") | |
| return f"data:{mime};base64,{b}" | |
| def extract_task_id(resp_json: dict) -> str | None: | |
| # Common shapes seen in the wild | |
| if not resp_json: | |
| return None | |
| for key in ("task_id", "taskId", "id"): | |
| if key in resp_json: | |
| return str(resp_json[key]) | |
| data = resp_json.get("data") or {} | |
| for key in ("task_id", "taskId", "id"): | |
| if key in data: | |
| return str(data[key]) | |
| # Sometimes nested deeper (e.g., {"task": {"id": ...}}) | |
| task = resp_json.get("task") or data.get("task") or {} | |
| if "id" in task: | |
| return str(task["id"]) | |
| return None | |
| def extract_image_urls(resp_json: dict) -> list[str]: | |
| if not resp_json: | |
| return [] | |
| data = resp_json.get("data") or {} | |
| # Typical: data.task_result.images = [{url: "..."}] | |
| task_result = data.get("task_result") or {} | |
| images = task_result.get("images") or [] | |
| urls = [img.get("url") for img in images if isinstance(img, dict) and img.get("url")] | |
| if urls: | |
| return urls | |
| # Some variants: output, image_url, result.image_url | |
| for k in ("output", "image_url"): | |
| if k in resp_json and isinstance(resp_json[k], str): | |
| return [resp_json[k]] | |
| result = resp_json.get("result") or {} | |
| if isinstance(result, dict) and result.get("image_url"): | |
| return [result["image_url"]] | |
| # Works array pattern | |
| works = resp_json.get("works") or data.get("works") or [] | |
| urls = [] | |
| for w in works: | |
| if isinstance(w, dict): | |
| u = w.get("url") or w.get("imageUrl") | |
| if u: | |
| urls.append(u) | |
| return urls | |
| def download_to_file(url: str, out_path: Path) -> Path: | |
| r = requests.get(url, timeout=60) | |
| r.raise_for_status() | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(out_path, "wb") as f: | |
| f.write(r.content) | |
| return out_path | |
| def poll_for_result(task_id: str, headers: dict, timeout_s: int = 300, interval_s: float = 3.0) -> dict: | |
| """Poll task endpoint first; fallback to listing.""" | |
| deadline = time.time() + timeout_s | |
| last_error = None | |
| while time.time() < deadline: | |
| try: | |
| # Preferred: direct task status | |
| r = requests.get(ENDPOINT_TASK(task_id), headers=headers, timeout=30) | |
| if r.status_code == 200: | |
| j = r.json() | |
| # Either "status_name":"succeed" or "data.task_status":"succeed" | |
| status_name = (j.get("status_name") | |
| or (j.get("data") or {}).get("task_status") | |
| or (j.get("task") or {}).get("status_name")) | |
| if isinstance(status_name, dict): | |
| # Some SDKs wrap status as enum-like | |
| status_name = status_name.get("value") | |
| if status_name in ("succeed", "succeeded", "success", "SUCCEED"): | |
| return j | |
| if status_name in ("failed", "FAIL", "failed_with_error"): | |
| return j | |
| elif r.status_code in (401, 403, 404): | |
| last_error = r.text | |
| # Fallback: scan generations list | |
| r2 = requests.get(ENDPOINT_GENERATIONS, headers=headers, params={"pageSize": 200}, timeout=30) | |
| if r2.status_code == 200: | |
| j2 = r2.json() | |
| for item in (j2.get("data") or []): | |
| if str(item.get("task_id")) == str(task_id): | |
| status = item.get("task_status") | |
| if status in ("succeed", "succeeded", "success"): | |
| return item | |
| if status in ("failed",): | |
| return item | |
| except requests.RequestException as e: | |
| last_error = str(e) | |
| time.sleep(interval_s) | |
| raise TimeoutError(f"Polling timed out for task_id {task_id}. Last error: {last_error or 'n/a'}") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CORE CALL β Kolors face reference (single reference, faceStrength=97) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def kling_face_reference(image_path: str, prompt: str, face_strength: int = 97, aspect_ratio: str = "1:1") -> tuple[str, str]: | |
| """ | |
| Returns (display_image_path, download_file_path) | |
| """ | |
| if not image_path: | |
| raise gr.Error("Please upload a face/reference image.") | |
| ensure_image_ok(image_path) | |
| token = make_jwt() | |
| headers_json = { | |
| "Authorization": f"Bearer {token}", | |
| "Content-Type": "application/json", | |
| } | |
| headers_multipart = { | |
| "Authorization": f"Bearer {token}", | |
| } | |
| # First try: multipart/form-data (send file as `imageReference`) | |
| data_multipart = { | |
| "prompt": (None, prompt), | |
| "reference": (None, "face"), | |
| "faceStrength": (None, str(max(1, min(100, int(face_strength))))), | |
| "faceNo": (None, "1"), # single face reference | |
| "imageCount": (None, "1"), | |
| "aspect_ratio": (None, aspect_ratio), | |
| } | |
| files = { | |
| "imageReference": (os.path.basename(image_path), open(image_path, "rb"), | |
| "image/png" if image_path.lower().endswith(".png") else "image/jpeg") | |
| } | |
| # Attempt 1 β multipart | |
| try: | |
| resp = requests.post(ENDPOINT_KOLORS, headers=headers_multipart, files=files, data=data_multipart, timeout=60) | |
| if resp.status_code == 200: | |
| j = resp.json() | |
| else: | |
| # Read JSON anyway if possible | |
| try: | |
| j = resp.json() | |
| except Exception: | |
| j = {"code": resp.status_code, "message": resp.text} | |
| finally: | |
| # Close file handle if opened | |
| try: | |
| files["imageReference"][1].close() | |
| except Exception: | |
| pass | |
| task_id = extract_task_id(j) | |
| # If Kolors rejected multipart or no task_id, try JSON with data URI | |
| if not task_id: | |
| payload = { | |
| "prompt": prompt, | |
| "reference": "face", | |
| "faceStrength": max(1, min(100, int(face_strength))), | |
| "faceNo": 1, | |
| "imageCount": 1, | |
| "aspect_ratio": aspect_ratio, | |
| "imageReference": b64_data_uri(image_path), | |
| } | |
| resp2 = requests.post(ENDPOINT_KOLORS, headers=headers_json, json=payload, timeout=60) | |
| try: | |
| j = resp2.json() | |
| except Exception: | |
| j = {"code": resp2.status_code, "message": resp2.text} | |
| task_id = extract_task_id(j) | |
| if not task_id: | |
| code = j.get("code") or j.get("service_code") or "?" | |
| msg = j.get("message") or j.get("error") or f"HTTP {resp.status_code if 'resp' in locals() else '?'}" | |
| raise gr.Error(f"Create task failed. Code: {code}. Message: {msg}") | |
| # Poll | |
| result_json = poll_for_result(task_id, headers=headers_json, timeout_s=420, interval_s=3.0) | |
| # Gather image URLs | |
| urls = extract_image_urls(result_json) | |
| if not urls: | |
| # Some APIs return the latest object on /v1/images/generations with same task_id | |
| try: | |
| listing = requests.get(ENDPOINT_GENERATIONS, headers=headers_json, params={"pageSize": 200}, timeout=30).json() | |
| for item in (listing.get("data") or []): | |
| if str(item.get("task_id")) == str(task_id): | |
| urls = extract_image_urls(item) | |
| if urls: | |
| break | |
| except Exception: | |
| pass | |
| if not urls: | |
| raise gr.Error(f"Task {task_id} succeeded but no image URL found in response.") | |
| # Download first image | |
| out_dir = Path("outputs") | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| out_path = out_dir / f"kling_face_{task_id}.png" | |
| download_to_file(urls[0], out_path) | |
| # Return same path for preview and download | |
| return str(out_path), str(out_path) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # GRADIO UI | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(title="Kling AI β Image to Image (Face Reference)") as demo: | |
| gr.Markdown("### Kling AI β Image-to-Image (Single Face Reference)\nUpload a face image and a prompt. Strength defaults to 97.") | |
| with gr.Row(): | |
| in_image = gr.Image(type="filepath", label="Reference Face Image (PNG/JPG)") | |
| in_prompt = gr.Textbox(label="Prompt", placeholder="e.g., Ultra-detailed portrait, soft light, studio background", lines=2) | |
| with gr.Row(): | |
| in_strength = gr.Slider(1, 100, value=97, step=1, label="Face Reference Strength") | |
| in_aspect = gr.Dropdown(choices=["1:1", "3:4", "4:3", "2:3", "3:2", "16:9", "9:16", "21:9"], value="1:1", label="Aspect Ratio") | |
| btn = gr.Button("Generate", variant="primary") | |
| out_img = gr.Image(label="Generated Image", show_download_button=False) | |
| out_file = gr.File(label="Download Image") | |
| def run(image, prompt, strength, aspect): | |
| if not prompt or not prompt.strip(): | |
| raise gr.Error("Please enter a prompt.") | |
| return kling_face_reference(image, prompt.strip(), int(strength), aspect) | |
| btn.click(fn=run, inputs=[in_image, in_prompt, in_strength, in_aspect], outputs=[out_img, out_file]) | |
| if __name__ == "__main__": | |
| # On HF Spaces, just `python app.py` is enough β no need to set host/port. | |
| demo.launch() | |