Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import requests | |
| import base64 | |
| import os | |
| import time | |
| import jwt | |
| import logging | |
| from pathlib import Path | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ===== API CONFIGURATION ===== | |
| API_BASE_URL = "https://api-singapore.klingai.com" | |
| CREATE_TASK_ENDPOINT = f"{API_BASE_URL}/v1/images/multi-image2image" | |
| # ===== AUTHENTICATION ===== | |
| def generate_jwt_token(): | |
| """Generate JWT token for API authentication""" | |
| payload = { | |
| "iss": ACCESS_KEY_ID, | |
| "exp": int(time.time()) + 1800, # 30 minutes expiration | |
| "nbf": int(time.time()) - 5 # Not before 5 seconds ago | |
| } | |
| return jwt.encode(payload, ACCESS_KEY_SECRET, algorithm="HS256") | |
| # ===== IMAGE PROCESSING ===== | |
| def prepare_image_base64(image_path): | |
| """Convert image to base64 without prefix""" | |
| try: | |
| with open(image_path, "rb") as img_file: | |
| return base64.b64encode(img_file.read()).decode('utf-8') | |
| except Exception as e: | |
| logger.error(f"Image processing failed: {str(e)}") | |
| return None | |
| def validate_image(image_path): | |
| """Validate image meets API requirements""" | |
| try: | |
| # Check file size | |
| size_mb = os.path.getsize(image_path) / (1024 * 1024) | |
| if size_mb > 10: | |
| return False, "Image too large (max 10MB)" | |
| # Check dimensions (basic check - should use PIL for actual dimensions) | |
| return True, "" | |
| except Exception as e: | |
| return False, f"Image validation error: {str(e)}" | |
| # ===== API FUNCTIONS ===== | |
| def create_multi_image_task(subject_images, prompt): | |
| """Create multi-image generation task""" | |
| headers = { | |
| "Authorization": f"Bearer {generate_jwt_token()}", | |
| "Content-Type": "application/json" | |
| } | |
| # Prepare subject images list | |
| subject_image_list = [] | |
| for img_path in subject_images: | |
| if img_path: # Skip empty/None images | |
| base64_img = prepare_image_base64(img_path) | |
| if base64_img: | |
| subject_image_list.append({"subject_image": base64_img}) | |
| if len(subject_image_list) < 2: | |
| return None, "At least 2 subject images required" | |
| payload = { | |
| "model_name": "kling-v2", | |
| "prompt": prompt, | |
| "subject_image_list": subject_image_list, | |
| "n": 1, | |
| "aspect_ratio": "1:1" | |
| } | |
| try: | |
| response = requests.post(CREATE_TASK_ENDPOINT, json=payload, headers=headers) | |
| response.raise_for_status() | |
| return response.json(), None | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"API request failed: {str(e)}") | |
| if hasattr(e, 'response') and e.response: | |
| logger.error(f"API response: {e.response.text}") | |
| return None, f"API Error: {str(e)}" | |
| def check_task_status(task_id): | |
| """Check task completion status""" | |
| headers = {"Authorization": f"Bearer {generate_jwt_token()}"} | |
| status_url = f"{API_BASE_URL}/v1/images/multi-image2image/{task_id}" | |
| try: | |
| response = requests.get(status_url, headers=headers) | |
| response.raise_for_status() | |
| return response.json(), None | |
| except requests.exceptions.RequestException as e: | |
| return None, f"Status check failed: {str(e)}" | |
| # ===== MAIN PROCESSING ===== | |
| def generate_image(subject_images, prompt): | |
| """Handle complete image generation workflow""" | |
| # Validate images | |
| for img in subject_images: | |
| if img: # Only validate non-empty images | |
| is_valid, error_msg = validate_image(img) | |
| if not is_valid: | |
| return None, error_msg | |
| # Create task | |
| task_response, error = create_multi_image_task(subject_images, prompt) | |
| if error: | |
| return None, error | |
| if task_response.get("code") != 0: | |
| return None, f"API error: {task_response.get('message', 'Unknown error')}" | |
| task_id = task_response["data"]["task_id"] | |
| logger.info(f"Task created: {task_id}") | |
| # Poll for results (max 10 minutes) | |
| for _ in range(60): | |
| task_data, error = check_task_status(task_id) | |
| if error: | |
| return None, error | |
| status = task_data["data"]["task_status"] | |
| if status == "succeed": | |
| image_url = task_data["data"]["task_result"]["images"][0]["url"] | |
| try: | |
| response = requests.get(image_url) | |
| response.raise_for_status() | |
| output_path = Path(f"/tmp/kling_output_{task_id}.png") | |
| with open(output_path, "wb") as f: | |
| f.write(response.content) | |
| return str(output_path), None | |
| except Exception as e: | |
| return None, f"Failed to download result: {str(e)}" | |
| elif status in ("failed", "canceled"): | |
| error_msg = task_data["data"].get("task_status_msg", "Unknown error") | |
| return None, f"Task failed: {error_msg}" | |
| time.sleep(10) | |
| return None, "Task timed out after 10 minutes" | |
| # ===== GRADIO INTERFACE ===== | |
| def process_interface(subject_image1, subject_image2, subject_image3, subject_image4, prompt): | |
| # Filter out None/empty images | |
| subject_images = [img for img in [subject_image1, subject_image2, subject_image3, subject_image4] if img] | |
| if len(subject_images) < 2: | |
| return None, None, "Please upload at least 2 subject images" | |
| output_path, error = generate_image(subject_images, prompt) | |
| if error: | |
| return None, None, error | |
| return output_path, output_path, "Generation successful!" | |
| with gr.Blocks(title="Kling AI Multi-Image Generator") as app: | |
| gr.Markdown("## 🖼️ Kling AI Multi-Image to Image") | |
| gr.Markdown("Combine features from multiple images into one result") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Input Settings") | |
| with gr.Row(): | |
| subject_image1 = gr.Image(type="filepath", label="Subject Image 1 *") | |
| subject_image2 = gr.Image(type="filepath", label="Subject Image 2 *") | |
| with gr.Row(): | |
| subject_image3 = gr.Image(type="filepath", label="Subject Image 3 (Optional)") | |
| subject_image4 = gr.Image(type="filepath", label="Subject Image 4 (Optional)") | |
| prompt_input = gr.Textbox( | |
| label="Transformation Prompt", | |
| placeholder="Describe how to combine these images" | |
| ) | |
| generate_btn = gr.Button("Generate", variant="primary") | |
| gr.Markdown("### Requirements (* = required)") | |
| gr.Markdown(""" | |
| - **At least 2 subject images** (marked with *) | |
| - Max 4 images total | |
| - Max size: 10MB per image | |
| - Formats: JPG, PNG | |
| - Min dimensions: 300x300px | |
| """) | |
| with gr.Column(): | |
| gr.Markdown("### Output") | |
| output_image = gr.Image(label="Generated Image", interactive=False, height=400) | |
| output_file = gr.File(label="Download Result") | |
| status_output = gr.Textbox(label="Status", interactive=False) | |
| # Modified inputs to accept individual components | |
| generate_btn.click( | |
| fn=process_interface, | |
| inputs=[subject_image1, subject_image2, subject_image3, subject_image4, prompt_input], | |
| outputs=[output_image, output_file, status_output] | |
| ) | |
| if __name__ == "__main__": | |
| app.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False | |
| ) |