Spaces:
Running
Running
| import logging | |
| from fastapi import FastAPI, UploadFile, File, HTTPException, Form | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse | |
| import requests | |
| import base64 | |
| import os | |
| import time | |
| import jwt | |
| from pathlib import Path | |
| from typing import List | |
| import io | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize FastAPI app | |
| app = FastAPI(title="Kling AI Multi-Image Generator API") | |
| # Enable CORS for the frontend | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["https://hivili.web.app"], # Allow your frontend | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ===== API CONFIGURATION ===== | |
| ACCESS_KEY_ID = os.getenv("ACCESS_KEY_ID", "AFyHfnQATghFdCMyAG3gRPbNY4TNKFGB") | |
| ACCESS_KEY_SECRET = os.getenv("ACCESS_KEY_SECRET", "TTepeLyBterLNM3brYPGmdndBnnyKJBA") | |
| API_BASE_URL = "https://api-singapore.klingai.com" | |
| CREATE_TASK_ENDPOINT = f"{API_BASE_URL}/v1/images/multi-image2image" | |
| # ===== AUTHENTICATION ===== | |
| def generate_jwt_token(): | |
| """Generate JWT token for API authentication""" | |
| payload = { | |
| "iss": ACCESS_KEY_ID, | |
| "exp": int(time.time()) + 1800, # 30 minutes expiration | |
| "nbf": int(time.time()) - 5 # Not before 5 seconds ago | |
| } | |
| return jwt.encode(payload, ACCESS_KEY_SECRET, algorithm="HS256") | |
| # ===== IMAGE PROCESSING ===== | |
| def prepare_image_base64(image_content: bytes): | |
| """Convert image bytes to base64 without prefix""" | |
| try: | |
| return base64.b64encode(image_content).decode('utf-8') | |
| except Exception as e: | |
| logger.error(f"Image processing failed: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Image processing failed: {str(e)}") | |
| def validate_image(image_content: bytes): | |
| """Validate image meets API requirements""" | |
| try: | |
| size_mb = len(image_content) / (1024 * 1024) | |
| if size_mb > 10: | |
| raise HTTPException(status_code=400, detail="Image too large (max 10MB)") | |
| return True, "" | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Image validation error: {str(e)}") | |
| # ===== API FUNCTIONS ===== | |
| def create_multi_image_task(subject_images: List[bytes], prompt: str): | |
| """Create multi-image generation task""" | |
| headers = { | |
| "Authorization": f"Bearer {generate_jwt_token()}", | |
| "Content-Type": "application/json" | |
| } | |
| subject_image_list = [] | |
| for img_content in subject_images: | |
| if img_content: | |
| base64_img = prepare_image_base64(img_content) | |
| if base64_img: | |
| subject_image_list.append({"subject_image": base64_img}) | |
| if len(subject_image_list) < 2: | |
| raise HTTPException(status_code=400, detail="At least 2 subject images required") | |
| payload = { | |
| "model_name": "kling-v2", | |
| "prompt": prompt, | |
| "subject_image_list": subject_image_list, | |
| "n": 1, | |
| "aspect_ratio": "1:1" | |
| } | |
| try: | |
| response = requests.post(CREATE_TASK_ENDPOINT, json=payload, headers=headers) | |
| response.raise_for_status() | |
| return response.json() | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"API request failed: {str(e)}") | |
| if hasattr(e, 'response') and e.response: | |
| logger.error(f"API response: {e.response.text}") | |
| raise HTTPException(status_code=500, detail=f"API Error: {str(e)}") | |
| def check_task_status(task_id: str): | |
| """Check task completion status""" | |
| headers = {"Authorization": f"Bearer {generate_jwt_token()}"} | |
| status_url = f"{API_BASE_URL}/v1/images/multi-image2image/{task_id}" | |
| try: | |
| response = requests.get(status_url, headers=headers) | |
| response.raise_for_status() | |
| return response.json() | |
| except requests.exceptions.RequestException as e: | |
| raise HTTPException(status_code=500, detail=f"Status check failed: {str(e)}") | |
| # ===== MAIN PROCESSING ===== | |
| async def generate_image(subject_images: List[bytes], prompt: str): | |
| """Handle complete image generation workflow""" | |
| for img_content in subject_images: | |
| if img_content: | |
| validate_image(img_content) | |
| task_response = create_multi_image_task(subject_images, prompt) | |
| if task_response.get("code") != 0: | |
| raise HTTPException(status_code=500, detail=f"API error: {task_response.get('message', 'Unknown error')}") | |
| task_id = task_response["data"]["task_id"] | |
| logger.info(f"Task created: {task_id}") | |
| for _ in range(60): | |
| task_data = check_task_status(task_id) | |
| status = task_data["data"]["task_status"] | |
| if status == "succeed": | |
| image_url = task_data["data"]["task_result"]["images"][0]["url"] | |
| try: | |
| response = requests.get(image_url) | |
| response.raise_for_status() | |
| output_dir = Path("/tmp") | |
| output_dir.mkdir(exist_ok=True) | |
| output_path = output_dir / f"kling_output_{task_id}.png" | |
| with open(output_path, "wb") as f: | |
| f.write(response.content) | |
| return output_path | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Failed to download result: {str(e)}") | |
| elif status in ("failed", "canceled"): | |
| error_msg = task_data["data"].get("task_status_msg", "Unknown error") | |
| raise HTTPException(status_code=500, detail=f"Task failed: {error_msg}") | |
| time.sleep(10) | |
| raise HTTPException(status_code=500, detail="Task timed out after 10 minutes") | |
| # ===== API ENDPOINTS ===== | |
| async def generate_image_endpoint( | |
| prompt: str = Form(...), | |
| images: List[UploadFile] = File(...) | |
| ): | |
| """Endpoint to generate an image from multiple input images and a prompt""" | |
| try: | |
| if len(images) < 2: | |
| raise HTTPException(status_code=400, detail="At least 2 images are required") | |
| if len(images) > 4: | |
| raise HTTPException(status_code=400, detail="Maximum 4 images allowed") | |
| image_contents = [await image.read() for image in images] | |
| output_path = await generate_image(image_contents, prompt) | |
| return FileResponse( | |
| path=output_path, | |
| media_type="image/png", | |
| filename=f"kling_output_{Path(output_path).stem}.png" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in /generate: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def index(): | |
| return {"status": "Kling AI Multi-Image Generator API is running"} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |