| |
| |
| |
| |
|
|
| from enum import Enum, EnumMeta |
| from typing import List |
|
|
|
|
| class StrEnumMeta(EnumMeta): |
| |
| |
| @classmethod |
| def __instancecheck__(cls, other): |
| return "enum" in str(type(other)) |
|
|
|
|
| class StrEnum(Enum, metaclass=StrEnumMeta): |
| def __str__(self): |
| return self.value |
|
|
| def __eq__(self, other: str): |
| return self.value == other |
|
|
| def __repr__(self): |
| return self.value |
|
|
| def __hash__(self): |
| return hash(str(self)) |
|
|
|
|
| def ChoiceEnum(choices: List[str]): |
| """return the Enum class used to enforce list of choices""" |
| return StrEnum("Choices", {k: k for k in choices}) |
|
|
|
|
| LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"]) |
| DDP_BACKEND_CHOICES = ChoiceEnum( |
| [ |
| "c10d", |
| "fully_sharded", |
| "legacy_ddp", |
| "no_c10d", |
| "pytorch_ddp", |
| "slowmo", |
| ] |
| ) |
| DDP_COMM_HOOK_CHOICES = ChoiceEnum(["none", "fp16"]) |
| DATASET_IMPL_CHOICES = ChoiceEnum(["raw", "lazy", "cached", "mmap", "fasta", "huffman"]) |
| GENERATION_CONSTRAINTS_CHOICES = ChoiceEnum(["ordered", "unordered"]) |
| GENERATION_DECODING_FORMAT_CHOICES = ChoiceEnum( |
| ["unigram", "ensemble", "vote", "dp", "bs"] |
| ) |
| ZERO_SHARDING_CHOICES = ChoiceEnum(["none", "os"]) |
| PIPELINE_CHECKPOINT_CHOICES = ChoiceEnum(["always", "never", "except_last"]) |
| PRINT_ALIGNMENT_CHOICES = ChoiceEnum(["hard", "soft"]) |
|
|