| from functools import partial |
| from itertools import islice |
| from typing import Callable, List, Optional, Sequence, Union |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
|
|
| def batched(iterable, n): |
| """Batch data into lists of length *n*. The last batch may be shorter. |
| NOTE based on more-itertools impl, to be replaced by python 3.12 itertools.batched impl |
| """ |
| it = iter(iterable) |
| while True: |
| batch = list(islice(it, n)) |
| if not batch: |
| break |
| yield batch |
|
|
|
|
| def build_zero_shot_classifier( |
| model, |
| tokenizer, |
| classnames: Sequence[str], |
| templates: Sequence[Union[Callable, str]], |
| num_classes_per_batch: Optional[int] = 10, |
| device: Union[str, torch.device] = 'cpu', |
| use_tqdm: bool = False, |
| ): |
| """ Build zero-shot classifier weights by iterating over class names in batches |
| Args: |
| model: CLIP model instance |
| tokenizer: CLIP tokenizer instance |
| classnames: A sequence of class (label) names |
| templates: A sequence of callables or format() friendly strings to produce templates per class name |
| num_classes_per_batch: The number of classes to batch together in each forward, all if None |
| device: Device to use. |
| use_tqdm: Enable TQDM progress bar. |
| """ |
| assert isinstance(templates, Sequence) and len(templates) > 0 |
| assert isinstance(classnames, Sequence) and len(classnames) > 0 |
| use_format = isinstance(templates[0], str) |
| num_templates = len(templates) |
| num_classes = len(classnames) |
| if use_tqdm: |
| import tqdm |
| num_iter = 1 if num_classes_per_batch is None else ((num_classes - 1) // num_classes_per_batch + 1) |
| iter_wrap = partial(tqdm.tqdm, total=num_iter, unit_scale=num_classes_per_batch) |
| else: |
| iter_wrap = iter |
|
|
| def _process_batch(batch_classnames): |
| num_batch_classes = len(batch_classnames) |
| texts = [template.format(c) if use_format else template(c) for c in batch_classnames for template in templates] |
| texts = tokenizer(texts).to(device) |
| class_embeddings = model.encode_text(texts, normalize=True) |
| class_embeddings = class_embeddings.reshape(num_batch_classes, num_templates, -1).mean(dim=1) |
| class_embeddings = class_embeddings / class_embeddings.norm(dim=1, keepdim=True) |
| class_embeddings = class_embeddings.T |
| return class_embeddings |
|
|
| with torch.no_grad(): |
| if num_classes_per_batch: |
| batched_embeds = [_process_batch(batch) for batch in iter_wrap(batched(classnames, num_classes_per_batch))] |
| zeroshot_weights = torch.cat(batched_embeds, dim=1) |
| else: |
| zeroshot_weights = _process_batch(classnames) |
| return zeroshot_weights |
|
|
|
|
| def build_zero_shot_classifier_legacy( |
| model, |
| tokenizer, |
| classnames: Sequence[str], |
| templates: Sequence[Union[Callable, str]], |
| device: Union[str, torch.device] = 'cpu', |
| use_tqdm: bool = False, |
| ): |
| """ Build zero-shot classifier weights by iterating over class names 1 by 1 |
| Args: |
| model: CLIP model instance |
| tokenizer: CLIP tokenizer instance |
| classnames: A sequence of class (label) names |
| templates: A sequence of callables or format() friendly strings to produce templates per class name |
| device: Device to use. |
| use_tqdm: Enable TQDM progress bar. |
| """ |
| assert isinstance(templates, Sequence) and len(templates) > 0 |
| assert isinstance(classnames, Sequence) and len(classnames) > 0 |
| if use_tqdm: |
| import tqdm |
| iter_wrap = tqdm.tqdm |
| else: |
| iter_wrap = iter |
|
|
| use_format = isinstance(templates[0], str) |
|
|
| with torch.no_grad(): |
| zeroshot_weights = [] |
| for classname in iter_wrap(classnames): |
| texts = [template.format(classname) if use_format else template(classname) for template in templates] |
| texts = tokenizer(texts).to(device) |
| class_embeddings = model.encode_text(texts) |
| class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) |
| class_embedding /= class_embedding.norm() |
| zeroshot_weights.append(class_embedding) |
| zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device) |
|
|
| return zeroshot_weights |
|
|
|
|