Spaces:
Configuration error
Configuration error
| import torch | |
| from torch import Tensor, nn | |
| import torch.nn.functional as F | |
| import open_clip | |
| from tqdm import tqdm | |
| import numpy as np | |
| from typing import Union, Tuple, List | |
| num_to_word = { | |
| "0": "zero", "1": "one", "2": "two", "3": "three", "4": "four", "5": "five", "6": "six", "7": "seven", "8": "eight", "9": "nine", | |
| "10": "ten", "11": "eleven", "12": "twelve", "13": "thirteen", "14": "fourteen", "15": "fifteen", "16": "sixteen", "17": "seventeen", "18": "eighteen", "19": "nineteen", | |
| "20": "twenty", "21": "twenty-one", "22": "twenty-two", "23": "twenty-three", "24": "twenty-four", "25": "twenty-five", "26": "twenty-six", "27": "twenty-seven", "28": "twenty-eight", "29": "twenty-nine", | |
| "30": "thirty", "31": "thirty-one", "32": "thirty-two", "33": "thirty-three", "34": "thirty-four", "35": "thirty-five", "36": "thirty-six", "37": "thirty-seven", "38": "thirty-eight", "39": "thirty-nine", | |
| "40": "forty", "41": "forty-one", "42": "forty-two", "43": "forty-three", "44": "forty-four", "45": "forty-five", "46": "forty-six", "47": "forty-seven", "48": "forty-eight", "49": "forty-nine", | |
| "50": "fifty", "51": "fifty-one", "52": "fifty-two", "53": "fifty-three", "54": "fifty-four", "55": "fifty-five", "56": "fifty-six", "57": "fifty-seven", "58": "fifty-eight", "59": "fifty-nine", | |
| "60": "sixty", "61": "sixty-one", "62": "sixty-two", "63": "sixty-three", "64": "sixty-four", "65": "sixty-five", "66": "sixty-six", "67": "sixty-seven", "68": "sixty-eight", "69": "sixty-nine", | |
| "70": "seventy", "71": "seventy-one", "72": "seventy-two", "73": "seventy-three", "74": "seventy-four", "75": "seventy-five", "76": "seventy-six", "77": "seventy-seven", "78": "seventy-eight", "79": "seventy-nine", | |
| "80": "eighty", "81": "eighty-one", "82": "eighty-two", "83": "eighty-three", "84": "eighty-four", "85": "eighty-five", "86": "eighty-six", "87": "eighty-seven", "88": "eighty-eight", "89": "eighty-nine", | |
| "90": "ninety", "91": "ninety-one", "92": "ninety-two", "93": "ninety-three", "94": "ninety-four", "95": "ninety-five", "96": "ninety-six", "97": "ninety-seven", "98": "ninety-eight", "99": "ninety-nine", | |
| "100": "one hundred" | |
| } | |
| prefixes = [ | |
| "", | |
| "A photo of", "A block of", "An image of", "A picture of", | |
| "There are", | |
| "The image contains", "The photo contains", "The picture contains", | |
| "The image shows", "The photo shows", "The picture shows", | |
| ] | |
| arabic_numeral = [True, False] | |
| compares = [ | |
| "more than", "greater than", "higher than", "larger than", "bigger than", "greater than or equal to", | |
| "at least", "no less than", "not less than", "not fewer than", "not lower than", "not smaller than", "not less than or equal to", | |
| "over", "above", "beyond", "exceeding", "surpassing", | |
| ] | |
| suffixes = [ | |
| "people", "persons", "individuals", "humans", "faces", "heads", "figures", "", | |
| ] | |
| def num2word(num: Union[int, str]) -> str: | |
| """ | |
| Convert the input number to the corresponding English word. For example, 1 -> "one", 2 -> "two", etc. | |
| """ | |
| num = str(int(num)) | |
| return num_to_word.get(num, num) | |
| def format_count( | |
| bins: List[Union[float, Tuple[float, float]]], | |
| ) -> List[List[str]]: | |
| text_prompts = [] | |
| for prefix in prefixes: | |
| for numeral in arabic_numeral: | |
| for compare in compares: | |
| for suffix in suffixes: | |
| prompts = [] | |
| for bin in bins: | |
| if isinstance(bin, (int, float)): # count is a single number | |
| count = int(bin) | |
| if count == 0 or count == 1: | |
| count = num2word(count) if not numeral else count | |
| prefix_ = "There is" if prefix == "There are" else prefix | |
| suffix_ = "person" if suffix == "people" else suffix[:-1] | |
| prompt = f"{prefix_} {count} {suffix_}" | |
| else: # count > 1 | |
| count = num2word(count) if not numeral else count | |
| prompt = f"{prefix} {count} {suffix}" | |
| elif bin[1] == float("inf"): # count is (lower_bound, inf) | |
| count = int(bin[0]) | |
| count = num2word(count) if not numeral else count | |
| prompt = f"{prefix} {compare} {count} {suffix}" | |
| else: # bin is (lower_bound, upper_bound) | |
| left, right = int(bin[0]), int(bin[1]) | |
| left, right = num2word(left) if not numeral else left, num2word(right) if not numeral else right | |
| prompt = f"{prefix} between {left} and {right} {suffix}" | |
| # Remove starting and trailing whitespaces | |
| prompt = prompt.strip() + "." | |
| prompts.append(prompt) | |
| text_prompts.append(prompts) | |
| return text_prompts | |
| def encode_text( | |
| model_name: str, | |
| weight_name: str, | |
| text: List[str] | |
| ) -> Tensor: | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| elif torch.mps.is_available(): | |
| device = torch.device("mps") | |
| else: | |
| device = torch.device("cpu") | |
| text = open_clip.get_tokenizer(model_name)(text).to(device) | |
| model = open_clip.create_model_from_pretrained(model_name, weight_name, return_transform=False).to(device) | |
| model.eval() | |
| with torch.no_grad(): | |
| text_feats = model.encode_text(text) | |
| text_feats = F.normalize(text_feats, p=2, dim=-1).detach().cpu() | |
| return text_feats | |
| def optimize_text_prompts( | |
| model_name: str, | |
| weight_name: str, | |
| flat_bins: List[Union[float, Tuple[float, float]]], | |
| batch_size: int = 1024, | |
| ) -> List[str]: | |
| text_prompts = format_count(flat_bins) | |
| # Find the template that has the smallest average similarity of bin prompts. | |
| print("Finding the best setup for text prompts...") | |
| text_prompts_ = [prompt for prompts in text_prompts for prompt in prompts] # flatten the list | |
| text_feats = [] | |
| for i in tqdm(range(0, len(text_prompts_), batch_size)): | |
| text_feats.append(encode_text(model_name, weight_name, text_prompts_[i: min(i + batch_size, len(text_prompts_))])) | |
| text_feats = torch.cat(text_feats, dim=0) | |
| sims = [] | |
| for idx, prompts in enumerate(text_prompts): | |
| text_feats_ = text_feats[idx * len(prompts): (idx + 1) * len(prompts)] | |
| sim = torch.mm(text_feats_, text_feats_.T) | |
| sim = sim[~torch.eye(sim.shape[0], dtype=bool)].mean().item() | |
| sims.append(sim) | |
| optimal_prompts = text_prompts[np.argmin(sims)] | |
| sim = sims[np.argmin(sims)] | |
| print(f"Found the best text prompts: {optimal_prompts} (similarity: {sim:.2f})") | |
| return optimal_prompts | |