Spaces:
Running
Running
File size: 7,636 Bytes
c65d70c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 |
from PIL import Image
import requests
import io
import base64
import jwt
import time
import logging
import sys
import asyncio
from requests.exceptions import RequestException
# Set up logging
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('virtual_tryon.log'),
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
# Constants
VALID_CLOTH_TYPES = ["upper", "lower", "full"]
VALID_IMAGE_SIZES = ["256x256", "512x512", "768x768"]
DEFAULT_IMAGE_SIZE = "512x512"
DEFAULT_NUM_STEPS = 30
DEFAULT_GUIDANCE_SCALE = 7.5
DEFAULT_SEED = 42
API_BASE_URL = "https://api.klingai.com"
def generate_api_token(access_key, secret_key):
"""Generate JWT token for API authentication"""
try:
current_time = int(time.time())
payload = {
"iss": access_key,
"exp": current_time + 1800, # 30 minutes expiration
"nbf": current_time
}
logger.debug(f"Generating token with payload: {payload}")
token = jwt.encode(payload, secret_key, algorithm="HS256")
logger.debug("Token generated successfully")
return token
except Exception as e:
logger.error(f"Error generating token: {str(e)}")
raise
def encode_image_to_base64(image):
"""Convert PIL Image to base64 string"""
try:
if isinstance(image, Image.Image):
buffered = io.BytesIO()
image.save(buffered, format="PNG")
base64_string = base64.b64encode(buffered.getvalue()).decode('utf-8')
logger.debug(f"Image encoded to base64 successfully. Length: {len(base64_string)}")
return base64_string
logger.error("Input is not a PIL Image")
return None
except Exception as e:
logger.error(f"Error encoding image to base64: {str(e)}")
return None
async def check_task_status(task_id, access_key, secret_key):
"""Check the status of a task"""
max_attempts = 3
wait_interval = 20
attempt = 1
while attempt <= max_attempts:
await asyncio.sleep(wait_interval)
logger.info(f"Checking task status (Attempt {attempt}/{max_attempts})...")
try:
# Generate new token for status check
token = generate_api_token(access_key, secret_key)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {token}"
}
# Status check endpoint
url = f"{API_BASE_URL}/v1/images/kolors-virtual-try-on/{task_id}"
response = requests.get(url, headers=headers, verify=False)
logger.debug(f"Status check response: {response.text}")
result = response.json()
if response.status_code == 200 and result.get('code') == 0:
data = result.get('data', {})
task_status = data.get('task_status', '').lower()
if task_status in ['completed', 'succeed']:
images = data.get('task_result', {}).get('images', [])
if images:
image_url = images[0].get('url')
return None, image_url
else:
return "No images found in the task result.", None
elif task_status in ['failed', 'error']:
error_message = data.get('task_status_msg', 'Task failed.')
return f"Task failed: {error_message}", None
else:
logger.info(f"Task status: {task_status}. Waiting for next attempt...")
else:
error_message = result.get('message', 'Unknown error occurred.')
logger.error(f"Error fetching task status: {error_message}")
except Exception as e:
logger.error(f"Error checking task status: {str(e)}")
attempt += 1
return "Task did not complete within the expected time.", None
async def apply_virtual_tryon_async(
person_image,
garment_image,
access_key,
secret_key
):
"""Apply virtual try-on using Kling API asynchronously"""
try:
logger.info("Starting virtual try-on process")
# Generate API token
jwt_token = generate_api_token(access_key, secret_key)
if not jwt_token:
return None, "Failed to generate JWT token"
# Ensure token is string
if isinstance(jwt_token, bytes):
jwt_token = jwt_token.decode('utf-8')
# Prepare images
logger.debug("Preparing images")
person_base64 = encode_image_to_base64(person_image)
garment_base64 = encode_image_to_base64(garment_image)
if not person_base64 or not garment_base64:
logger.error("Failed to convert images to base64")
return None, "Error converting images to base64"
# Prepare request
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {jwt_token}"
}
# Payload structure
payload = {
"model_name": "kolors-virtual-try-on-v1",
"human_image": person_base64,
"cloth_image": garment_base64
}
# Submit task
url = f"{API_BASE_URL}/v1/images/kolors-virtual-try-on"
logger.debug(f"Making API request to {url}")
response = requests.post(url, headers=headers, json=payload, verify=False)
result = response.json()
if response.status_code == 200 and result.get('code') == 0:
task_id = result.get('data', {}).get('task_id')
if not task_id:
return None, "No task ID received"
logger.info(f"Task submitted successfully. Task ID: {task_id}")
# Check task status
error_message, image_url = await check_task_status(task_id, access_key, secret_key)
if error_message:
return None, error_message
# Download result image
try:
image_response = requests.get(image_url)
if image_response.status_code == 200:
return Image.open(io.BytesIO(image_response.content)), "Success"
else:
return None, f"Failed to download result image: {image_response.status_code}"
except Exception as e:
return None, f"Error downloading result image: {str(e)}"
else:
error_msg = result.get('message', 'Unknown error')
logger.error(f"API Error: {error_msg}")
return None, f"API Error: {error_msg}"
except Exception as e:
logger.error(f"Unexpected Error: {str(e)}")
return None, f"Error: {str(e)}"
def apply_virtual_tryon(
person_image,
garment_image,
access_key,
secret_key,
cloth_type="upper",
image_size="512x512",
num_steps=DEFAULT_NUM_STEPS,
guidance_scale=DEFAULT_GUIDANCE_SCALE,
seed=DEFAULT_SEED
):
"""Synchronous wrapper for async virtual try-on function"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
apply_virtual_tryon_async(
person_image,
garment_image,
access_key,
secret_key
)
)
finally:
loop.close() |