|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import ast |
|
|
import base64 |
|
|
import itertools |
|
|
import json |
|
|
import os |
|
|
from abc import ABC, abstractmethod |
|
|
from io import BytesIO |
|
|
from typing import Any, List, Union |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Validator(ABC): |
|
|
|
|
|
|
|
|
def __set_name__(self, owner, name): |
|
|
self.private_name = "_" + name |
|
|
|
|
|
def __get__(self, obj, objtype=None): |
|
|
return getattr(obj, self.private_name, self.default) |
|
|
|
|
|
def __set__(self, obj, value): |
|
|
value = self.validate(value) |
|
|
setattr(obj, self.private_name, value) |
|
|
|
|
|
@abstractmethod |
|
|
def validate(self, value): |
|
|
pass |
|
|
|
|
|
def json(self): |
|
|
pass |
|
|
|
|
|
|
|
|
class MultipleOf(Validator): |
|
|
def __init__(self, default: int, multiple_of: int, type_cast=None, hidden=False, tooltip=None): |
|
|
if type(multiple_of) is not int: |
|
|
raise ValueError(f"Expected {multiple_of!r} to be an int") |
|
|
self.multiple_of = multiple_of |
|
|
self.default = default |
|
|
self.type_cast = type_cast |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.hidden = hidden |
|
|
self.tooltip = tooltip |
|
|
|
|
|
def validate(self, value): |
|
|
if self.type_cast: |
|
|
try: |
|
|
value = self.type_cast(value) |
|
|
except ValueError: |
|
|
raise ValueError(f"Expected {value!r} to be castable to {self.type_cast!r}") |
|
|
|
|
|
if value % self.multiple_of != 0: |
|
|
raise ValueError(f"Expected {value!r} to be a multiple of {self.multiple_of!r}") |
|
|
|
|
|
return value |
|
|
|
|
|
def get_range_iterator(self): |
|
|
return itertools.count(0, self.multiple_of) |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return f"MultipleOf({self.private_name=} {self.multiple_of=} {self.hidden=})" |
|
|
|
|
|
def json(self): |
|
|
return { |
|
|
"type": MultipleOf.__name__, |
|
|
"default": self.default, |
|
|
"multiple_of": self.multiple_of, |
|
|
"tooltip": self.tooltip, |
|
|
} |
|
|
|
|
|
|
|
|
class OneOf(Validator): |
|
|
def __init__(self, default, options, type_cast=None, hidden=False, tooltip=None): |
|
|
self.options = set(options) |
|
|
self.default = default |
|
|
self.type_cast = type_cast |
|
|
self.tooltip = tooltip |
|
|
self.hidden = hidden |
|
|
|
|
|
def validate(self, value): |
|
|
if self.type_cast: |
|
|
try: |
|
|
value = self.type_cast(value) |
|
|
except ValueError: |
|
|
raise ValueError(f"Expected {value!r} to be castable to {self.type_cast!r}") |
|
|
|
|
|
if value not in self.options: |
|
|
raise ValueError(f"Expected {value!r} to be one of {self.options!r}") |
|
|
|
|
|
return value |
|
|
|
|
|
def get_range_iterator(self): |
|
|
return self.options |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return f"OneOf({self.private_name=} {self.options=} {self.hidden=})" |
|
|
|
|
|
def json(self): |
|
|
return { |
|
|
"type": OneOf.__name__, |
|
|
"default": self.default, |
|
|
"values": list(self.options), |
|
|
"tooltip": self.tooltip, |
|
|
} |
|
|
|
|
|
|
|
|
class HumanAttributes(Validator): |
|
|
def __init__(self, default, hidden=False, tooltip=None): |
|
|
self.default = default |
|
|
self.hidden = hidden |
|
|
self.tooltip = tooltip |
|
|
|
|
|
|
|
|
|
|
|
valid_attributes = { |
|
|
"emotion": ["angry", "contemptful", "disgusted", "fearful", "happy", "neutral", "sad", "surprised"], |
|
|
"race": ["asian", "indian", "black", "white", "middle eastern", "latino hispanic"], |
|
|
"gender": ["male", "female"], |
|
|
"age group": [ |
|
|
"young", |
|
|
"teen", |
|
|
"adult early twenties", |
|
|
"adult late twenties", |
|
|
"adult early thirties", |
|
|
"adult late thirties", |
|
|
"adult middle aged", |
|
|
"older adult", |
|
|
], |
|
|
} |
|
|
|
|
|
def get_range_iterator(self): |
|
|
|
|
|
l1 = self.valid_attributes["emotion"] |
|
|
l2 = self.valid_attributes["race"] |
|
|
l3 = self.valid_attributes["gender"] |
|
|
l4 = self.valid_attributes["age group"] |
|
|
all_combinations = list(itertools.product(l1, l2, l3, l4)) |
|
|
return iter(all_combinations) |
|
|
|
|
|
def validate(self, value): |
|
|
human_attributes = value.lower() |
|
|
if human_attributes not in ["none", "random"]: |
|
|
|
|
|
|
|
|
attr_string = human_attributes |
|
|
for attr_key in ["emotion", "race", "gender", "age group"]: |
|
|
attr_detected = False |
|
|
for attr_label in self.valid_attributes[attr_key]: |
|
|
if attr_string.startswith(attr_label): |
|
|
attr_string = attr_string[len(attr_label) + 1 :] |
|
|
attr_detected = True |
|
|
break |
|
|
|
|
|
if attr_detected is False: |
|
|
raise ValueError(f"Expected {value!r} to be one of {self.valid_attributes!r}") |
|
|
|
|
|
return value |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return f"HumanAttributes({self.private_name=} {self.hidden=})" |
|
|
|
|
|
def json(self): |
|
|
return { |
|
|
"type": HumanAttributes.__name__, |
|
|
"default": self.default, |
|
|
"values": self.valid_attributes, |
|
|
"tooltip": self.tooltip, |
|
|
} |
|
|
|
|
|
|
|
|
class Bool(Validator): |
|
|
def __init__(self, default, hidden=False, tooltip=None): |
|
|
self.default = default |
|
|
self.hidden = hidden |
|
|
self.tooltip = tooltip |
|
|
|
|
|
def validate(self, value): |
|
|
if isinstance(value, int): |
|
|
value = value != 0 |
|
|
elif isinstance(value, str): |
|
|
value = value.lower() |
|
|
if value in ["true", "1"]: |
|
|
value = True |
|
|
elif value in ["false", "0"]: |
|
|
value = False |
|
|
else: |
|
|
raise ValueError(f"Expected {value!r} to be one of ['True', 'False', '1', '0']") |
|
|
elif not isinstance(value, bool): |
|
|
raise TypeError(f"Expected {value!r} to be an bool") |
|
|
|
|
|
return value |
|
|
|
|
|
def get_range_iterator(self): |
|
|
return [True, False] |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return f"Bool({self.private_name=} {self.default=} {self.hidden=})" |
|
|
|
|
|
def json(self): |
|
|
return { |
|
|
"type": bool.__name__, |
|
|
"default": self.default, |
|
|
"tooltip": self.tooltip, |
|
|
} |
|
|
|
|
|
|
|
|
class Int(Validator): |
|
|
def __init__(self, default, min=None, max=None, step=1, hidden=False, tooltip=None): |
|
|
self.min = min |
|
|
self.max = max |
|
|
self.default = default |
|
|
self.step = step |
|
|
self.hidden = hidden |
|
|
self.tooltip = tooltip |
|
|
|
|
|
def validate(self, value): |
|
|
if isinstance(value, str): |
|
|
value = int(value) |
|
|
elif not isinstance(value, int): |
|
|
raise TypeError(f"Expected {value!r} to be an int") |
|
|
|
|
|
if self.min is not None and value < self.min: |
|
|
raise ValueError(f"Expected {value!r} to be at least {self.min!r}") |
|
|
if self.max is not None and value > self.max: |
|
|
raise ValueError(f"Expected {value!r} to be no more than {self.max!r}") |
|
|
return value |
|
|
|
|
|
def get_range_iterator(self): |
|
|
iter_min = self.min if self.min is not None else self.default |
|
|
iter_max = self.max if self.max is not None else self.default |
|
|
return itertools.takewhile(lambda x: x <= iter_max, itertools.count(iter_min, self.step)) |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return f"Int({self.private_name=} {self.default=}, {self.min=}, {self.max=} {self.hidden=})" |
|
|
|
|
|
def json(self): |
|
|
return { |
|
|
"type": int.__name__, |
|
|
"default": self.default, |
|
|
"min": self.min, |
|
|
"max": self.max, |
|
|
"step": self.step, |
|
|
"tooltip": self.tooltip, |
|
|
} |
|
|
|
|
|
|
|
|
class Float(Validator): |
|
|
def __init__(self, default=0.0, min=None, max=None, step=0.5, hidden=False, tooltip=None): |
|
|
self.min = min |
|
|
self.max = max |
|
|
self.default = default |
|
|
self.step = step |
|
|
self.hidden = hidden |
|
|
self.tooltip = tooltip |
|
|
|
|
|
def validate(self, value): |
|
|
if isinstance(value, str) or isinstance(value, int): |
|
|
value = float(value) |
|
|
elif not isinstance(value, float): |
|
|
raise TypeError(f"Expected {value!r} to be float") |
|
|
|
|
|
if self.min is not None and value < self.min: |
|
|
raise ValueError(f"Expected {value!r} to be at least {self.min!r}") |
|
|
if self.max is not None and value > self.max: |
|
|
raise ValueError(f"Expected {value!r} to be no more than {self.max!r}") |
|
|
return value |
|
|
|
|
|
def get_range_iterator(self): |
|
|
iter_min = self.min if self.min is not None else self.default |
|
|
iter_max = self.max if self.max is not None else self.default |
|
|
return itertools.takewhile(lambda x: x <= iter_max, itertools.count(iter_min, self.step)) |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return f"Float({self.private_name=} {self.default=}, {self.min=}, {self.max=} {self.hidden=})" |
|
|
|
|
|
def json(self): |
|
|
return { |
|
|
"type": float.__name__, |
|
|
"default": self.default, |
|
|
"min": self.min, |
|
|
"max": self.max, |
|
|
"step": self.step, |
|
|
"tooltip": self.tooltip, |
|
|
} |
|
|
|
|
|
|
|
|
class String(Validator): |
|
|
def __init__(self, default="", min=None, max=None, predicate=None, hidden=False, tooltip=None): |
|
|
self.min = min |
|
|
self.max = max |
|
|
self.predicate = predicate |
|
|
self.default = default |
|
|
self.hidden = hidden |
|
|
self.tooltip = tooltip |
|
|
|
|
|
def validate(self, value): |
|
|
if not isinstance(value, str): |
|
|
raise TypeError(f"Expected {value!r} to be an str") |
|
|
if self.min is not None and len(value) < self.min: |
|
|
raise ValueError(f"Expected {value!r} to be no smaller than {self.min!r}") |
|
|
if self.max is not None and len(value) > self.max: |
|
|
raise ValueError(f"Expected {value!r} to be no bigger than {self.max!r}") |
|
|
if self.predicate is not None and not self.predicate(value): |
|
|
raise ValueError(f"Expected {self.predicate} to be true for {value!r}") |
|
|
return value |
|
|
|
|
|
def get_range_iterator(self): |
|
|
return iter([self.default]) |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return f"String({self.private_name=} {self.default=}, {self.min=}, {self.max=} {self.hidden=})" |
|
|
|
|
|
def json(self): |
|
|
return { |
|
|
"type": str.__name__, |
|
|
"default": self.default, |
|
|
"tooltip": self.tooltip, |
|
|
} |
|
|
|
|
|
|
|
|
class Path(Validator): |
|
|
def __init__(self, default="", hidden=False, tooltip=None): |
|
|
self.default = default |
|
|
self.hidden = hidden |
|
|
self.tooltip = tooltip |
|
|
|
|
|
def validate(self, value): |
|
|
if not isinstance(value, str): |
|
|
raise TypeError(f"Expected {value!r} to be an str") |
|
|
if not os.path.exists(value): |
|
|
raise ValueError(f"Expected {value!r} to be a valid path") |
|
|
|
|
|
return value |
|
|
|
|
|
def get_range_iterator(self): |
|
|
return iter([self.default]) |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return f"String({self.private_name=} {self.default=}, {self.hidden=})" |
|
|
|
|
|
|
|
|
class InputImage(Validator): |
|
|
def __init__(self, default="", hidden=False, tooltip=None): |
|
|
self.default = default |
|
|
self.hidden = hidden |
|
|
self.tooltip = tooltip |
|
|
|
|
|
valid_formats = { |
|
|
"JPEG": ["jpeg", "jpg"], |
|
|
"JPEG2000": ["jp2"], |
|
|
"PNG": ["png"], |
|
|
"GIF": ["gif"], |
|
|
"BMP": ["bmp"], |
|
|
} |
|
|
|
|
|
valid_extensions = {vi: k for k, v in valid_formats.items() for vi in v} |
|
|
|
|
|
def validate(self, value): |
|
|
_, ext = os.path.splitext(value).lower() |
|
|
image_format = InputImage.valid_extensions[ext] |
|
|
|
|
|
if not isinstance(value, str): |
|
|
raise TypeError(f"Expected {value!r} to be an str") |
|
|
if not os.path.exists(value): |
|
|
raise ValueError(f"Expected {value!r} to be a valid path") |
|
|
return value |
|
|
|
|
|
def get_range_iterator(self): |
|
|
return iter([self.default]) |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return f"String({self.private_name=} {self.default=} {self.hidden=})" |
|
|
|
|
|
def json(self): |
|
|
return { |
|
|
"type": InputImage.__name__, |
|
|
"default": self.default, |
|
|
"values": self.valid_formats, |
|
|
"tooltip": self.tooltip, |
|
|
} |
|
|
|
|
|
|
|
|
class MeshFormat(Validator): |
|
|
""" |
|
|
Validator class for mesh formats. Valid inputs are either: |
|
|
- single valid format such as "glb", "obj" |
|
|
- or a list of valid formats such as "[obj, ply, usdz]" |
|
|
""" |
|
|
|
|
|
valid_formats = {"glb", "usdz", "obj", "ply"} |
|
|
|
|
|
def __init__(self, default="glb", hidden=False, tooltip=None): |
|
|
self.default = default |
|
|
self.hidden = hidden |
|
|
self.tooltip = tooltip |
|
|
|
|
|
def validate(self, value: str) -> Union[str, List[str]]: |
|
|
try: |
|
|
|
|
|
if value.startswith("[") and value.endswith("]"): |
|
|
formats = ast.literal_eval(value) |
|
|
if not all(fmt in MeshFormat.valid_formats for fmt in formats): |
|
|
raise ValueError(f"Each item must be one of {MeshFormat.valid_formats}") |
|
|
return formats |
|
|
elif value in MeshFormat.valid_formats: |
|
|
return value |
|
|
else: |
|
|
raise ValueError(f"Expected {value!r} to be one of {MeshFormat.valid_formats} or a list of them") |
|
|
except (SyntaxError, ValueError) as e: |
|
|
|
|
|
raise ValueError(f"Invalid format specification: {value}. Error: {str(e)}") |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return f"MeshFormat(default={self.default}, hidden={self.hidden})" |
|
|
|
|
|
def json(self): |
|
|
return { |
|
|
"type": MeshFormat.__name__, |
|
|
"default": self.default, |
|
|
"values": self.valid_formats, |
|
|
"tooltip": self.tooltip, |
|
|
} |
|
|
|
|
|
|
|
|
class JsonDict(Validator): |
|
|
""" |
|
|
JSON stringified version of a python dict. |
|
|
Example: '{"ema_customization_iter.pt": "ema_customization_iter.pt"}' |
|
|
""" |
|
|
|
|
|
def __init__(self, default="", hidden=False): |
|
|
self.default = default |
|
|
self.hidden = hidden |
|
|
|
|
|
def validate(self, value): |
|
|
if not value: |
|
|
return {} |
|
|
try: |
|
|
dict = json.loads(value) |
|
|
return dict |
|
|
except json.JSONDecodeError as e: |
|
|
raise ValueError(f"Expected {value!r} to be json stringified dict. Error: {str(e)}") |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return f"Dict({self.default=} {self.hidden=})" |
|
|
|
|
|
|
|
|
class BytesIOType(Validator): |
|
|
""" |
|
|
Validator class for BytesIO. Valid inputs are either: |
|
|
- bytes |
|
|
- objects of class BytesIO |
|
|
- str which can be successfully decoded into BytesIO |
|
|
""" |
|
|
|
|
|
def __init__(self, default=None, hidden=False, tooltip=None): |
|
|
self.default = default |
|
|
self.hidden = hidden |
|
|
self.tooltip = tooltip |
|
|
|
|
|
def validate(self, value: Any) -> BytesIO: |
|
|
if isinstance(value, str): |
|
|
try: |
|
|
|
|
|
decoded_bytes = base64.b64decode(value) |
|
|
|
|
|
return BytesIO(decoded_bytes) |
|
|
except (base64.binascii.Error, ValueError) as e: |
|
|
raise ValueError(f"Invalid Base64 encoded string: {e}") |
|
|
elif isinstance(value, bytes): |
|
|
return BytesIO(value) |
|
|
elif isinstance(value, BytesIO): |
|
|
return value |
|
|
else: |
|
|
raise TypeError(f"Expected {value!r} to be a Base64 encoded string, bytes, or BytesIO") |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return f"BytesIOValidator({self.default=}, {self.hidden=})" |
|
|
|
|
|
def json(self): |
|
|
return { |
|
|
"type": BytesIO.__name__, |
|
|
"default": self.default, |
|
|
"tooltip": self.tooltip, |
|
|
} |
|
|
|