KLING-Virtual-Tryon / virtual_tryon.py
Abhlash's picture
Create virtual_tryon.py
c65d70c verified
from PIL import Image
import requests
import io
import base64
import jwt
import time
import logging
import sys
import asyncio
from requests.exceptions import RequestException
# Set up logging
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('virtual_tryon.log'),
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
# Constants
VALID_CLOTH_TYPES = ["upper", "lower", "full"]
VALID_IMAGE_SIZES = ["256x256", "512x512", "768x768"]
DEFAULT_IMAGE_SIZE = "512x512"
DEFAULT_NUM_STEPS = 30
DEFAULT_GUIDANCE_SCALE = 7.5
DEFAULT_SEED = 42
API_BASE_URL = "https://api.klingai.com"
def generate_api_token(access_key, secret_key):
"""Generate JWT token for API authentication"""
try:
current_time = int(time.time())
payload = {
"iss": access_key,
"exp": current_time + 1800, # 30 minutes expiration
"nbf": current_time
}
logger.debug(f"Generating token with payload: {payload}")
token = jwt.encode(payload, secret_key, algorithm="HS256")
logger.debug("Token generated successfully")
return token
except Exception as e:
logger.error(f"Error generating token: {str(e)}")
raise
def encode_image_to_base64(image):
"""Convert PIL Image to base64 string"""
try:
if isinstance(image, Image.Image):
buffered = io.BytesIO()
image.save(buffered, format="PNG")
base64_string = base64.b64encode(buffered.getvalue()).decode('utf-8')
logger.debug(f"Image encoded to base64 successfully. Length: {len(base64_string)}")
return base64_string
logger.error("Input is not a PIL Image")
return None
except Exception as e:
logger.error(f"Error encoding image to base64: {str(e)}")
return None
async def check_task_status(task_id, access_key, secret_key):
"""Check the status of a task"""
max_attempts = 3
wait_interval = 20
attempt = 1
while attempt <= max_attempts:
await asyncio.sleep(wait_interval)
logger.info(f"Checking task status (Attempt {attempt}/{max_attempts})...")
try:
# Generate new token for status check
token = generate_api_token(access_key, secret_key)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {token}"
}
# Status check endpoint
url = f"{API_BASE_URL}/v1/images/kolors-virtual-try-on/{task_id}"
response = requests.get(url, headers=headers, verify=False)
logger.debug(f"Status check response: {response.text}")
result = response.json()
if response.status_code == 200 and result.get('code') == 0:
data = result.get('data', {})
task_status = data.get('task_status', '').lower()
if task_status in ['completed', 'succeed']:
images = data.get('task_result', {}).get('images', [])
if images:
image_url = images[0].get('url')
return None, image_url
else:
return "No images found in the task result.", None
elif task_status in ['failed', 'error']:
error_message = data.get('task_status_msg', 'Task failed.')
return f"Task failed: {error_message}", None
else:
logger.info(f"Task status: {task_status}. Waiting for next attempt...")
else:
error_message = result.get('message', 'Unknown error occurred.')
logger.error(f"Error fetching task status: {error_message}")
except Exception as e:
logger.error(f"Error checking task status: {str(e)}")
attempt += 1
return "Task did not complete within the expected time.", None
async def apply_virtual_tryon_async(
person_image,
garment_image,
access_key,
secret_key
):
"""Apply virtual try-on using Kling API asynchronously"""
try:
logger.info("Starting virtual try-on process")
# Generate API token
jwt_token = generate_api_token(access_key, secret_key)
if not jwt_token:
return None, "Failed to generate JWT token"
# Ensure token is string
if isinstance(jwt_token, bytes):
jwt_token = jwt_token.decode('utf-8')
# Prepare images
logger.debug("Preparing images")
person_base64 = encode_image_to_base64(person_image)
garment_base64 = encode_image_to_base64(garment_image)
if not person_base64 or not garment_base64:
logger.error("Failed to convert images to base64")
return None, "Error converting images to base64"
# Prepare request
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {jwt_token}"
}
# Payload structure
payload = {
"model_name": "kolors-virtual-try-on-v1",
"human_image": person_base64,
"cloth_image": garment_base64
}
# Submit task
url = f"{API_BASE_URL}/v1/images/kolors-virtual-try-on"
logger.debug(f"Making API request to {url}")
response = requests.post(url, headers=headers, json=payload, verify=False)
result = response.json()
if response.status_code == 200 and result.get('code') == 0:
task_id = result.get('data', {}).get('task_id')
if not task_id:
return None, "No task ID received"
logger.info(f"Task submitted successfully. Task ID: {task_id}")
# Check task status
error_message, image_url = await check_task_status(task_id, access_key, secret_key)
if error_message:
return None, error_message
# Download result image
try:
image_response = requests.get(image_url)
if image_response.status_code == 200:
return Image.open(io.BytesIO(image_response.content)), "Success"
else:
return None, f"Failed to download result image: {image_response.status_code}"
except Exception as e:
return None, f"Error downloading result image: {str(e)}"
else:
error_msg = result.get('message', 'Unknown error')
logger.error(f"API Error: {error_msg}")
return None, f"API Error: {error_msg}"
except Exception as e:
logger.error(f"Unexpected Error: {str(e)}")
return None, f"Error: {str(e)}"
def apply_virtual_tryon(
person_image,
garment_image,
access_key,
secret_key,
cloth_type="upper",
image_size="512x512",
num_steps=DEFAULT_NUM_STEPS,
guidance_scale=DEFAULT_GUIDANCE_SCALE,
seed=DEFAULT_SEED
):
"""Synchronous wrapper for async virtual try-on function"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
apply_virtual_tryon_async(
person_image,
garment_image,
access_key,
secret_key
)
)
finally:
loop.close()