|
|
import torch |
|
|
from torch import Tensor, nn |
|
|
import torch.nn.functional as F |
|
|
import open_clip |
|
|
from tqdm import tqdm |
|
|
import numpy as np |
|
|
from typing import Union, Tuple, List |
|
|
|
|
|
|
|
|
num_to_word = { |
|
|
"0": "zero", "1": "one", "2": "two", "3": "three", "4": "four", "5": "five", "6": "six", "7": "seven", "8": "eight", "9": "nine", |
|
|
"10": "ten", "11": "eleven", "12": "twelve", "13": "thirteen", "14": "fourteen", "15": "fifteen", "16": "sixteen", "17": "seventeen", "18": "eighteen", "19": "nineteen", |
|
|
"20": "twenty", "21": "twenty-one", "22": "twenty-two", "23": "twenty-three", "24": "twenty-four", "25": "twenty-five", "26": "twenty-six", "27": "twenty-seven", "28": "twenty-eight", "29": "twenty-nine", |
|
|
"30": "thirty", "31": "thirty-one", "32": "thirty-two", "33": "thirty-three", "34": "thirty-four", "35": "thirty-five", "36": "thirty-six", "37": "thirty-seven", "38": "thirty-eight", "39": "thirty-nine", |
|
|
"40": "forty", "41": "forty-one", "42": "forty-two", "43": "forty-three", "44": "forty-four", "45": "forty-five", "46": "forty-six", "47": "forty-seven", "48": "forty-eight", "49": "forty-nine", |
|
|
"50": "fifty", "51": "fifty-one", "52": "fifty-two", "53": "fifty-three", "54": "fifty-four", "55": "fifty-five", "56": "fifty-six", "57": "fifty-seven", "58": "fifty-eight", "59": "fifty-nine", |
|
|
"60": "sixty", "61": "sixty-one", "62": "sixty-two", "63": "sixty-three", "64": "sixty-four", "65": "sixty-five", "66": "sixty-six", "67": "sixty-seven", "68": "sixty-eight", "69": "sixty-nine", |
|
|
"70": "seventy", "71": "seventy-one", "72": "seventy-two", "73": "seventy-three", "74": "seventy-four", "75": "seventy-five", "76": "seventy-six", "77": "seventy-seven", "78": "seventy-eight", "79": "seventy-nine", |
|
|
"80": "eighty", "81": "eighty-one", "82": "eighty-two", "83": "eighty-three", "84": "eighty-four", "85": "eighty-five", "86": "eighty-six", "87": "eighty-seven", "88": "eighty-eight", "89": "eighty-nine", |
|
|
"90": "ninety", "91": "ninety-one", "92": "ninety-two", "93": "ninety-three", "94": "ninety-four", "95": "ninety-five", "96": "ninety-six", "97": "ninety-seven", "98": "ninety-eight", "99": "ninety-nine", |
|
|
"100": "one hundred" |
|
|
} |
|
|
|
|
|
prefixes = [ |
|
|
"", |
|
|
"A photo of", "A block of", "An image of", "A picture of", |
|
|
"There are", |
|
|
"The image contains", "The photo contains", "The picture contains", |
|
|
"The image shows", "The photo shows", "The picture shows", |
|
|
] |
|
|
arabic_numeral = [True, False] |
|
|
compares = [ |
|
|
"more than", "greater than", "higher than", "larger than", "bigger than", "greater than or equal to", |
|
|
"at least", "no less than", "not less than", "not fewer than", "not lower than", "not smaller than", "not less than or equal to", |
|
|
"over", "above", "beyond", "exceeding", "surpassing", |
|
|
] |
|
|
suffixes = [ |
|
|
"people", "persons", "individuals", "humans", "faces", "heads", "figures", "", |
|
|
] |
|
|
|
|
|
|
|
|
def num2word(num: Union[int, str]) -> str: |
|
|
""" |
|
|
Convert the input number to the corresponding English word. For example, 1 -> "one", 2 -> "two", etc. |
|
|
""" |
|
|
num = str(int(num)) |
|
|
return num_to_word.get(num, num) |
|
|
|
|
|
|
|
|
def format_count( |
|
|
bins: List[Union[float, Tuple[float, float]]], |
|
|
) -> List[List[str]]: |
|
|
text_prompts = [] |
|
|
for prefix in prefixes: |
|
|
for numeral in arabic_numeral: |
|
|
for compare in compares: |
|
|
for suffix in suffixes: |
|
|
prompts = [] |
|
|
for bin in bins: |
|
|
if isinstance(bin, (int, float)): |
|
|
count = int(bin) |
|
|
if count == 0 or count == 1: |
|
|
count = num2word(count) if not numeral else count |
|
|
prefix_ = "There is" if prefix == "There are" else prefix |
|
|
suffix_ = "person" if suffix == "people" else suffix[:-1] |
|
|
prompt = f"{prefix_} {count} {suffix_}" |
|
|
else: |
|
|
count = num2word(count) if not numeral else count |
|
|
prompt = f"{prefix} {count} {suffix}" |
|
|
|
|
|
elif bin[1] == float("inf"): |
|
|
count = int(bin[0]) |
|
|
count = num2word(count) if not numeral else count |
|
|
prompt = f"{prefix} {compare} {count} {suffix}" |
|
|
|
|
|
else: |
|
|
left, right = int(bin[0]), int(bin[1]) |
|
|
left, right = num2word(left) if not numeral else left, num2word(right) if not numeral else right |
|
|
prompt = f"{prefix} between {left} and {right} {suffix}" |
|
|
|
|
|
|
|
|
prompt = prompt.strip() + "." |
|
|
|
|
|
prompts.append(prompt) |
|
|
|
|
|
text_prompts.append(prompts) |
|
|
|
|
|
return text_prompts |
|
|
|
|
|
|
|
|
def encode_text( |
|
|
model_name: str, |
|
|
weight_name: str, |
|
|
text: List[str] |
|
|
) -> Tensor: |
|
|
if torch.cuda.is_available(): |
|
|
device = torch.device("cuda") |
|
|
elif torch.mps.is_available(): |
|
|
device = torch.device("mps") |
|
|
else: |
|
|
device = torch.device("cpu") |
|
|
text = open_clip.get_tokenizer(model_name)(text).to(device) |
|
|
model = open_clip.create_model_from_pretrained(model_name, weight_name, return_transform=False).to(device) |
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
text_feats = model.encode_text(text) |
|
|
text_feats = F.normalize(text_feats, p=2, dim=-1).detach().cpu() |
|
|
return text_feats |
|
|
|
|
|
|
|
|
def optimize_text_prompts( |
|
|
model_name: str, |
|
|
weight_name: str, |
|
|
flat_bins: List[Union[float, Tuple[float, float]]], |
|
|
batch_size: int = 1024, |
|
|
) -> List[str]: |
|
|
text_prompts = format_count(flat_bins) |
|
|
|
|
|
|
|
|
print("Finding the best setup for text prompts...") |
|
|
text_prompts_ = [prompt for prompts in text_prompts for prompt in prompts] |
|
|
text_feats = [] |
|
|
for i in tqdm(range(0, len(text_prompts_), batch_size)): |
|
|
text_feats.append(encode_text(model_name, weight_name, text_prompts_[i: min(i + batch_size, len(text_prompts_))])) |
|
|
text_feats = torch.cat(text_feats, dim=0) |
|
|
|
|
|
sims = [] |
|
|
for idx, prompts in enumerate(text_prompts): |
|
|
text_feats_ = text_feats[idx * len(prompts): (idx + 1) * len(prompts)] |
|
|
sim = torch.mm(text_feats_, text_feats_.T) |
|
|
sim = sim[~torch.eye(sim.shape[0], dtype=bool)].mean().item() |
|
|
sims.append(sim) |
|
|
|
|
|
optimal_prompts = text_prompts[np.argmin(sims)] |
|
|
sim = sims[np.argmin(sims)] |
|
|
print(f"Found the best text prompts: {optimal_prompts} (similarity: {sim:.2f})") |
|
|
return optimal_prompts |
|
|
|