|
|
import cv2
|
|
|
import numpy as np
|
|
|
import requests
|
|
|
from requests_toolbelt.multipart.encoder import MultipartEncoder
|
|
|
from urllib.parse import urlparse
|
|
|
import logging
|
|
|
import json
|
|
|
from io import BytesIO
|
|
|
from dataclasses import dataclass
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class TryOnDiffusionAPIResponse:
|
|
|
status_code: int
|
|
|
image: np.ndarray = None
|
|
|
response_data: bytes = None
|
|
|
error_details: str = None
|
|
|
seed: int = None
|
|
|
|
|
|
|
|
|
class TryOnDiffusionClient:
|
|
|
def __init__(self, base_url: str = "https://try-on-diffusion.p.rapidapi.com", api_key: str = "f46338f1a1msh8f27a3a69564667p1c5a31jsnbd2438a5d1c9"):
|
|
|
self._logger = logging.getLogger("try_on_diffusion_client")
|
|
|
self._base_url = base_url
|
|
|
self._api_key = api_key
|
|
|
|
|
|
if self._base_url[-1] == "/":
|
|
|
self._base_url = self._base_url[:-1]
|
|
|
|
|
|
parsed_url = urlparse(self._base_url)
|
|
|
self._rapidapi_host = parsed_url.netloc if parsed_url.netloc.endswith(".rapidapi.com") else None
|
|
|
|
|
|
if self._rapidapi_host is not None:
|
|
|
self._logger.info(f"Using RapidAPI proxy: {self._rapidapi_host}")
|
|
|
|
|
|
@staticmethod
|
|
|
def _image_to_upload_file(image: np.ndarray) -> tuple:
|
|
|
_, jpeg_data = cv2.imencode(".jpg", image, [int(cv2.IMWRITE_JPEG_QUALITY), 99])
|
|
|
jpeg_data = jpeg_data.tobytes()
|
|
|
fp = BytesIO(jpeg_data)
|
|
|
return "image.jpg", fp, "image/jpeg"
|
|
|
|
|
|
def try_on_file(
|
|
|
self,
|
|
|
clothing_image: np.ndarray = None,
|
|
|
clothing_prompt: str = None,
|
|
|
avatar_image: np.ndarray = None,
|
|
|
avatar_prompt: str = None,
|
|
|
avatar_sex: str = None,
|
|
|
background_image: np.ndarray = None,
|
|
|
background_prompt: str = None,
|
|
|
seed: int = -1,
|
|
|
raw_response: bool = False,
|
|
|
) -> TryOnDiffusionAPIResponse:
|
|
|
url = "https://try-on-diffusion.p.rapidapi.com/try-on-file"
|
|
|
print(f"API URL: {url}")
|
|
|
|
|
|
request_data = {"seed": str(seed)}
|
|
|
|
|
|
if clothing_image is not None:
|
|
|
request_data["clothing_image"] = self._image_to_upload_file(clothing_image)
|
|
|
|
|
|
if clothing_prompt is not None:
|
|
|
request_data["clothing_prompt"] = clothing_prompt
|
|
|
|
|
|
if avatar_image is not None:
|
|
|
request_data["avatar_image"] = self._image_to_upload_file(avatar_image)
|
|
|
|
|
|
if avatar_prompt is not None:
|
|
|
request_data["avatar_prompt"] = avatar_prompt
|
|
|
|
|
|
if avatar_sex is not None:
|
|
|
request_data["avatar_sex"] = avatar_sex
|
|
|
|
|
|
if background_image is not None:
|
|
|
request_data["background_image"] = self._image_to_upload_file(background_image)
|
|
|
|
|
|
if background_prompt is not None:
|
|
|
request_data["background_prompt"] = background_prompt
|
|
|
|
|
|
multipart_data = MultipartEncoder(fields=request_data)
|
|
|
headers = {"Content-Type": multipart_data.content_type}
|
|
|
|
|
|
if self._rapidapi_host is not None:
|
|
|
headers["X-RapidAPI-Key"] = self._api_key
|
|
|
headers["X-RapidAPI-Host"] = self._rapidapi_host
|
|
|
else:
|
|
|
headers["X-API-Key"] = self._api_key
|
|
|
|
|
|
try:
|
|
|
response = requests.post(url, data=multipart_data, headers=headers)
|
|
|
except Exception as e:
|
|
|
self._logger.error(e, exc_info=True)
|
|
|
return TryOnDiffusionAPIResponse(status_code=0)
|
|
|
|
|
|
if response.status_code != 200:
|
|
|
self._logger.warning(f"Request failed, status code: {response.status_code}, response: {response.content}")
|
|
|
|
|
|
result = TryOnDiffusionAPIResponse(status_code=response.status_code)
|
|
|
|
|
|
if not raw_response and response.status_code == 200:
|
|
|
try:
|
|
|
result.image = cv2.imdecode(np.frombuffer(response.content, np.uint8), cv2.IMREAD_COLOR)
|
|
|
except Exception as e:
|
|
|
self._logger.error(f"Error decoding image: {e}", exc_info=True)
|
|
|
result.image = None
|
|
|
else:
|
|
|
result.response_data = response.content
|
|
|
|
|
|
if result.status_code == 200:
|
|
|
if "X-Seed" in response.headers:
|
|
|
result.seed = int(response.headers["X-Seed"])
|
|
|
else:
|
|
|
try:
|
|
|
response_json = (
|
|
|
json.loads(result.response_data.decode("utf-8")) if result.response_data is not None else None
|
|
|
)
|
|
|
if response_json is not None and "detail" in response_json:
|
|
|
result.error_details = response_json["detail"]
|
|
|
except Exception as e:
|
|
|
self._logger.error(f"Error parsing response JSON: {e}", exc_info=True)
|
|
|
result.error_details = None
|
|
|
|
|
|
return result
|
|
|
|