Spaces:
Sleeping
Sleeping
| import cv2 | |
| import numpy as np | |
| import requests | |
| from requests_toolbelt.multipart.encoder import MultipartEncoder | |
| from urllib.parse import urlparse | |
| import logging | |
| import json | |
| from io import BytesIO | |
| from dataclasses import dataclass | |
| class TryOnDiffusionAPIResponse: | |
| status_code: int | |
| image: np.ndarray = None | |
| response_data: bytes = None | |
| error_details: str = None | |
| seed: int = None | |
| class TryOnDiffusionClient: | |
| def __init__(self, base_url: str = "http://localhost:8000/", api_key: str = ""): | |
| self._logger = logging.getLogger("try_on_diffusion_client") | |
| self._base_url = base_url | |
| self._api_key = api_key | |
| if self._base_url[-1] == "/": | |
| self._base_url = self._base_url[:-1] | |
| parsed_url = urlparse(self._base_url) | |
| self._rapidapi_host = parsed_url.netloc if parsed_url.netloc.endswith(".rapidapi.com") else None | |
| if self._rapidapi_host is not None: | |
| self._logger.info(f"Using RapidAPI proxy: {self._rapidapi_host}") | |
| def _image_to_upload_file(image: np.ndarray) -> tuple: | |
| _, jpeg_data = cv2.imencode(".jpg", image, [int(cv2.IMWRITE_JPEG_QUALITY), 99]) | |
| jpeg_data = jpeg_data.tobytes() | |
| fp = BytesIO(jpeg_data) | |
| return "image.jpg", fp, "image/jpeg" | |
| def try_on_file( | |
| self, | |
| clothing_image: np.ndarray = None, | |
| clothing_prompt: str = None, | |
| avatar_image: np.ndarray = None, | |
| avatar_prompt: str = None, | |
| avatar_sex: str = None, | |
| background_image: np.ndarray = None, | |
| background_prompt: str = None, | |
| seed: int = -1, | |
| raw_response: bool = False, | |
| ) -> TryOnDiffusionAPIResponse: | |
| url = self._base_url + "/try-on-file" | |
| request_data = {"seed": str(seed)} | |
| if clothing_image is not None: | |
| request_data["clothing_image"] = self._image_to_upload_file(clothing_image) | |
| if clothing_prompt is not None: | |
| request_data["clothing_prompt"] = clothing_prompt | |
| if avatar_image is not None: | |
| request_data["avatar_image"] = self._image_to_upload_file(avatar_image) | |
| if avatar_prompt is not None: | |
| request_data["avatar_prompt"] = avatar_prompt | |
| if avatar_sex is not None: | |
| request_data["avatar_sex"] = avatar_sex | |
| if background_image is not None: | |
| request_data["background_image"] = self._image_to_upload_file(background_image) | |
| if background_prompt is not None: | |
| request_data["background_prompt"] = background_prompt | |
| multipart_data = MultipartEncoder(fields=request_data) | |
| headers = {"Content-Type": multipart_data.content_type} | |
| if self._rapidapi_host is not None: | |
| headers["X-RapidAPI-Key"] = self._api_key | |
| headers["X-RapidAPI-Host"] = self._rapidapi_host | |
| else: | |
| headers["X-API-Key"] = self._api_key | |
| try: | |
| response = requests.post( | |
| url, | |
| data=multipart_data, | |
| headers=headers, | |
| ) | |
| except Exception as e: | |
| self._logger.error(e, exc_info=True) | |
| return TryOnDiffusionAPIResponse(status_code=0) | |
| if response.status_code != 200: | |
| self._logger.warning(f"Request failed, status code: {response.status_code}, response: {response.content}") | |
| result = TryOnDiffusionAPIResponse(status_code=response.status_code) | |
| if not raw_response and response.status_code == 200: | |
| try: | |
| result.image = cv2.imdecode(np.frombuffer(response.content, np.uint8), cv2.IMREAD_COLOR) | |
| except: | |
| result.image = None | |
| else: | |
| result.response_data = response.content | |
| if result.status_code == 200: | |
| if "X-Seed" in response.headers: | |
| result.seed = int(response.headers["X-Seed"]) | |
| else: | |
| try: | |
| response_json = ( | |
| json.loads(result.response_data.decode("utf-8")) if result.response_data is not None else None | |
| ) | |
| if response_json is not None and "detail" in response_json: | |
| result.error_details = response_json["detail"] | |
| except: | |
| result.error_details = None | |
| return result | |