Spaces:
Running
Running
| from PIL import Image | |
| import requests | |
| import io | |
| import base64 | |
| import jwt | |
| import time | |
| import logging | |
| import sys | |
| import asyncio | |
| from requests.exceptions import RequestException | |
| # Set up logging | |
| logging.basicConfig( | |
| level=logging.DEBUG, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.FileHandler('virtual_tryon.log'), | |
| logging.StreamHandler(sys.stdout) | |
| ] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Constants | |
| VALID_CLOTH_TYPES = ["upper", "lower", "full"] | |
| VALID_IMAGE_SIZES = ["256x256", "512x512", "768x768"] | |
| DEFAULT_IMAGE_SIZE = "512x512" | |
| DEFAULT_NUM_STEPS = 30 | |
| DEFAULT_GUIDANCE_SCALE = 7.5 | |
| DEFAULT_SEED = 42 | |
| API_BASE_URL = "https://api.klingai.com" | |
| def generate_api_token(access_key, secret_key): | |
| """Generate JWT token for API authentication""" | |
| try: | |
| current_time = int(time.time()) | |
| payload = { | |
| "iss": access_key, | |
| "exp": current_time + 1800, # 30 minutes expiration | |
| "nbf": current_time | |
| } | |
| logger.debug(f"Generating token with payload: {payload}") | |
| token = jwt.encode(payload, secret_key, algorithm="HS256") | |
| logger.debug("Token generated successfully") | |
| return token | |
| except Exception as e: | |
| logger.error(f"Error generating token: {str(e)}") | |
| raise | |
| def encode_image_to_base64(image): | |
| """Convert PIL Image to base64 string""" | |
| try: | |
| if isinstance(image, Image.Image): | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="PNG") | |
| base64_string = base64.b64encode(buffered.getvalue()).decode('utf-8') | |
| logger.debug(f"Image encoded to base64 successfully. Length: {len(base64_string)}") | |
| return base64_string | |
| logger.error("Input is not a PIL Image") | |
| return None | |
| except Exception as e: | |
| logger.error(f"Error encoding image to base64: {str(e)}") | |
| return None | |
| async def check_task_status(task_id, access_key, secret_key): | |
| """Check the status of a task""" | |
| max_attempts = 3 | |
| wait_interval = 20 | |
| attempt = 1 | |
| while attempt <= max_attempts: | |
| await asyncio.sleep(wait_interval) | |
| logger.info(f"Checking task status (Attempt {attempt}/{max_attempts})...") | |
| try: | |
| # Generate new token for status check | |
| token = generate_api_token(access_key, secret_key) | |
| headers = { | |
| "Content-Type": "application/json", | |
| "Authorization": f"Bearer {token}" | |
| } | |
| # Status check endpoint | |
| url = f"{API_BASE_URL}/v1/images/kolors-virtual-try-on/{task_id}" | |
| response = requests.get(url, headers=headers, verify=False) | |
| logger.debug(f"Status check response: {response.text}") | |
| result = response.json() | |
| if response.status_code == 200 and result.get('code') == 0: | |
| data = result.get('data', {}) | |
| task_status = data.get('task_status', '').lower() | |
| if task_status in ['completed', 'succeed']: | |
| images = data.get('task_result', {}).get('images', []) | |
| if images: | |
| image_url = images[0].get('url') | |
| return None, image_url | |
| else: | |
| return "No images found in the task result.", None | |
| elif task_status in ['failed', 'error']: | |
| error_message = data.get('task_status_msg', 'Task failed.') | |
| return f"Task failed: {error_message}", None | |
| else: | |
| logger.info(f"Task status: {task_status}. Waiting for next attempt...") | |
| else: | |
| error_message = result.get('message', 'Unknown error occurred.') | |
| logger.error(f"Error fetching task status: {error_message}") | |
| except Exception as e: | |
| logger.error(f"Error checking task status: {str(e)}") | |
| attempt += 1 | |
| return "Task did not complete within the expected time.", None | |
| async def apply_virtual_tryon_async( | |
| person_image, | |
| garment_image, | |
| access_key, | |
| secret_key | |
| ): | |
| """Apply virtual try-on using Kling API asynchronously""" | |
| try: | |
| logger.info("Starting virtual try-on process") | |
| # Generate API token | |
| jwt_token = generate_api_token(access_key, secret_key) | |
| if not jwt_token: | |
| return None, "Failed to generate JWT token" | |
| # Ensure token is string | |
| if isinstance(jwt_token, bytes): | |
| jwt_token = jwt_token.decode('utf-8') | |
| # Prepare images | |
| logger.debug("Preparing images") | |
| person_base64 = encode_image_to_base64(person_image) | |
| garment_base64 = encode_image_to_base64(garment_image) | |
| if not person_base64 or not garment_base64: | |
| logger.error("Failed to convert images to base64") | |
| return None, "Error converting images to base64" | |
| # Prepare request | |
| headers = { | |
| "Content-Type": "application/json", | |
| "Authorization": f"Bearer {jwt_token}" | |
| } | |
| # Payload structure | |
| payload = { | |
| "model_name": "kolors-virtual-try-on-v1", | |
| "human_image": person_base64, | |
| "cloth_image": garment_base64 | |
| } | |
| # Submit task | |
| url = f"{API_BASE_URL}/v1/images/kolors-virtual-try-on" | |
| logger.debug(f"Making API request to {url}") | |
| response = requests.post(url, headers=headers, json=payload, verify=False) | |
| result = response.json() | |
| if response.status_code == 200 and result.get('code') == 0: | |
| task_id = result.get('data', {}).get('task_id') | |
| if not task_id: | |
| return None, "No task ID received" | |
| logger.info(f"Task submitted successfully. Task ID: {task_id}") | |
| # Check task status | |
| error_message, image_url = await check_task_status(task_id, access_key, secret_key) | |
| if error_message: | |
| return None, error_message | |
| # Download result image | |
| try: | |
| image_response = requests.get(image_url) | |
| if image_response.status_code == 200: | |
| return Image.open(io.BytesIO(image_response.content)), "Success" | |
| else: | |
| return None, f"Failed to download result image: {image_response.status_code}" | |
| except Exception as e: | |
| return None, f"Error downloading result image: {str(e)}" | |
| else: | |
| error_msg = result.get('message', 'Unknown error') | |
| logger.error(f"API Error: {error_msg}") | |
| return None, f"API Error: {error_msg}" | |
| except Exception as e: | |
| logger.error(f"Unexpected Error: {str(e)}") | |
| return None, f"Error: {str(e)}" | |
| def apply_virtual_tryon( | |
| person_image, | |
| garment_image, | |
| access_key, | |
| secret_key, | |
| cloth_type="upper", | |
| image_size="512x512", | |
| num_steps=DEFAULT_NUM_STEPS, | |
| guidance_scale=DEFAULT_GUIDANCE_SCALE, | |
| seed=DEFAULT_SEED | |
| ): | |
| """Synchronous wrapper for async virtual try-on function""" | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| try: | |
| return loop.run_until_complete( | |
| apply_virtual_tryon_async( | |
| person_image, | |
| garment_image, | |
| access_key, | |
| secret_key | |
| ) | |
| ) | |
| finally: | |
| loop.close() |