|
|
import requests |
|
|
from requests.models import PreparedRequest |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
import torch |
|
|
from torchvision.transforms import ToPILImage |
|
|
from io import BytesIO |
|
|
import os |
|
|
import time |
|
|
|
|
|
API_KEY = os.environ.get("CAI_API_KEY") |
|
|
|
|
|
|
|
|
try: |
|
|
if not API_KEY: |
|
|
dir_path = os.path.dirname(os.path.realpath(__file__)) |
|
|
with open(os.path.join(dir_path, "cai_platform_key.txt"), "r") as f: |
|
|
API_KEY = f.read().strip() |
|
|
|
|
|
if API_KEY.strip() == "": |
|
|
raise Exception(f"API Key is required to use Clarity AI. \nPlease set the CAI_API_KEY environment variable to your API key or place in {dir_path}/cai_platform_key.txt.") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\n\n***API Key is required to use Clarity AI. Please set the CAI_API_KEY environment variable to your API key or place in {dir_path}/cai_platform_key.txt.***\n\n") |
|
|
|
|
|
|
|
|
ROOT_API = "https://v1-upscale-endpoint-oak26mtdga-ey.a.run.app" |
|
|
|
|
|
|
|
|
class ClarityBase: |
|
|
API_ENDPOINT = "" |
|
|
POLL_ENDPOINT = "" |
|
|
ACCEPT = "" |
|
|
|
|
|
@classmethod |
|
|
def INPUT_TYPES(cls): |
|
|
return cls.INPUT_SPEC |
|
|
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
|
FUNCTION = "call" |
|
|
CATEGORY = "Clarity AI" |
|
|
|
|
|
def call(self, *args, **kwargs): |
|
|
|
|
|
buffered = BytesIO() |
|
|
files = {'none': None} |
|
|
data = None |
|
|
|
|
|
image = kwargs.get('image', None) |
|
|
if image is not None: |
|
|
kwargs["mode"] = "image-to-image" |
|
|
kwargs.pop("aspect_ratio", None) |
|
|
image = ToPILImage()(image.squeeze(0).permute(2,0,1)) |
|
|
image.save(buffered, format="PNG") |
|
|
files = self._get_files(buffered, **kwargs) |
|
|
else: |
|
|
kwargs.pop("strength", None) |
|
|
|
|
|
style = kwargs.get('style', False) |
|
|
if style is False: |
|
|
kwargs.pop('style_preset', None) |
|
|
|
|
|
kwargs['comfyui'] = True |
|
|
|
|
|
headers = { |
|
|
"Authorization": API_KEY, |
|
|
} |
|
|
|
|
|
if kwargs.get("api_key_override"): |
|
|
headers = { |
|
|
"Authorization": kwargs.get("api_key_override"), |
|
|
} |
|
|
|
|
|
if headers.get("Authorization") is None: |
|
|
raise Exception(f"No Clarity AI key set.\n\nUse your Clarity AI API key by:\n1. Setting the CAI_API_KEY environment variable to your API key\n3. Placing inside cai_platform_key.txt\n4. Passing the API key as an argument to the function with the key 'api_key_override'") |
|
|
|
|
|
headers["Accept"] = self.ACCEPT |
|
|
|
|
|
data = self._get_data(**kwargs) |
|
|
|
|
|
req = PreparedRequest() |
|
|
req.prepare_method('POST') |
|
|
req.prepare_url(f"{ROOT_API}{self.API_ENDPOINT}", None) |
|
|
req.prepare_headers(headers) |
|
|
req.prepare_body(data=data, files=files) |
|
|
response = requests.Session().send(req) |
|
|
|
|
|
if response.status_code == 200: |
|
|
if self.POLL_ENDPOINT != "": |
|
|
id = response.json().get("id") |
|
|
timeout = 550 |
|
|
start_time = time.time() |
|
|
while True: |
|
|
response = requests.get(f"{ROOT_API}{self.POLL_ENDPOINT}{id}", headers=headers, timeout=timeout) |
|
|
if response.status_code == 200: |
|
|
print("took time: ", time.time() - start_time) |
|
|
if self.ACCEPT == "image/*": |
|
|
return self._return_image(response) |
|
|
if self.ACCEPT == "video/*": |
|
|
return self._return_video(response) |
|
|
break |
|
|
elif response.status_code == 202: |
|
|
time.sleep(10) |
|
|
elif time.time() - start_time > timeout: |
|
|
raise Exception("Clarity AI API Timeout: Request took too long to complete") |
|
|
else: |
|
|
error_info = response.json() |
|
|
raise Exception(f"Clarity AI API Error: {error_info}") |
|
|
else: |
|
|
result_image = Image.open(BytesIO(response.content)) |
|
|
result_image = result_image.convert("RGBA") |
|
|
result_image = np.array(result_image).astype(np.float32) / 255.0 |
|
|
result_image = torch.from_numpy(result_image)[None,] |
|
|
return (result_image,) |
|
|
else: |
|
|
print("Fehler!! Status Code:", response.status_code) |
|
|
error_info = response.text |
|
|
print("error_info: " + error_info) |
|
|
if response.status_code == 401: |
|
|
raise Exception("Clarity AI API Error: Unauthorized.\n\nUse your Clarity AI API key by:\n1. Setting the CAI_API_KEY environment variable to your API key\n3. Placing inside cai_platform_key.txt\n4. Passing the API key as an argument to the function with the key 'api_key_override' \n\n \n\n") |
|
|
if response.status_code == 402: |
|
|
raise Exception("Clarity AI API Error: Not enough credits.\n\nPlease ensure your Clarity AI API account has enough credits to complete this action. \n\n \n\n") |
|
|
if response.status_code == 400: |
|
|
raise Exception(f"Clarity AI API Error: Bad request.\n\n{error_info} \n\n \n\n") |
|
|
else: |
|
|
raise Exception(f"Clarity AI API Error: {error_info}") |
|
|
|
|
|
def _return_image(self, response): |
|
|
result_image = Image.open(BytesIO(response.content)) |
|
|
result_image = result_image.convert("RGBA") |
|
|
result_image = np.array(result_image).astype(np.float32) / 255.0 |
|
|
result_image = torch.from_numpy(result_image)[None,] |
|
|
return (result_image,) |
|
|
|
|
|
def _return_video(self, response): |
|
|
result_video = response.content |
|
|
return (result_video,) |
|
|
|
|
|
def _get_files(self, buffered, **kwargs): |
|
|
return { |
|
|
"image": buffered.getvalue() |
|
|
} |
|
|
|
|
|
def _get_data(self, **kwargs): |
|
|
return {k: v for k, v in kwargs.items() if k != "image"} |
|
|
|
|
|
|
|
|
class ClarityAIUpscaler(ClarityBase): |
|
|
API_ENDPOINT = "" |
|
|
POLL_ENDPOINT = "" |
|
|
ACCEPT = "image/*" |
|
|
INPUT_SPEC = { |
|
|
"required": { |
|
|
"image": ("IMAGE",), |
|
|
}, |
|
|
"optional": { |
|
|
"prompt": ("STRING", {"multiline": True}), |
|
|
"creativity": ("FLOAT", {"default": 0, "min": -10, "max": 10, "step": 1}), |
|
|
"resemblance": ("FLOAT", {"default": 0, "min": -10, "max": 10, "step": 1}), |
|
|
"dynamic": ("FLOAT", {"default": 0, "min": -10, "max": 10, "step": 1}), |
|
|
"fractality": ("FLOAT", {"default": 0, "min": -10, "max": 10, "step": 1}), |
|
|
"style": (["default", "portrait", "anime"],), |
|
|
"scale_factor": (["2", "4", "6", "8", "10", "12", "14", "16"],), |
|
|
"api_key_override": ("STRING", {"multiline": False}), |
|
|
} |
|
|
} |
|
|
|