Spaces:
Running
Running
| import logging | |
| from fastapi import FastAPI, UploadFile, File, HTTPException, Form, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse, JSONResponse | |
| import requests | |
| import base64 | |
| import os | |
| import time | |
| import jwt | |
| from pathlib import Path | |
| from typing import List | |
| import io | |
| import razorpay | |
| from razorpay.errors import SignatureVerificationError | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize FastAPI app | |
| app = FastAPI(title="Kling AI Multi-Image Generator API with Razorpay") | |
| # Enable CORS for the frontend | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["https://hivili.web.app"], # Update with your frontend URL | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ===== API CONFIGURATION ===== | |
| ACCESS_KEY_ID = os.getenv("ACCESS_KEY_ID", "AFyHfnQATghFdCMyAG3gRPbNY4TNKFGB") | |
| ACCESS_KEY_SECRET = os.getenv("ACCESS_KEY_SECRET", "TTepeLyBterLNM3brYPGmdndBnnyKJBA") | |
| API_BASE_URL = "https://api-singapore.klingai.com" | |
| CREATE_TASK_ENDPOINT = f"{API_BASE_URL}/v1/images/multi-image2image" | |
| # ===== RAZORPAY CONFIGURATION ===== | |
| RAZORPAY_KEY_ID = os.getenv("RAZORPAY_KEY_ID") # Set in environment variables | |
| RAZORPAY_KEY_SECRET = os.getenv("RAZORPAY_KEY_SECRET") # Set in environment variables | |
| razorpay_client = razorpay.Client(auth=(RAZORPAY_KEY_ID, RAZORPAY_KEY_SECRET)) | |
| # ===== AUTHENTICATION ===== | |
| def generate_jwt_token(): | |
| """Generate JWT token for API authentication""" | |
| payload = { | |
| "iss": ACCESS_KEY_ID, | |
| "exp": int(time.time()) + 1800, # 30 minutes expiration | |
| "nbf": int(time.time()) - 5 # Not before 5 seconds ago | |
| } | |
| return jwt.encode(payload, ACCESS_KEY_SECRET, algorithm="HS256") | |
| # ===== IMAGE PROCESSING ===== | |
| def prepare_image_base64(image_content: bytes): | |
| """Convert image bytes to base64 without prefix""" | |
| try: | |
| return base64.b64encode(image_content).decode('utf-8') | |
| except Exception as e: | |
| logger.error(f"Image processing failed: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Image processing failed: {str(e)}") | |
| def validate_image(image_content: bytes): | |
| """Validate image meets API requirements""" | |
| try: | |
| size_mb = len(image_content) / (1024 * 1024) | |
| if size_mb > 10: | |
| raise HTTPException(status_code=400, detail="Image too large (max 10MB)") | |
| return True, "" | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Image validation error: {str(e)}") | |
| # ===== API FUNCTIONS ===== | |
| def create_multi_image_task(subject_images: List[bytes], prompt: str): | |
| """Create multi-image generation task""" | |
| headers = { | |
| "Authorization": f"Bearer {generate_jwt_token()}", | |
| "Content-Type": "application/json" | |
| } | |
| subject_image_list = [] | |
| for img_content in subject_images: | |
| if img_content: | |
| base64_img = prepare_image_base64(img_content) | |
| if base64_img: | |
| subject_image_list.append({"subject_image": base64_img}) | |
| if len(subject_image_list) < 2: | |
| raise HTTPException(status_code=400, detail="At least 2 subject images required") | |
| payload = { | |
| "model_name": "kling-v2", | |
| "prompt": prompt, | |
| "subject_image_list": subject_image_list, | |
| "n": 1, | |
| "aspect_ratio": "1:1" | |
| } | |
| try: | |
| response = requests.post(CREATE_TASK_ENDPOINT, json=payload, headers=headers) | |
| response.raise_for_status() | |
| return response.json() | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"API request failed: {str(e)}") | |
| if hasattr(e, 'response') and e.response: | |
| logger.error(f"API response: {e.response.text}") | |
| raise HTTPException(status_code=500, detail=f"API Error: {str(e)}") | |
| def check_task_status(task_id: str): | |
| """Check task completion status""" | |
| headers = {"Authorization": f"Bearer {generate_jwt_token()}"} | |
| status_url = f"{API_BASE_URL}/v1/images/multi-image2image/{task_id}" | |
| try: | |
| response = requests.get(status_url, headers=headers) | |
| response.raise_for_status() | |
| return response.json() | |
| except requests.exceptions.RequestException as e: | |
| raise HTTPException(status_code=500, detail=f"Status check failed: {str(e)}") | |
| # ===== RAZORPAY FUNCTIONS ===== | |
| def create_razorpay_order(amount: int, currency: str = "INR"): | |
| """Create a Razorpay order""" | |
| try: | |
| order_data = { | |
| "amount": amount * 100, # Amount in paise (e.g., 500 INR = 50000 paise) | |
| "currency": currency, | |
| "payment_capture": 1 # Auto-capture payment | |
| } | |
| order = razorpay_client.order.create(data=order_data) | |
| return order | |
| except Exception as e: | |
| logger.error(f"Failed to create Razorpay order: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Failed to create order: {str(e)}") | |
| def verify_payment_signature(order_id: str, payment_id: str, signature: str): | |
| """Verify Razorpay payment signature""" | |
| try: | |
| params_dict = { | |
| "razorpay_order_id": order_id, | |
| "razorpay_payment_id": payment_id, | |
| "razorpay_signature": signature | |
| } | |
| razorpay_client.utility.verify_payment_signature(params_dict) | |
| return True | |
| except SignatureVerificationError as e: | |
| logger.error(f"Payment signature verification failed: {str(e)}") | |
| return False | |
| except Exception as e: | |
| logger.error(f"Error verifying payment signature: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Verification error: {str(e)}") | |
| # ===== MAIN PROCESSING ===== | |
| async def generate_image(subject_images: List[bytes], prompt: str): | |
| """Handle complete image generation workflow""" | |
| for img_content in subject_images: | |
| if img_content: | |
| validate_image(img_content) | |
| task_response = create_multi_image_task(subject_images, prompt) | |
| if task_response.get("code") != 0: | |
| raise HTTPException(status_code=500, detail=f"API error: {task_response.get('message', 'Unknown error')}") | |
| task_id = task_response["data"]["task_id"] | |
| logger.info(f"Task created: {task_id}") | |
| for _ in range(60): | |
| task_data = check_task_status(task_id) | |
| status = task_data["data"]["task_status"] | |
| if status == "succeed": | |
| image_url = task_data["data"]["task_result"]["images"][0]["url"] | |
| try: | |
| response = requests.get(image_url) | |
| response.raise_for_status() | |
| output_dir = Path("/tmp") | |
| output_dir.mkdir(exist_ok=True) | |
| output_path = output_dir / f"kling_output_{task_id}.png" | |
| with open(output_path, "wb") as f: | |
| f.write(response.content) | |
| return output_path | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Failed to download result: {str(e)}") | |
| elif status in ("failed", "canceled"): | |
| error_msg = task_data["data"].get("task_status_msg", "Unknown error") | |
| raise HTTPException(status_code=500, detail=f"Task failed: {error_msg}") | |
| time.sleep(10) | |
| raise HTTPException(status_code=500, detail="Task timed out after 10 minutes") | |
| # ===== API ENDPOINTS ===== | |
| async def generate_image_endpoint( | |
| prompt: str = Form(...), | |
| images: List[UploadFile] = File(...) | |
| ): | |
| """Endpoint to generate an image from multiple input images and a prompt""" | |
| try: | |
| if len(images) < 2: | |
| raise HTTPException(status_code=400, detail="At least 2 images are required") | |
| if len(images) > 4: | |
| raise HTTPException(status_code=400, detail="Maximum 4 images allowed") | |
| image_contents = [await image.read() for image in images] | |
| output_path = await generate_image(image_contents, prompt) | |
| return FileResponse( | |
| path=output_path, | |
| media_type="image/png", | |
| filename=f"kling_output_{Path(output_path).stem}.png" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in /generate: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def create_order(amount: int = Form(...)): | |
| """Create a Razorpay order for payment""" | |
| if not RAZORPAY_KEY_ID or not RAZORPAY_KEY_SECRET: | |
| raise HTTPException(status_code=500, detail="Razorpay configuration missing") | |
| try: | |
| order = create_razorpay_order(amount) | |
| return JSONResponse(content={ | |
| "order_id": order["id"], | |
| "amount": order["amount"], | |
| "currency": order["currency"], | |
| "key_id": RAZORPAY_KEY_ID | |
| }) | |
| except Exception as e: | |
| logger.error(f"Error creating order: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def verify_payment( | |
| razorpay_order_id: str = Form(...), | |
| razorpay_payment_id: str = Form(...), | |
| razorpay_signature: str = Form(...) | |
| ): | |
| """Verify Razorpay payment signature""" | |
| try: | |
| is_valid = verify_payment_signature(razorpay_order_id, razorpay_payment_id, razorpay_signature) | |
| if is_valid: | |
| # Here you can update user subscription status in your database | |
| # For example, update a user's plan to "premium" or increment credits | |
| return JSONResponse(content={"status": "success", "message": "Payment verified successfully"}) | |
| else: | |
| raise HTTPException(status_code=400, detail="Payment verification failed") | |
| except Exception as e: | |
| logger.error(f"Error verifying payment: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def index(): | |
| return {"status": "Kling AI Multi-Image Generator API with Razorpay is running"} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |