Spaces:
Running
on
Zero
Running
on
Zero
| import json | |
| def get_prompt_list(file_path: str) -> list[list[str]]: | |
| return json.load(open(file_path, "r")) | |
| class GetPromptList(object): | |
| _SUPPORTED_SOURCE = {'Sachit-descriptors',} | |
| def __init__(self, file_path: str, name2idx: dict[str: int] = None, class_names: list[str] = None) -> None: | |
| self.class_names = class_names | |
| self.file_path = file_path | |
| self.desc = get_prompt_list(file_path) | |
| if isinstance(self.desc, dict): | |
| self.__get_parts() | |
| if name2idx is not None: | |
| self.name2idx = name2idx | |
| elif class_names is not None: | |
| self.name2idx = {cls_name: idx for idx, cls_name in enumerate(class_names)} | |
| else: | |
| self.name2idx = {cls_name: idx for idx, cls_name in enumerate(self.desc.keys())} if isinstance(self.desc, dict) else None | |
| # def __get_sachit_desc(self, file_path: str): | |
| # params = get_sachit_hparams(file_path) | |
| # return load_gpt_descriptions(params) | |
| def __get_parts(self, ): | |
| # get part names from one of the descriptions | |
| self.part_names = [d.split(":")[0].strip() for d in self.desc[list(self.desc.keys())[0]]] | |
| def replace_class_names(self, descs: dict, target_class: list[str], new_classes: list[str]): | |
| new_descs = [] | |
| for desc, cls_name, new_name in zip(descs, target_class, new_classes): | |
| temp = [d.replace(cls_name, new_name) for d in desc] | |
| new_descs.extend(temp) | |
| return new_descs | |
| def __call__(self, source: str, pad: bool = False, max_len: int = 15, pad_text: str = "", target_classes: list[int] = None, pad_neg_index: bool = True): | |
| """ | |
| This function will return a list of prompts based on the source (format) and file_path provided. | |
| If name2idx is provided, the prompts will be mapped based on the provied class indexes. Otherwise, | |
| the prompts will be mapped based on the order of class name in the file. | |
| Note: this function is will apply trucation when padding is True to make sure to have fixed length prompts. | |
| Args: | |
| source (str): The sorce (format) of the prompts. Supported sources are: {self._SUPPORTED_SOURCE} | |
| file_path (str): The file that contains the original prompts. | |
| pad (bool, optional): Whether to pad the prompts to the same length. Defaults to False. | |
| max_len (int, optional): The maximum length of the prompts. Defaults to 15. | |
| pad_text (str, optional): The text to pad the prompts. Defaults to "Padding". | |
| target_classes (list[int], optional): A list of class indexes to include in the prompts. Defaults to None (include all classes). | |
| Returns: | |
| prompts (list[str]): A list of engineered prompts. | |
| class_idxs (list[int]): A list of class indexes for each prompt. | |
| class_mapping (dict[int: str]): A mapping of class indexes to class names. | |
| """ | |
| org_desc_mapper = None | |
| match source: | |
| case 'Sachit-descriptors': | |
| desc, org_dict = self.__get_sachit_desc(self.file_path) | |
| case 'Sachit-no-template': | |
| desc = self.desc | |
| case 'Sachit-CLIP-template-5': | |
| desc, org_dict = self.__get_sachit_desc(self.file_path) | |
| desc = {k: [f'a photo of a {d}.' for d in v] for k, v in desc.items()} | |
| case 'cub-12-parts': | |
| return self.desc, None, None, None | |
| case 'chatgpt-no-template': | |
| desc = self.desc | |
| case 'chatgpt-template-0': | |
| # convert 'part: features' to 'a features part' | |
| template = 'a {} {}.' | |
| desc = {k: [template.format(d.split(":")[1].strip(), d.split(":")[0].strip()) for d in v] for k, v in self.desc.items()} | |
| case 'chatgpt-template-8': | |
| # convert '{part}: {features}' to 'a {features} {part} of {class_name}' | |
| template = 'a {} {} of {}.' | |
| desc = {k: [template.format(d.split(":")[1].strip(), d.split(":")[0].strip(), k) for d in v] for k, v in self.desc.items()} | |
| case 'chatgpt-template-5': | |
| # convert '{part}: {features}' to 'a photo of {class_name}, which is/has/etc {descriptor} | |
| desc, org_dict = self.__get_sachit_desc(self.file_path) | |
| template = 'a photo of a {}' | |
| desc = {k: [template.format(d) for d in v] for k, v in desc.items()} | |
| case 'chatgpt-template-x': | |
| # convert '{part}: {features}' to '{features}. {part}. {class_name}' | |
| desc = {k: [f'{d.split(":")[1].strip()}. {d.split(":")[0].strip()}. {k}' for d in v] for k, v in self.desc.items()} | |
| case 'chatgpt-template-x-2': # no class name | |
| # convert '{part}: {features}' to 'a {features} {part}' | |
| desc = {k: [f'a {d.split(":")[1].strip()} {d.split(":")[0].strip()}' for d in v] for k, v in self.desc.items()} | |
| case 'chatgpt-template-x-3': # no class name | |
| # convert '{part}: {features}' to '{features}. {part}.' | |
| desc = {k: [f'{d.split(":")[1].strip()}. {d.split(":")[0].strip()}.' for d in v] for k, v in self.desc.items()} | |
| case 'chatgpt-template-x-4': | |
| # convert '{part}: {features}' to 'a {part} of {class_name}: {features}' | |
| desc = {k: [f'a {d.split(":")[0].strip()} of {k}: {d.split(":")[1].strip()}' for d in v] for k, v in self.desc.items()} | |
| case _: | |
| raise ValueError(f"Source {source} is not supported. Check {self._SUPPORTED_SOURCE}") | |
| # get the subset of descriptrions that match the target classes | |
| if len(self.name2idx) < len(desc): | |
| desc = {k: desc[k] for k in self.name2idx} | |
| prompts, class_idxs, class_list = [], [], [] | |
| class_mapping = {v: k for k, v in self.name2idx.items()} | |
| for class_name, class_idx in self.name2idx.items(): | |
| descriptions = desc[class_name] | |
| if target_classes is not None and class_idx not in target_classes: | |
| continue | |
| if pad: | |
| pad_id = -1 if pad_neg_index else class_idx | |
| ids = [class_idx] * len(descriptions) + [pad_id] * (max_len - len(descriptions)) if len(descriptions) < max_len else [class_idx] * max_len | |
| if len(descriptions) < max_len: | |
| descriptions.extend([pad_text] * (max_len - len(descriptions))) | |
| else: | |
| descriptions = descriptions[:max_len] | |
| else: | |
| ids = [class_idx] * len(descriptions) | |
| prompts.extend(descriptions) | |
| class_idxs.extend(ids) | |
| class_list.append(class_name) | |
| if org_desc_mapper is not None: | |
| org_desc_mapper = {des: org_dict[class_name][des] for des in descriptions} | |
| return prompts, class_idxs, class_mapping, org_desc_mapper, class_list | |
| imagenet_templates = [ | |
| 'a bad photo of a {}.', | |
| 'a photo of many {}.', | |
| 'a sculpture of a {}.', | |
| 'a photo of the hard to see {}.', | |
| 'a low resolution photo of the {}.', | |
| 'a rendering of a {}.', | |
| 'graffiti of a {}.', | |
| 'a bad photo of the {}.', | |
| 'a cropped photo of the {}.', | |
| 'a tattoo of a {}.', | |
| 'the embroidered {}.', | |
| 'a photo of a hard to see {}.', | |
| 'a bright photo of a {}.', | |
| 'a photo of a clean {}.', | |
| 'a photo of a dirty {}.', | |
| 'a dark photo of the {}.', | |
| 'a drawing of a {}.', | |
| 'a photo of my {}.', | |
| 'the plastic {}.', | |
| 'a photo of the cool {}.', | |
| 'a close-up photo of a {}.', | |
| 'a black and white photo of the {}.', | |
| 'a painting of the {}.', | |
| 'a painting of a {}.', | |
| 'a pixelated photo of the {}.', | |
| 'a sculpture of the {}.', | |
| 'a bright photo of the {}.', | |
| 'a cropped photo of a {}.', | |
| 'a plastic {}.', | |
| 'a photo of the dirty {}.', | |
| 'a jpeg corrupted photo of a {}.', | |
| 'a blurry photo of the {}.', | |
| 'a photo of the {}.', | |
| 'a good photo of the {}.', | |
| 'a rendering of the {}.', | |
| 'a {} in a video game.', | |
| 'a photo of one {}.', | |
| 'a doodle of a {}.', | |
| 'a close-up photo of the {}.', | |
| 'a photo of a {}.', | |
| 'the origami {}.', | |
| 'the {} in a video game.', | |
| 'a sketch of a {}.', | |
| 'a doodle of the {}.', | |
| 'a origami {}.', | |
| 'a low resolution photo of a {}.', | |
| 'the toy {}.', | |
| 'a rendition of the {}.', | |
| 'a photo of the clean {}.', | |
| 'a photo of a large {}.', | |
| 'a rendition of a {}.', | |
| 'a photo of a nice {}.', | |
| 'a photo of a weird {}.', | |
| 'a blurry photo of a {}.', | |
| 'a cartoon {}.', | |
| 'art of a {}.', | |
| 'a sketch of the {}.', | |
| 'a embroidered {}.', | |
| 'a pixelated photo of a {}.', | |
| 'itap of the {}.', | |
| 'a jpeg corrupted photo of the {}.', | |
| 'a good photo of a {}.', | |
| 'a plushie {}.', | |
| 'a photo of the nice {}.', | |
| 'a photo of the small {}.', | |
| 'a photo of the weird {}.', | |
| 'the cartoon {}.', | |
| 'art of the {}.', | |
| 'a drawing of the {}.', | |
| 'a photo of the large {}.', | |
| 'a black and white photo of a {}.', | |
| 'the plushie {}.', | |
| 'a dark photo of a {}.', | |
| 'itap of a {}.', | |
| 'graffiti of the {}.', | |
| 'a toy {}.', | |
| 'itap of my {}.', | |
| 'a photo of a cool {}.', | |
| 'a photo of a small {}.', | |
| 'a tattoo of the {}.', | |
| ] | |