diff --git a/open_clip/src/open_clip/__init__.py b/open_clip/src/open_clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d0419b4d7887b5af810f6251c9e4b3c18971b59a --- /dev/null +++ b/open_clip/src/open_clip/__init__.py @@ -0,0 +1,18 @@ +from .version import __version__ + +from .coca_model import CoCa +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss +from .factory import list_models, add_model_config, get_model_config, load_checkpoint +from .loss import ClipLoss, DistillClipLoss, CoCaLoss +from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \ + convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype, \ + get_model_tokenize_cfg, get_model_preprocess_cfg, set_model_preprocess_cfg +from .openai import load_openai_model, list_openai_models +from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \ + get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained +from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub +from .tokenizer import SimpleTokenizer, tokenize, decode +from .transform import image_transform, AugmentationCfg +from .zero_shot_classifier import build_zero_shot_classifier, build_zero_shot_classifier_legacy +from .zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, SIMPLE_IMAGENET_TEMPLATES, IMAGENET_CLASSNAMES diff --git a/open_clip/src/open_clip/coca_model.py b/open_clip/src/open_clip/coca_model.py new file mode 100644 index 0000000000000000000000000000000000000000..539616332aa25f2008a7d09ee2e40d79b35ff351 --- /dev/null +++ b/open_clip/src/open_clip/coca_model.py @@ -0,0 +1,500 @@ +from typing import Optional + +import torch +from torch import nn +from torch.nn import functional as F +import numpy as np +from dataclasses import dataclass + +from .transformer import ( + LayerNormFp32, + LayerNorm, + QuickGELU, + MultimodalTransformer, +) +from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower + +try: + from transformers import ( + BeamSearchScorer, + LogitsProcessorList, + TopPLogitsWarper, + TopKLogitsWarper, + RepetitionPenaltyLogitsProcessor, + MinLengthLogitsProcessor, + MaxLengthCriteria, + StopStringCriteria, + EosTokenCriteria, + StoppingCriteriaList + ) + + GENERATION_TYPES = { + "top_k": TopKLogitsWarper, + "top_p": TopPLogitsWarper, + "beam_search": "beam_search" + } + _has_transformers = True +except ImportError as e: + GENERATION_TYPES = { + "top_k": None, + "top_p": None, + "beam_search": "beam_search" + } + _has_transformers = False + + +@dataclass +class MultimodalCfg(CLIPTextCfg): + mlp_ratio: int = 4 + dim_head: int = 64 + heads: int = 8 + n_queries: int = 256 + attn_pooler_heads: int = 8 + + +def _build_text_decoder_tower( + embed_dim, + multimodal_cfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, +): + multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg + act_layer = QuickGELU if quick_gelu else nn.GELU + norm_layer = ( + LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + ) + + decoder = MultimodalTransformer( + context_length=multimodal_cfg.context_length, + width=multimodal_cfg.width, + heads=multimodal_cfg.heads, + layers=multimodal_cfg.layers, + ls_init_value=multimodal_cfg.ls_init_value, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + return decoder + + +def _token_to_tensor(token_id, device: str = "cpu") -> torch.Tensor: + if not isinstance(token_id, torch.Tensor): + if isinstance(token_id, int): + token_id = [token_id] + token_id = torch.tensor(token_id, device=device) + return token_id + + +class CoCa(nn.Module): + def __init__( + self, + embed_dim, + multimodal_cfg: MultimodalCfg, + text_cfg: CLIPTextCfg, + vision_cfg: CLIPVisionCfg, + quick_gelu: bool = False, + init_logit_scale: float = np.log(1 / 0.07), + init_logit_bias: Optional[float] = None, + nonscalar_logit_scale: bool = False, + cast_dtype: Optional[torch.dtype] = None, + pad_id: int = 0, + ): + super().__init__() + multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg + text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg + vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg + + self.text = _build_text_tower( + embed_dim=embed_dim, + text_cfg=text_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + vocab_size = ( + text_cfg.vocab_size # for hf models + if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None + else text_cfg.vocab_size + ) + + self.visual = _build_vision_tower( + embed_dim=embed_dim, + vision_cfg=vision_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + self.text_decoder = _build_text_decoder_tower( + vocab_size, + multimodal_cfg=multimodal_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + lshape = [1] if nonscalar_logit_scale else [] + self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale) + if init_logit_bias is not None: + self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias) + else: + self.logit_bias = None + self.pad_id = pad_id + + self.context_length = multimodal_cfg.context_length + + @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True): + self.visual.set_grad_checkpointing(enable) + self.text.set_grad_checkpointing(enable) + self.text_decoder.set_grad_checkpointing(enable) + + def _encode_image(self, images, normalize: bool = True): + image_latent, tokens_embs = self.visual(images) + image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent + return image_latent, tokens_embs + + def _encode_text(self, text, normalize: bool = True): + text_latent, token_emb = self.text(text) + text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent + return text_latent, token_emb + + def encode_image(self, images, normalize: bool = True): + image_latent, _ = self._encode_image(images, normalize=normalize) + return image_latent + + def encode_text(self, text, normalize: bool = True): + text_latent, _ = self._encode_text(text, normalize=normalize) + return text_latent + + def forward( + self, + image, + text: Optional[torch.Tensor] = None, + image_latent: Optional[torch.Tensor] = None, + image_embs: Optional[torch.Tensor] = None, + output_labels: bool = True, + ): + if image_latent is None or image_embs is None: + image_latent, image_embs = self._encode_image(image) + + if text is None: + return {"image_features": image_latent, "image_embs": image_embs} + + text_latent, token_embs = self._encode_text(text) + + # FIXME this isn't an ideal solution, would like to improve -RW + labels: Optional[torch.Tensor] = text[:, 1:] if output_labels else None + if output_labels: + # align text_embs and thus logits with labels for teacher-forcing caption loss + token_embs = token_embs[:, :-1] + + logits = self.text_decoder(image_embs, token_embs) + out_dict = { + "image_features": image_latent, + "text_features": text_latent, + "logits": logits, + "logit_scale": self.logit_scale.exp() + } + if labels is not None: + out_dict["labels"] = labels + if self.logit_bias is not None: + out_dict["logit_bias"] = self.logit_bias + return out_dict + + def generate( + self, + image, + text=None, + seq_len=30, + max_seq_len=77, + temperature=1., + generation_type="beam_search", + top_p=0.1, # keep tokens in the 1 - top_p quantile + top_k=1, # keeps the top_k most probable tokens + pad_token_id=None, + eos_token_id=None, + sot_token_id=None, + num_beams=6, + num_beam_groups=3, + min_seq_len=5, + stopping_criteria=None, + repetition_penalty=1.0, + fixed_output_length=False # if True output.shape == (batch_size, seq_len) + ): + # taking many ideas and components from HuggingFace GenerationMixin + # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation + assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`." + assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len" + device = image.device + + with torch.no_grad(): + sot_token_id = _token_to_tensor(49406 if sot_token_id is None else sot_token_id, device=device) + eos_token_id = _token_to_tensor(49407 if eos_token_id is None else eos_token_id, device=device) + pad_token_id = self.pad_id if pad_token_id is None else pad_token_id + logit_processor = LogitsProcessorList( + [ + MinLengthLogitsProcessor(min_seq_len, eos_token_id), + RepetitionPenaltyLogitsProcessor(repetition_penalty), + ] + ) + + if stopping_criteria is None: + stopping_criteria = [MaxLengthCriteria(max_length=seq_len)] + stopping_criteria = StoppingCriteriaList(stopping_criteria) + + if generation_type == "beam_search": + output = self._generate_beamsearch( + image_inputs=image, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + sot_token_id=sot_token_id, + num_beams=num_beams, + num_beam_groups=num_beam_groups, + min_seq_len=min_seq_len, + stopping_criteria=stopping_criteria, + logit_processor=logit_processor, + ) + if fixed_output_length and output.shape[1] < seq_len: + pad_len = seq_len - output.shape[1] + return torch.cat(( + output, + torch.ones(output.shape[0], pad_len, device=device, dtype=output.dtype) * pad_token_id + ), + dim=1 + ) + return output + + elif generation_type == "top_p": + logit_warper = GENERATION_TYPES[generation_type](top_p) + elif generation_type == "top_k": + logit_warper = GENERATION_TYPES[generation_type](top_k) + else: + raise ValueError( + f"generation_type has to be one of " + f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}." + ) + + image_latent, image_embs = self._encode_image(image) + + if text is None: + text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id + + was_training = self.training + num_dims = len(text.shape) + + if num_dims == 1: + text = text[None, :] + + self.eval() + out = text + + while True: + x = out[:, -max_seq_len:] + cur_len = x.shape[1] + logits = self( + image, + x, + image_latent=image_latent, + image_embs=image_embs, + output_labels=False, + )["logits"][:, -1] + mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id) + sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id + + if mask.all(): + if not fixed_output_length: + break + else: + logits = logits[~mask, :] + filtered_logits = logit_processor(x[~mask, :], logits) + filtered_logits = logit_warper(x[~mask, :], filtered_logits) + probs = F.softmax(filtered_logits / temperature, dim=-1) + + if (cur_len + 1 == seq_len): + sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id + else: + sample[~mask, :] = torch.multinomial(probs, 1) + + out = torch.cat((out, sample), dim=-1) + + cur_len += 1 + + if all(stopping_criteria(out, None)): + break + + if num_dims == 1: + out = out.squeeze(0) + + self.train(was_training) + return out + + def _generate_beamsearch( + self, + image_inputs, + pad_token_id=None, + eos_token_id=None, + sot_token_id=None, + num_beams=6, + num_beam_groups=3, + min_seq_len=5, + stopping_criteria=None, + logit_processor=None, + logit_warper=None, + ): + device = image_inputs.device + batch_size = image_inputs.shape[0] + image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0) + image_latent, image_embs = self._encode_image(image_inputs) + + input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long) + input_ids = input_ids * sot_token_id + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=num_beams, + device=device, + num_beam_groups=num_beam_groups, + ) + # instantiate logits processors + logits_processor = ( + LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)]) + if logit_processor is None + else logit_processor + ) + + num_beams = beam_scorer.num_beams + num_beam_groups = beam_scorer.num_beam_groups + num_sub_beams = num_beams // num_beam_groups + batch_size = len(beam_scorer._beam_hyps) // num_beam_groups + batch_beam_size, cur_len = input_ids.shape + beam_indices = None + + if num_beams * batch_size != batch_beam_size: + raise ValueError( + f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." + ) + + beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) + # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in + # the same group don't produce same tokens everytime. + beam_scores[:, ::num_sub_beams] = 0 + beam_scores = beam_scores.view((batch_size * num_beams,)) + + while True: + + # predicted tokens in cur_len step + current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) + + # indices which will form the beams in the next time step + reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) + + # do one decoder step on all beams of all sentences in batch + model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs) + outputs = self( + model_inputs['images'], + model_inputs['text'], + image_latent=image_latent, + image_embs=image_embs, + output_labels=False, + ) + + for beam_group_idx in range(num_beam_groups): + group_start_idx = beam_group_idx * num_sub_beams + group_end_idx = min(group_start_idx + num_sub_beams, num_beams) + group_size = group_end_idx - group_start_idx + + # indices of beams of current group among all sentences in batch + batch_group_indices = [] + + for batch_idx in range(batch_size): + batch_group_indices.extend( + [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] + ) + group_input_ids = input_ids[batch_group_indices] + + # select outputs of beams of currentg group only + next_token_logits = outputs['logits'][batch_group_indices, -1, :] + vocab_size = next_token_logits.shape[-1] + + next_token_scores_processed = logits_processor( + group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx + ) + next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) + next_token_scores = next_token_scores.expand_as(next_token_scores_processed) + + # reshape for beam search + next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) + + next_token_scores, next_tokens = torch.topk( + next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True + ) + + next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") + next_tokens = next_tokens % vocab_size + + # stateless + process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None + beam_outputs = beam_scorer.process( + group_input_ids, + next_token_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + beam_indices=process_beam_indices, + group_index=beam_group_idx, + ) + beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + + input_ids[batch_group_indices] = group_input_ids[beam_idx] + group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + current_tokens[batch_group_indices] = group_input_ids[:, -1] + + # (beam_idx // group_size) -> batch_idx + # (beam_idx % group_size) -> offset of idx inside the group + reordering_indices[batch_group_indices] = ( + num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size) + ) + + input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) + + # increase cur_len + cur_len = cur_len + 1 + if beam_scorer.is_done or all(stopping_criteria(input_ids, None)): + break + + final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None + sequence_outputs = beam_scorer.finalize( + input_ids, + beam_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + max_length=stopping_criteria.max_length, + beam_indices=final_beam_indices, + ) + return sequence_outputs['sequences'] + + +def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs): + if past: + input_ids = input_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + else: + position_ids = None + return { + "text": input_ids, + "images": image_inputs, + "past_key_values": past, + "position_ids": position_ids, + "attention_mask": attention_mask, + } diff --git a/open_clip/src/open_clip/loss.py b/open_clip/src/open_clip/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..b3e6dd256fb0368281ae3ce4a081ffbb955f37bb --- /dev/null +++ b/open_clip/src/open_clip/loss.py @@ -0,0 +1,448 @@ +from typing import Optional + +import torch +import torch.nn as nn +from torch.nn import functional as F + +try: + import torch.distributed.nn + from torch import distributed as dist + + has_distributed = True +except ImportError: + has_distributed = False + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + + +def gather_features( + image_features, + text_features, + local_loss=False, + gather_with_grad=False, + rank=0, + world_size=1, + use_horovod=False +): + assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.' + if use_horovod: + assert hvd is not None, 'Please install horovod' + if gather_with_grad: + all_image_features = hvd.allgather(image_features) + all_text_features = hvd.allgather(text_features) + else: + with torch.no_grad(): + all_image_features = hvd.allgather(image_features) + all_text_features = hvd.allgather(text_features) + if not local_loss: + # ensure grads for local rank when all_* features don't have a gradient + gathered_image_features = list(all_image_features.chunk(world_size, dim=0)) + gathered_text_features = list(all_text_features.chunk(world_size, dim=0)) + gathered_image_features[rank] = image_features + gathered_text_features[rank] = text_features + all_image_features = torch.cat(gathered_image_features, dim=0) + all_text_features = torch.cat(gathered_text_features, dim=0) + else: + # We gather tensors from all gpus + if gather_with_grad: + all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0) + all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0) + else: + gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)] + gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)] + dist.all_gather(gathered_image_features, image_features) + dist.all_gather(gathered_text_features, text_features) + if not local_loss: + # ensure grads for local rank when all_* features don't have a gradient + gathered_image_features[rank] = image_features + gathered_text_features[rank] = text_features + all_image_features = torch.cat(gathered_image_features, dim=0) + all_text_features = torch.cat(gathered_text_features, dim=0) + + return all_image_features, all_text_features + + +class ClipLoss(nn.Module): + + def __init__( + self, + local_loss=False, + gather_with_grad=False, + cache_labels=False, + rank=0, + world_size=1, + use_horovod=False, + ): + super().__init__() + self.local_loss = local_loss + self.gather_with_grad = gather_with_grad + self.cache_labels = cache_labels + self.rank = rank + self.world_size = world_size + self.use_horovod = use_horovod + + # cache state + self.prev_num_logits = 0 + self.labels = {} + + def get_ground_truth(self, device, num_logits) -> torch.Tensor: + # calculated ground-truth and cache if enabled + if self.prev_num_logits != num_logits or device not in self.labels: + labels = torch.arange(num_logits, device=device, dtype=torch.long) + if self.world_size > 1 and self.local_loss: + labels = labels + num_logits * self.rank + if self.cache_labels: + self.labels[device] = labels + self.prev_num_logits = num_logits + else: + labels = self.labels[device] + return labels + + def get_logits(self, image_features, text_features, logit_scale): + if self.world_size > 1: + all_image_features, all_text_features = gather_features( + image_features, + text_features, + local_loss=self.local_loss, + gather_with_grad=self.gather_with_grad, + rank=self.rank, + world_size=self.world_size, + use_horovod=self.use_horovod, + ) + + if self.local_loss: + logits_per_image = logit_scale * image_features @ all_text_features.T + logits_per_text = logit_scale * text_features @ all_image_features.T + else: + logits_per_image = logit_scale * all_image_features @ all_text_features.T + logits_per_text = logits_per_image.T + else: + logits_per_image = logit_scale * image_features @ text_features.T + logits_per_text = logit_scale * text_features @ image_features.T + + return logits_per_image, logits_per_text + + def forward(self, image_features, text_features, logit_scale, output_dict=False): + device = image_features.device + logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale) + + labels = self.get_ground_truth(device, logits_per_image.shape[0]) + + total_loss = ( + F.cross_entropy(logits_per_image, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + return {"contrastive_loss": total_loss} if output_dict else total_loss + + +class CoCaLoss(ClipLoss): + def __init__( + self, + caption_loss_weight, + clip_loss_weight, + pad_id=0, # pad_token for open_clip custom tokenizer + local_loss=False, + gather_with_grad=False, + cache_labels=False, + rank=0, + world_size=1, + use_horovod=False, + ): + super().__init__( + local_loss=local_loss, + gather_with_grad=gather_with_grad, + cache_labels=cache_labels, + rank=rank, + world_size=world_size, + use_horovod=use_horovod + ) + + self.clip_loss_weight = clip_loss_weight + self.caption_loss_weight = caption_loss_weight + self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id) + + def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False): + if self.clip_loss_weight: + clip_loss = super().forward(image_features, text_features, logit_scale) + clip_loss = self.clip_loss_weight * clip_loss + else: + clip_loss = torch.tensor(0, device=logits.device) + + caption_loss = self.caption_loss( + logits.permute(0, 2, 1), + labels, + ) + caption_loss = caption_loss * self.caption_loss_weight + + if output_dict: + return {"contrastive_loss": clip_loss, "caption_loss": caption_loss} + + return clip_loss, caption_loss + + +class DistillClipLoss(ClipLoss): + + def dist_loss(self, teacher_logits, student_logits): + return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0) + + def forward( + self, + image_features, + text_features, + logit_scale, + dist_image_features, + dist_text_features, + dist_logit_scale, + output_dict=False, + ): + logits_per_image, logits_per_text = \ + self.get_logits(image_features, text_features, logit_scale) + + dist_logits_per_image, dist_logits_per_text = \ + self.get_logits(dist_image_features, dist_text_features, dist_logit_scale) + + labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0]) + + contrastive_loss = ( + F.cross_entropy(logits_per_image, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + distill_loss = ( + self.dist_loss(dist_logits_per_image, logits_per_image) + + self.dist_loss(dist_logits_per_text, logits_per_text) + ) / 2 + + if output_dict: + return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss} + + return contrastive_loss, distill_loss + + +def neighbour_exchange(from_rank, to_rank, tensor, group=None): + tensor_recv = torch.zeros_like(tensor) + send_op = torch.distributed.P2POp( + torch.distributed.isend, + tensor, + to_rank, + group=group, + ) + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, + tensor_recv, + from_rank, + group=group, + ) + reqs = torch.distributed.batch_isend_irecv([send_op, recv_op]) + for req in reqs: + req.wait() + return tensor_recv + + +def neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None): + tensor_from_left = torch.zeros_like(tensor_to_right) + tensor_from_right = torch.zeros_like(tensor_to_left) + send_op_left = torch.distributed.P2POp( + torch.distributed.isend, + tensor_to_left, + left_rank, + group=group, + ) + send_op_right = torch.distributed.P2POp( + torch.distributed.isend, + tensor_to_right, + right_rank, + group=group, + ) + recv_op_left = torch.distributed.P2POp( + torch.distributed.irecv, + tensor_from_left, + left_rank, + group=group, + ) + recv_op_right = torch.distributed.P2POp( + torch.distributed.irecv, + tensor_from_right, + right_rank, + group=group, + ) + reqs = torch.distributed.batch_isend_irecv([send_op_right, send_op_left, recv_op_right, recv_op_left]) + for req in reqs: + req.wait() + return tensor_from_right, tensor_from_left + + +class NeighbourExchange(torch.autograd.Function): + @staticmethod + def forward(ctx, from_rank, to_rank, group, tensor): + ctx.group = group + ctx.from_rank = from_rank + ctx.to_rank = to_rank + return neighbour_exchange(from_rank, to_rank, tensor, group=group) + + @staticmethod + def backward(ctx, grad_output): + return (None, None, None) + (NeighbourExchange.apply(ctx.to_rank, ctx.from_rank, ctx.group, grad_output),) + + +def neighbour_exchange_with_grad(from_rank, to_rank, tensor, group=None): + return NeighbourExchange.apply(from_rank, to_rank, group, tensor) + + +class NeighbourExchangeBidir(torch.autograd.Function): + @staticmethod + def forward(ctx, left_rank, right_rank, group, tensor_to_left, tensor_to_right): + ctx.group = group + ctx.left_rank = left_rank + ctx.right_rank = right_rank + return neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=group) + + @staticmethod + def backward(ctx, *grad_outputs): + return (None, None, None) + \ + NeighbourExchangeBidir.apply(ctx.right_rank, ctx.left_rank, ctx.group, *grad_outputs) + + +def neighbour_exchange_bidir_with_grad(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None): + return NeighbourExchangeBidir.apply(left_rank, right_rank, group, tensor_to_left, tensor_to_right) + + +class SigLipLoss(nn.Module): + """ Sigmoid Loss for Language Image Pre-Training (SigLIP) - https://arxiv.org/abs/2303.15343 + + @article{zhai2023sigmoid, + title={Sigmoid loss for language image pre-training}, + author={Zhai, Xiaohua and Mustafa, Basil and Kolesnikov, Alexander and Beyer, Lucas}, + journal={arXiv preprint arXiv:2303.15343}, + year={2023} + } + """ + def __init__( + self, + cache_labels: bool = False, + rank: int = 0, + world_size: int = 1, + dist_impl: Optional[str] = None, + ): + super().__init__() + self.cache_labels = cache_labels + self.rank = rank + self.world_size = world_size + self.dist_impl = dist_impl or 'bidir' # default to bidir exchange for now, this will likely change + assert self.dist_impl in ('bidir', 'shift', 'reduce', 'gather') + + # cache state FIXME cache not currently used, worthwhile? + self.prev_num_logits = 0 + self.labels = {} + + def get_ground_truth(self, device, dtype, num_logits, negative_only=False) -> torch.Tensor: + labels = -torch.ones((num_logits, num_logits), device=device, dtype=dtype) + if not negative_only: + labels = 2 * torch.eye(num_logits, device=device, dtype=dtype) + labels + return labels + + def get_logits(self, image_features, text_features, logit_scale, logit_bias=None): + logits = logit_scale * image_features @ text_features.T + if logit_bias is not None: + logits += logit_bias + return logits + + def _loss(self, image_features, text_features, logit_scale, logit_bias=None, negative_only=False): + logits = self.get_logits(image_features, text_features, logit_scale, logit_bias) + labels = self.get_ground_truth( + image_features.device, + image_features.dtype, + image_features.shape[0], + negative_only=negative_only, + ) + loss = -F.logsigmoid(labels * logits).sum() / image_features.shape[0] + return loss + + def forward(self, image_features, text_features, logit_scale, logit_bias, output_dict=False): + loss = self._loss(image_features, text_features, logit_scale, logit_bias) + + if self.world_size > 1: + if self.dist_impl == 'bidir': + right_rank = (self.rank + 1) % self.world_size + left_rank = (self.rank - 1 + self.world_size) % self.world_size + text_features_to_right = text_features_to_left = text_features + num_bidir, remainder = divmod(self.world_size - 1, 2) + for i in range(num_bidir): + text_features_recv = neighbour_exchange_bidir_with_grad( + left_rank, + right_rank, + text_features_to_left, + text_features_to_right, + ) + for f in text_features_recv: + loss += self._loss( + image_features, + f, + logit_scale, + logit_bias, + negative_only=True, + ) + text_features_to_left, text_features_to_right = text_features_recv + + if remainder: + text_features_recv = neighbour_exchange_with_grad( + left_rank, + right_rank, + text_features_to_right + ) + loss += self._loss( + image_features, + text_features_recv, + logit_scale, + logit_bias, + negative_only=True, + ) + elif self.dist_impl == "shift": + right_rank = (self.rank + 1) % self.world_size + left_rank = (self.rank - 1 + self.world_size) % self.world_size + text_features_to_right = text_features + for i in range(self.world_size - 1): + text_features_from_left = neighbour_exchange_with_grad( + left_rank, + right_rank, + text_features_to_right, + ) + loss += self._loss( + image_features, + text_features_from_left, + logit_scale, + logit_bias, + negative_only=True, + ) + text_features_to_right = text_features_from_left + elif self.dist_impl == "reduce": + for i in range(self.world_size): + text_from_other = torch.distributed.nn.all_reduce( + text_features * (self.rank == i), + torch.distributed.ReduceOp.SUM, + ) + loss += float(i != self.rank) * self._loss( + image_features, + text_from_other, + logit_scale, + logit_bias, + negative_only=True, + ) + elif self.dist_impl == "gather": + all_text = torch.distributed.nn.all_gather(text_features) + for i in range(self.world_size): + loss += float(i != self.rank) * self._loss( + image_features, + all_text[i], + logit_scale, + logit_bias, + negative_only=True, + ) + else: + assert False + + return {"contrastive_loss": loss} if output_dict else loss diff --git a/open_clip/src/open_clip/model_configs/EVA01-g-14.json b/open_clip/src/open_clip/model_configs/EVA01-g-14.json new file mode 100644 index 0000000000000000000000000000000000000000..9d0e80f290d9491b7c46fafd576201b1258165aa --- /dev/null +++ b/open_clip/src/open_clip/model_configs/EVA01-g-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva_giant_patch14_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/EVA02-E-14.json b/open_clip/src/open_clip/model_configs/EVA02-E-14.json new file mode 100644 index 0000000000000000000000000000000000000000..4b6648e25092b151a9095e0a66956c7ebf835b16 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/EVA02-E-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva02_enormous_patch14_clip_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/EVA02-L-14.json b/open_clip/src/open_clip/model_configs/EVA02-L-14.json new file mode 100644 index 0000000000000000000000000000000000000000..b4c7f377bc543aa92a145358f2630a58ae9be989 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/EVA02-L-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva02_large_patch14_clip_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/MobileCLIP-B.json b/open_clip/src/open_clip/model_configs/MobileCLIP-B.json new file mode 100644 index 0000000000000000000000000000000000000000..9907d86b37a60918405e5e3f2cf237bad889a0ce --- /dev/null +++ b/open_clip/src/open_clip/model_configs/MobileCLIP-B.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "vit_base_mci_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null, + "timm_drop": 0.0, + "timm_drop_path": 0.0, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12, + "no_causal_mask": false + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/MobileCLIP-S1.json b/open_clip/src/open_clip/model_configs/MobileCLIP-S1.json new file mode 100644 index 0000000000000000000000000000000000000000..80780c5eac6f3f9e7b09bc891abb63599e4464f3 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/MobileCLIP-S1.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "fastvit_mci1", + "timm_model_pretrained": false, + "timm_pool": "avg", + "timm_proj": null, + "timm_drop": 0.0, + "timm_drop_path": 0.0, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12, + "no_causal_mask": true + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/MobileCLIP-S2.json b/open_clip/src/open_clip/model_configs/MobileCLIP-S2.json new file mode 100644 index 0000000000000000000000000000000000000000..66ebc16aaab350091f29c8330c15ead59c228609 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/MobileCLIP-S2.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "fastvit_mci2", + "timm_model_pretrained": false, + "timm_pool": "avg", + "timm_proj": null, + "timm_drop": 0.0, + "timm_drop_path": 0.0, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12, + "no_causal_mask": true + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/RN101-quickgelu.json b/open_clip/src/open_clip/model_configs/RN101-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..d0db2c161d13138788c4609d373b023b8454d624 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/RN101-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 23, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/RN50-quickgelu.json b/open_clip/src/open_clip/model_configs/RN50-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..8c2f91260cdeb043434dc1e893cce81d4ce7f0d1 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/RN50-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 1024, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 6, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/open_clip/src/open_clip/model_configs/RN50.json b/open_clip/src/open_clip/model_configs/RN50.json new file mode 100644 index 0000000000000000000000000000000000000000..33aa884d54fee0076c33676831e49d5e1ffcb8f2 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/RN50.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 6, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/RN50x64-quickgelu.json b/open_clip/src/open_clip/model_configs/RN50x64-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..6da9d7e219b8e3ed233909055308f994187ebae7 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/RN50x64-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 1024, + "quick_gelu": true, + "vision_cfg": { + "image_size": 448, + "layers": [ + 3, + 15, + 36, + 10 + ], + "width": 128, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/ViT-B-16-SigLIP-384.json b/open_clip/src/open_clip/model_configs/ViT-B-16-SigLIP-384.json new file mode 100644 index 0000000000000000000000000000000000000000..df9a25cdca5207a8954801c0f2cf28514c15a1cd --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViT-B-16-SigLIP-384.json @@ -0,0 +1,29 @@ +{ + "embed_dim": 768, + "init_logit_bias": -10, + "custom_text": true, + "vision_cfg": { + "image_size": 384, + "timm_model_name": "vit_base_patch16_siglip_384", + "timm_model_pretrained": false, + "timm_pool": "map", + "timm_proj": "none" + }, + "text_cfg": { + "context_length": 64, + "vocab_size": 32000, + "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", + "tokenizer_kwargs": { + "clean": "canonicalize" + }, + "width": 768, + "heads": 12, + "layers": 12, + "no_causal_mask": true, + "proj_bias": true, + "pool_type": "last", + "norm_kwargs":{ + "eps": 1e-6 + } + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/ViT-B-16-SigLIP-512.json b/open_clip/src/open_clip/model_configs/ViT-B-16-SigLIP-512.json new file mode 100644 index 0000000000000000000000000000000000000000..88b018528b2e7806cd11b95d5808136786ea0f97 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViT-B-16-SigLIP-512.json @@ -0,0 +1,29 @@ +{ + "embed_dim": 768, + "init_logit_bias": -10, + "custom_text": true, + "vision_cfg": { + "image_size": 512, + "timm_model_name": "vit_base_patch16_siglip_512", + "timm_model_pretrained": false, + "timm_pool": "map", + "timm_proj": "none" + }, + "text_cfg": { + "context_length": 64, + "vocab_size": 32000, + "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", + "tokenizer_kwargs": { + "clean": "canonicalize" + }, + "width": 768, + "heads": 12, + "layers": 12, + "no_causal_mask": true, + "proj_bias": true, + "pool_type": "last", + "norm_kwargs":{ + "eps": 1e-6 + } + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/ViT-B-16-SigLIP.json b/open_clip/src/open_clip/model_configs/ViT-B-16-SigLIP.json new file mode 100644 index 0000000000000000000000000000000000000000..a9f2b654a671c9bd235f351b2a253ca889758549 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViT-B-16-SigLIP.json @@ -0,0 +1,29 @@ +{ + "embed_dim": 768, + "init_logit_bias": -10, + "custom_text": true, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "vit_base_patch16_siglip_224", + "timm_model_pretrained": false, + "timm_pool": "map", + "timm_proj": "none" + }, + "text_cfg": { + "context_length": 64, + "vocab_size": 32000, + "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", + "tokenizer_kwargs": { + "clean": "canonicalize" + }, + "width": 768, + "heads": 12, + "layers": 12, + "no_causal_mask": true, + "proj_bias": true, + "pool_type": "last", + "norm_kwargs":{ + "eps": 1e-6 + } + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/ViT-B-32-256.json b/open_clip/src/open_clip/model_configs/ViT-B-32-256.json new file mode 100644 index 0000000000000000000000000000000000000000..80a2597d8f7d5d500df2aacbded9507196dad6da --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViT-B-32-256.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 256, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/open_clip/src/open_clip/model_configs/ViT-H-14-378.json b/open_clip/src/open_clip/model_configs/ViT-H-14-378.json new file mode 100644 index 0000000000000000000000000000000000000000..04b2e62d60d031b1a5762e365e070e52b6fea7b1 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViT-H-14-378.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 378, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/ViT-H-14-CLIPA-336.json b/open_clip/src/open_clip/model_configs/ViT-H-14-CLIPA-336.json new file mode 100644 index 0000000000000000000000000000000000000000..01fabb29db2bcbd9513e903064d61e3e1974d580 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViT-H-14-CLIPA-336.json @@ -0,0 +1,26 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 336, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14, + "no_ln_pre": true, + "pool_type": "avg", + "final_ln_after_pool": true + }, + "text_cfg": { + "context_length": 32, + "vocab_size": 32000, + "hf_tokenizer_name": "bert-base-uncased", + "tokenizer_kwargs": { + "strip_sep_token": true + }, + "width": 1024, + "heads": 16, + "layers": 24, + "pool_type": "last", + "no_causal_mask": true + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/ViT-H-14-quickgelu.json b/open_clip/src/open_clip/model_configs/ViT-H-14-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..41f22f65bb002c320111790e0cd0f2425a575df7 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViT-H-14-quickgelu.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/ViT-H-14.json b/open_clip/src/open_clip/model_configs/ViT-H-14.json new file mode 100644 index 0000000000000000000000000000000000000000..3e3a7e934e7f02e41f4829996c4950e05f015a74 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViT-H-14.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/ViT-L-14-280.json b/open_clip/src/open_clip/model_configs/ViT-L-14-280.json new file mode 100644 index 0000000000000000000000000000000000000000..2262deaefa82792d35d73c0d7c8e620525092581 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViT-L-14-280.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 280, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/ViT-L-14-336-quickgelu.json b/open_clip/src/open_clip/model_configs/ViT-L-14-336-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..d928c0284c692dfe738be8cbf4a0e2eb939bcf41 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViT-L-14-336-quickgelu.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 768, + "quick_gelu": true, + "vision_cfg": { + "image_size": 336, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/ViT-L-14-quickgelu.json b/open_clip/src/open_clip/model_configs/ViT-L-14-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..d5a3fd36aa9cd9cc4a3dc29e362945cec13a02f3 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViT-L-14-quickgelu.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 768, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/ViT-L-14.json b/open_clip/src/open_clip/model_configs/ViT-L-14.json new file mode 100644 index 0000000000000000000000000000000000000000..d4a4bbb1dd4ed4edb317d3ace4f3ad13b211c241 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViT-L-14.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/ViT-L-16-320.json b/open_clip/src/open_clip/model_configs/ViT-L-16-320.json new file mode 100644 index 0000000000000000000000000000000000000000..fc2d13ca9ec7f0b56a886ddaf66c4a7ba7a442ba --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViT-L-16-320.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 320, + "layers": 24, + "width": 1024, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/ViT-L-16-SigLIP-384.json b/open_clip/src/open_clip/model_configs/ViT-L-16-SigLIP-384.json new file mode 100644 index 0000000000000000000000000000000000000000..fd2cc2e346f7110a5de01cfaf7eae8c94360de3a --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViT-L-16-SigLIP-384.json @@ -0,0 +1,29 @@ +{ + "embed_dim": 1024, + "init_logit_bias": -10, + "custom_text": true, + "vision_cfg": { + "image_size": 384, + "timm_model_name": "vit_large_patch16_siglip_384", + "timm_model_pretrained": false, + "timm_pool": "map", + "timm_proj": "none" + }, + "text_cfg": { + "context_length": 64, + "vocab_size": 32000, + "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", + "tokenizer_kwargs": { + "clean": "canonicalize" + }, + "width": 1024, + "heads": 16, + "layers": 24, + "no_causal_mask": true, + "proj_bias": true, + "pool_type": "last", + "norm_kwargs":{ + "eps": 1e-6 + } + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/ViT-M-16-alt.json b/open_clip/src/open_clip/model_configs/ViT-M-16-alt.json new file mode 100644 index 0000000000000000000000000000000000000000..1a317aad8e02d9c26d2decc7cc49a18dfdf9e0d8 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViT-M-16-alt.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 16, + "ls_init_value": 1e-4 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/ViT-M-32-alt.json b/open_clip/src/open_clip/model_configs/ViT-M-32-alt.json new file mode 100644 index 0000000000000000000000000000000000000000..fd222aeac0f582ef6a1a33f1b3fec70a5b386ac0 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViT-M-32-alt.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/ViT-M-32.json b/open_clip/src/open_clip/model_configs/ViT-M-32.json new file mode 100644 index 0000000000000000000000000000000000000000..4f718642821035d9776d1e006817d65ede074366 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViT-M-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/ViT-S-16.json b/open_clip/src/open_clip/model_configs/ViT-S-16.json new file mode 100644 index 0000000000000000000000000000000000000000..1d8504e59658803f3093e5b05de45f30a09b8185 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViT-S-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/ViT-SO400M-14-SigLIP-378.json b/open_clip/src/open_clip/model_configs/ViT-SO400M-14-SigLIP-378.json new file mode 100644 index 0000000000000000000000000000000000000000..6bc14fabc30a9e11cbc9ca53d353f2d1216f9d2c --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViT-SO400M-14-SigLIP-378.json @@ -0,0 +1,30 @@ +{ + "embed_dim": 1152, + "init_logit_bias": -10, + "custom_text": true, + "vision_cfg": { + "image_size": 378, + "timm_model_name": "vit_so400m_patch14_siglip_378", + "timm_model_pretrained": false, + "timm_pool": "map", + "timm_proj": "none" + }, + "text_cfg": { + "context_length": 64, + "vocab_size": 32000, + "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", + "tokenizer_kwargs": { + "clean": "canonicalize" + }, + "width": 1152, + "heads": 16, + "layers": 27, + "mlp_ratio": 3.7362, + "no_causal_mask": true, + "proj_bias": true, + "pool_type": "last", + "norm_kwargs":{ + "eps": 1e-6 + } + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/ViT-SO400M-14-SigLIP-384.json b/open_clip/src/open_clip/model_configs/ViT-SO400M-14-SigLIP-384.json new file mode 100644 index 0000000000000000000000000000000000000000..4c527f581230938d7b39baf36b6bd749b0e7f169 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViT-SO400M-14-SigLIP-384.json @@ -0,0 +1,30 @@ +{ + "embed_dim": 1152, + "init_logit_bias": -10, + "custom_text": true, + "vision_cfg": { + "image_size": 384, + "timm_model_name": "vit_so400m_patch14_siglip_384", + "timm_model_pretrained": false, + "timm_pool": "map", + "timm_proj": "none" + }, + "text_cfg": { + "context_length": 64, + "vocab_size": 32000, + "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", + "tokenizer_kwargs": { + "clean": "canonicalize" + }, + "width": 1152, + "heads": 16, + "layers": 27, + "mlp_ratio": 3.7362, + "no_causal_mask": true, + "proj_bias": true, + "pool_type": "last", + "norm_kwargs":{ + "eps": 1e-6 + } + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/ViT-bigG-14-CLIPA.json b/open_clip/src/open_clip/model_configs/ViT-bigG-14-CLIPA.json new file mode 100644 index 0000000000000000000000000000000000000000..83ec709f8b8362d892067adafde9a0d78ce4db14 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViT-bigG-14-CLIPA.json @@ -0,0 +1,27 @@ +{ + "embed_dim": 1280, + "vision_cfg": { + "image_size": 224, + "layers": 48, + "width": 1664, + "head_width": 104, + "mlp_ratio": 4.9231, + "patch_size": 14, + "no_ln_pre": true, + "pool_type": "avg", + "final_ln_after_pool": true + }, + "text_cfg": { + "context_length": 32, + "vocab_size": 32000, + "hf_tokenizer_name": "bert-base-uncased", + "tokenizer_kwargs": { + "strip_sep_token": true + }, + "width": 1280, + "heads": 20, + "layers": 32, + "pool_type": "last", + "no_causal_mask": true + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/ViT-bigG-14-quickgelu.json b/open_clip/src/open_clip/model_configs/ViT-bigG-14-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..fed567cc670274e50e7ecd69954097cca1d5b081 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViT-bigG-14-quickgelu.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 1280, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": 48, + "width": 1664, + "head_width": 104, + "mlp_ratio": 4.9231, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32 + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/ViT-bigG-14.json b/open_clip/src/open_clip/model_configs/ViT-bigG-14.json new file mode 100644 index 0000000000000000000000000000000000000000..2cfba479a2e8f3737e71ce240732bf3bc743d8b7 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViT-bigG-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1280, + "vision_cfg": { + "image_size": 224, + "layers": 48, + "width": 1664, + "head_width": 104, + "mlp_ratio": 4.9231, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32 + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/ViT-e-14.json b/open_clip/src/open_clip/model_configs/ViT-e-14.json new file mode 100644 index 0000000000000000000000000000000000000000..91a0fe14d25a107fb8ec48dd7faae313fd26ed7b --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViT-e-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1280, + "vision_cfg": { + "image_size": 224, + "layers": 56, + "width": 1792, + "head_width": 112, + "mlp_ratio": 8.5715, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 36 + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/ViTamin-B-LTT.json b/open_clip/src/open_clip/model_configs/ViTamin-B-LTT.json new file mode 100644 index 0000000000000000000000000000000000000000..775621409becce43a1b1aa5bd61cdaf93c578733 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViTamin-B-LTT.json @@ -0,0 +1,20 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "vitamin_base_224", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/ViTamin-B.json b/open_clip/src/open_clip/model_configs/ViTamin-B.json new file mode 100644 index 0000000000000000000000000000000000000000..bf09a8e698b2f133f531d1567755e9f9d3510047 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViTamin-B.json @@ -0,0 +1,20 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "vitamin_base_224", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/ViTamin-L-336.json b/open_clip/src/open_clip/model_configs/ViTamin-L-336.json new file mode 100644 index 0000000000000000000000000000000000000000..63aa8cebef0f19d5276e99b104380b01f0a8c58e --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViTamin-L-336.json @@ -0,0 +1,20 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "vitamin_large_336", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 336 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/ViTamin-L-384.json b/open_clip/src/open_clip/model_configs/ViTamin-L-384.json new file mode 100644 index 0000000000000000000000000000000000000000..1278d8393686b9818c7635c2ec3e97a4ae5e57e9 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViTamin-L-384.json @@ -0,0 +1,20 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "vitamin_large_384", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 384 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + }, + "custom_text": true +} diff --git a/open_clip/src/open_clip/model_configs/ViTamin-L.json b/open_clip/src/open_clip/model_configs/ViTamin-L.json new file mode 100644 index 0000000000000000000000000000000000000000..c74e56e9df1b5548863ef42a3a08f12fb28f09bd --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViTamin-L.json @@ -0,0 +1,20 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "vitamin_large_224", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/ViTamin-L2-256.json b/open_clip/src/open_clip/model_configs/ViTamin-L2-256.json new file mode 100644 index 0000000000000000000000000000000000000000..68465befbe72ab02dd31248fd322fd8d1950d2d0 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViTamin-L2-256.json @@ -0,0 +1,20 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "vitamin_large2_256", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/ViTamin-L2-336.json b/open_clip/src/open_clip/model_configs/ViTamin-L2-336.json new file mode 100644 index 0000000000000000000000000000000000000000..4b48a526322de8c23912e258019c1737fb9336c8 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViTamin-L2-336.json @@ -0,0 +1,20 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "vitamin_large2_336", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 336 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/ViTamin-L2-384.json b/open_clip/src/open_clip/model_configs/ViTamin-L2-384.json new file mode 100644 index 0000000000000000000000000000000000000000..cc0faaae7b3a17f571b91fa98b0748261ad16fcd --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViTamin-L2-384.json @@ -0,0 +1,20 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "vitamin_large2_384", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 384 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + }, + "custom_text": true +} diff --git a/open_clip/src/open_clip/model_configs/ViTamin-L2.json b/open_clip/src/open_clip/model_configs/ViTamin-L2.json new file mode 100644 index 0000000000000000000000000000000000000000..3d14b710906775c89143b9f227bc38414ee9ad11 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViTamin-L2.json @@ -0,0 +1,20 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "vitamin_large2_224", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/ViTamin-S-LTT.json b/open_clip/src/open_clip/model_configs/ViTamin-S-LTT.json new file mode 100644 index 0000000000000000000000000000000000000000..b01c95b4132620e3908716f3a549e398b7d5089e --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViTamin-S-LTT.json @@ -0,0 +1,20 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "vitamin_small_224", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/ViTamin-S.json b/open_clip/src/open_clip/model_configs/ViTamin-S.json new file mode 100644 index 0000000000000000000000000000000000000000..1fb6cd24a681500d94284b29f645595ca2727e2a --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViTamin-S.json @@ -0,0 +1,20 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "timm_model_name": "vitamin_small_224", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/ViTamin-XL-256.json b/open_clip/src/open_clip/model_configs/ViTamin-XL-256.json new file mode 100644 index 0000000000000000000000000000000000000000..68f672f0cc3c3564f4c7ec6e25255034e5af45cb --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViTamin-XL-256.json @@ -0,0 +1,20 @@ +{ + "embed_dim": 1152, + "vision_cfg": { + "timm_model_name": "vitamin_xlarge_256", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1152, + "heads": 16, + "layers": 27 + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/ViTamin-XL-336.json b/open_clip/src/open_clip/model_configs/ViTamin-XL-336.json new file mode 100644 index 0000000000000000000000000000000000000000..116c30e7301a5b7c3869c7adf3ecf6fc82436c17 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/ViTamin-XL-336.json @@ -0,0 +1,20 @@ +{ + "embed_dim": 1152, + "vision_cfg": { + "timm_model_name": "vitamin_xlarge_336", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 336 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1152, + "heads": 16, + "layers": 27 + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/coca_ViT-B-32.json b/open_clip/src/open_clip/model_configs/coca_ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..7e7eb520a6a0096e5602d509ecd6186e278f4725 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/coca_ViT-B-32.json @@ -0,0 +1,30 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32, + "attentional_pool": true, + "attn_pooler_heads": 8, + "output_tokens": true + }, + "text_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12, + "embed_cls": true, + "output_tokens": true + }, + "multimodal_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12, + "attn_pooler_heads": 8 + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/coca_roberta-ViT-B-32.json b/open_clip/src/open_clip/model_configs/coca_roberta-ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..aa9d3f562057f849e6ced8b495de2dd73387fe61 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/coca_roberta-ViT-B-32.json @@ -0,0 +1,24 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32, + "output_tokens": true + }, + "text_cfg": { + "hf_model_name": "roberta-base", + "hf_tokenizer_name": "roberta-base", + "hf_proj_type": "linear", + "width": 768, + "output_tokens": true + }, + "multimodal_cfg": { + "context_length": 76, + "width": 768, + "heads": 8, + "layers": 12 + }, + "custom_text": true +} diff --git a/open_clip/src/open_clip/model_configs/convnext_base.json b/open_clip/src/open_clip/model_configs/convnext_base.json new file mode 100644 index 0000000000000000000000000000000000000000..bb6dba181d950ea5081155c90d47e72c94816b80 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/convnext_base.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "convnext_base", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/convnext_base_w.json b/open_clip/src/open_clip/model_configs/convnext_base_w.json new file mode 100644 index 0000000000000000000000000000000000000000..82ea7ae3659e5514f37ff982f0ab1141dff4bd18 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/convnext_base_w.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "timm_model_name": "convnext_base", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/convnext_base_w_320.json b/open_clip/src/open_clip/model_configs/convnext_base_w_320.json new file mode 100644 index 0000000000000000000000000000000000000000..0a07c4e16abaa4015ecc5f82ec845de16e1f9d88 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/convnext_base_w_320.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "timm_model_name": "convnext_base", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 320 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/convnext_large_d.json b/open_clip/src/open_clip/model_configs/convnext_large_d.json new file mode 100644 index 0000000000000000000000000000000000000000..ae8fed21b58e1a6a411daf8b792ee50f0ab42346 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/convnext_large_d.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "convnext_large", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "mlp", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 16 + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/convnext_large_d_320.json b/open_clip/src/open_clip/model_configs/convnext_large_d_320.json new file mode 100644 index 0000000000000000000000000000000000000000..54c3df36a6f56ace0b12ada24c13058de96feed8 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/convnext_large_d_320.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "convnext_large", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "mlp", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 320 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 16 + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/convnext_small.json b/open_clip/src/open_clip/model_configs/convnext_small.json new file mode 100644 index 0000000000000000000000000000000000000000..3592c2a5cd21aae8d2544931773cf7603f67ea28 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/convnext_small.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "convnext_small", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/convnext_xlarge.json b/open_clip/src/open_clip/model_configs/convnext_xlarge.json new file mode 100644 index 0000000000000000000000000000000000000000..2a909965932eef994177c829fefc2bdc1c219b3f --- /dev/null +++ b/open_clip/src/open_clip/model_configs/convnext_xlarge.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "convnext_xlarge", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 20 + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/convnext_xxlarge.json b/open_clip/src/open_clip/model_configs/convnext_xxlarge.json new file mode 100644 index 0000000000000000000000000000000000000000..23a55a681c346d1a315d8a163c1cb6ad495e6a91 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/convnext_xxlarge.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "convnext_xxlarge", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/convnext_xxlarge_320.json b/open_clip/src/open_clip/model_configs/convnext_xxlarge_320.json new file mode 100644 index 0000000000000000000000000000000000000000..ac5134ca12cbaa97772cde059270d345386a74c7 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/convnext_xxlarge_320.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "convnext_xxlarge", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 320 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/nllb-clip-base-siglip.json b/open_clip/src/open_clip/model_configs/nllb-clip-base-siglip.json new file mode 100644 index 0000000000000000000000000000000000000000..f7152d0bb6b9fd3333b46cb75934e500f1aab348 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/nllb-clip-base-siglip.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 768, + "custom_text": true, + "init_logit_bias": -10, + "vision_cfg": { + "image_size": 384, + "timm_model_name": "vit_base_patch16_siglip_384", + "timm_model_pretrained": false, + "timm_pool": "map", + "timm_proj": "none" + }, + "text_cfg": { + "hf_model_name": "facebook/nllb-200-distilled-600M", + "hf_tokenizer_name": "facebook/nllb-200-distilled-600M", + "hf_proj_type": "linear", + "hf_pooler_type": "cls_pooler" + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/nllb-clip-base.json b/open_clip/src/open_clip/model_configs/nllb-clip-base.json new file mode 100644 index 0000000000000000000000000000000000000000..57265b33f7cfd21b07741744d50cbf30208017d1 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/nllb-clip-base.json @@ -0,0 +1,15 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "hf_model_name": "facebook/nllb-200-distilled-600M", + "hf_tokenizer_name": "facebook/nllb-200-distilled-600M", + "hf_proj_type": "linear", + "hf_pooler_type": "cls_pooler" + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/nllb-clip-large-siglip.json b/open_clip/src/open_clip/model_configs/nllb-clip-large-siglip.json new file mode 100644 index 0000000000000000000000000000000000000000..0ac3485762b5117597839b3274ed85340a2c76c2 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/nllb-clip-large-siglip.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1152, + "custom_text": true, + "init_logit_bias": -10, + "vision_cfg": { + "image_size": 384, + "timm_model_name": "vit_so400m_patch14_siglip_384", + "timm_model_pretrained": false, + "timm_pool": "map", + "timm_proj": "none" + }, + "text_cfg": { + "hf_model_name": "facebook/nllb-200-distilled-1.3B", + "hf_tokenizer_name": "facebook/nllb-200-distilled-1.3B", + "hf_proj_type": "linear", + "hf_pooler_type": "cls_pooler" + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/roberta-ViT-B-32.json b/open_clip/src/open_clip/model_configs/roberta-ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..c0c7a55995d50230c6b0f0af5fbd81d5889a3d59 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/roberta-ViT-B-32.json @@ -0,0 +1,15 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "hf_model_name": "roberta-base", + "hf_tokenizer_name": "roberta-base", + "hf_pooler_type": "mean_pooler" + } +} diff --git a/open_clip/src/open_clip/model_configs/swin_base_patch4_window7_224.json b/open_clip/src/open_clip/model_configs/swin_base_patch4_window7_224.json new file mode 100644 index 0000000000000000000000000000000000000000..bd6820f0cf2aa655e0a2723287f4b78895a58e6a --- /dev/null +++ b/open_clip/src/open_clip/model_configs/swin_base_patch4_window7_224.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "timm_model_name": "swin_base_patch4_window7_224", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json b/open_clip/src/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json new file mode 100644 index 0000000000000000000000000000000000000000..ed217b202d5e6071c5307f4547c97ff4cfe2abd1 --- /dev/null +++ b/open_clip/src/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "vit_relpos_medium_patch16_cls_224", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/src/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json b/open_clip/src/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..375fa9e12f1629ef049a715d43ba2a8b1822ff1c --- /dev/null +++ b/open_clip/src/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json @@ -0,0 +1,14 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "hf_model_name": "xlm-roberta-base", + "hf_tokenizer_name": "xlm-roberta-base", + "hf_pooler_type": "mean_pooler" + } +} diff --git a/open_clip/src/open_clip/pretrained.py b/open_clip/src/open_clip/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..712f28192b92c5a7e7db92128a510a37a79153b7 --- /dev/null +++ b/open_clip/src/open_clip/pretrained.py @@ -0,0 +1,796 @@ +import copy +import hashlib +import os +import urllib +import warnings +from functools import partial +from typing import Dict, Iterable, Optional, Union + +from tqdm import tqdm + + +try: + import safetensors.torch + _has_safetensors = True +except ImportError: + _has_safetensors = False + + +from .constants import ( + IMAGENET_MEAN, + IMAGENET_STD, + INCEPTION_MEAN, + INCEPTION_STD, + OPENAI_DATASET_MEAN, + OPENAI_DATASET_STD, + HF_WEIGHTS_NAME, + HF_SAFE_WEIGHTS_NAME, +) +from .version import __version__ + +try: + from huggingface_hub import hf_hub_download + hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__) + _has_hf_hub = True +except ImportError: + hf_hub_download = None + _has_hf_hub = False + + +def _pcfg(url='', hf_hub='', **kwargs): + # OpenAI / OpenCLIP defaults + return { + 'url': url, + 'hf_hub': hf_hub, + 'mean': OPENAI_DATASET_MEAN, + 'std': OPENAI_DATASET_STD, + 'interpolation': 'bicubic', + 'resize_mode': 'shortest', + **kwargs, + } + + +def _slpcfg(url='', hf_hub='', **kwargs): + # SiGLIP defaults + return { + 'url': url, + 'hf_hub': hf_hub, + 'mean': INCEPTION_MEAN, + 'std': INCEPTION_STD, + 'interpolation': 'bicubic', + 'resize_mode': 'squash', + **kwargs, + } + + +def _apcfg(url='', hf_hub='', **kwargs): + # CLIPA defaults + return { + 'url': url, + 'hf_hub': hf_hub, + 'mean': IMAGENET_MEAN, + 'std': IMAGENET_STD, + 'interpolation': 'bilinear', + 'resize_mode': 'squash', + **kwargs, + } + + +def _mccfg(url='', hf_hub='', **kwargs): + # MobileCLIP + return { + 'url': url, + 'hf_hub': hf_hub, + 'mean': (0., 0., 0.), + 'std': (1., 1., 1.), + 'interpolation': 'bilinear', + 'resize_mode': 'shortest', + **kwargs, + } + + + +_RN50 = dict( + openai=_pcfg( + url="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + hf_hub="timm/resnet50_clip.openai/", + quick_gelu=True, + ), + yfcc15m=_pcfg( + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", + hf_hub="timm/resnet50_clip.yfcc15m/", + quick_gelu=True, + ), + cc12m=_pcfg( + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt", + hf_hub="timm/resnet50_clip.cc12m/", + quick_gelu=True, + ), +) + +_RN101 = dict( + openai=_pcfg( + url="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + hf_hub="timm/resnet101_clip.openai/", + quick_gelu=True, + ), + yfcc15m=_pcfg( + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt", + hf_hub="timm/resnet101_clip.yfcc15m/", + quick_gelu=True, + ), +) + +_RN50x4 = dict( + openai=_pcfg( + url="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", + hf_hub="timm/resnet50x4_clip.openai/", + quick_gelu=True, + ), +) + +_RN50x16 = dict( + openai=_pcfg( + url="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", + hf_hub="timm/resnet50x16_clip.openai/", + quick_gelu=True, + ), +) + +_RN50x64 = dict( + openai=_pcfg( + url="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", + hf_hub="timm/resnet50x64_clip.openai/", + quick_gelu=True, + ), +) + +_VITB32 = dict( + openai=_pcfg( + url="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + hf_hub="timm/vit_base_patch32_clip_224.openai/", + quick_gelu=True, + ), + # LAION 400M (quick gelu) + laion400m_e31=_pcfg( + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", + hf_hub="timm/vit_base_patch32_clip_224.laion400m_e31/", + quick_gelu=True, + ), + laion400m_e32=_pcfg( + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", + hf_hub="timm/vit_base_patch32_clip_224.laion400m_e32/", + quick_gelu=True, + ), + # LAION 2B-en + laion2b_e16=_pcfg( + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth", + hf_hub="timm/vit_base_patch32_clip_224.laion2b_e16/", + ), + laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/'), + # DataComp-XL models + datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K/'), + # DataComp-M models + datacomp_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.M-s128M-b4K/'), + commonpool_m_clip_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.clip-s128M-b4K/'), + commonpool_m_laion_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.laion-s128M-b4K/'), + commonpool_m_image_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.image-s128M-b4K/'), + commonpool_m_text_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.text-s128M-b4K/'), + commonpool_m_basic_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.basic-s128M-b4K/'), + commonpool_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M-s128M-b4K/'), + # DataComp-S models + datacomp_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.S-s13M-b4K/'), + commonpool_s_clip_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.clip-s13M-b4K/'), + commonpool_s_laion_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.laion-s13M-b4K/'), + commonpool_s_image_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.image-s13M-b4K/'), + commonpool_s_text_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.text-s13M-b4K/'), + commonpool_s_basic_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.basic-s13M-b4K/'), + commonpool_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S-s13M-b4K/'), + # MetaClip models (NOTE quick-gelu activation used) + metaclip_400m=_pcfg( + url="https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_400m.pt", + hf_hub="timm/vit_base_patch32_clip_224.metaclip_400m/", + quick_gelu=True, + ), + metaclip_fullcc=_pcfg( + url="https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_fullcc2.5b.pt", + hf_hub="timm/vit_base_patch32_clip_224.metaclip_2pt5b/", + quick_gelu=True, + ), +) + +_VITB32_256 = dict( + datacomp_s34b_b86k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-256x256-DataComp-s34B-b86K/'), +) + +_VITB16 = dict( + openai=_pcfg( + url="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", + hf_hub="timm/vit_base_patch16_clip_224.openai/", + quick_gelu=True, + ), + # LAION-400M + laion400m_e31=_pcfg( + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt", + hf_hub="timm/vit_base_patch16_clip_224.laion400m_e31/", + ), + laion400m_e32=_pcfg( + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt", + hf_hub="timm/vit_base_patch16_clip_224.laion400m_e32/", + ), + # LAION-2B + laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'), + # DataComp-XL models + datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K/'), + # DataComp-L models + datacomp_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.L-s1B-b8K/'), + commonpool_l_clip_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.clip-s1B-b8K/'), + commonpool_l_laion_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.laion-s1B-b8K/'), + commonpool_l_image_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.image-s1B-b8K/'), + commonpool_l_text_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.text-s1B-b8K/'), + commonpool_l_basic_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.basic-s1B-b8K/'), + commonpool_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L-s1B-b8K/'), + # DFN + dfn2b=_pcfg( + hf_hub='apple/DFN2B-CLIP-ViT-B-16/', + quick_gelu=True, + ), + # MetaCLIP (these are quick-gelu) + metaclip_400m=_pcfg( + url="https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_400m.pt", + hf_hub="timm/vit_base_patch16_clip_224.metaclip_400m/", + quick_gelu=True, + ), + metaclip_fullcc=_pcfg( + url="https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_fullcc2.5b.pt", + hf_hub="timm/vit_base_patch16_clip_224.metaclip_2pt5b/", + quick_gelu=True, + ), +) + +_VITB16_PLUS_240 = dict( + laion400m_e31=_pcfg( + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt", + hf_hub="timm/vit_base_patch16_plus_clip_240.laion400m_e31/", + ), + laion400m_e32=_pcfg( + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt", + hf_hub="timm/vit_base_patch16_plus_clip_240.laion400m_e31/", + ), +) + +_VITL14 = dict( + openai=_pcfg( + url="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", + hf_hub="timm/vit_large_patch14_clip_224.openai/", + quick_gelu=True, + ), + # LAION-400M + laion400m_e31=_pcfg( + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt", + hf_hub="timm/vit_large_patch14_clip_224.laion400m_e31/", + ), + laion400m_e32=_pcfg( + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt", + hf_hub="timm/vit_large_patch14_clip_224.laion400m_e32/", + ), + # LAION-2B-en + laion2b_s32b_b82k=_pcfg( + hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/', + mean=INCEPTION_MEAN, std=INCEPTION_STD), + # DataComp-XL models + datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K/'), + commonpool_xl_clip_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.clip-s13B-b90K/'), + commonpool_xl_laion_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.laion-s13B-b90K/'), + commonpool_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL-s13B-b90K/'), + # MetaCLIP + metaclip_400m=_pcfg( + url="https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_400m.pt", + hf_hub="timm/vit_large_patch14_clip_224.metaclip_400m/", + quick_gelu=True, + ), + metaclip_fullcc=_pcfg( + url="https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_fullcc2.5b.pt", + hf_hub="timm/vit_large_patch14_clip_224.metaclip_2pt5b/", + quick_gelu=True, + ), + # DFN-2B (quick-gelu) + dfn2b=_pcfg( + hf_hub='apple/DFN2B-CLIP-ViT-L-14/', + quick_gelu=True, + ), + # DFN-2B 39B SS + dfn2b_s39b=_pcfg( + hf_hub='apple/DFN2B-CLIP-ViT-L-14-39B/', + ), +) + +_VITL14_336 = dict( + openai=_pcfg( + url="https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", + hf_hub="timm/vit_large_patch14_clip_336.openai/", + quick_gelu=True, + ), +) + +_VITH14 = dict( + # LAION-2B-en + laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'), + # MetaCLIP (quick-gelu) + metaclip_fullcc=_pcfg( + url="https://dl.fbaipublicfiles.com/MMPT/metaclip/h14_fullcc2.5b.pt", + hf_hub="timm/vit_huge_patch14_clip_224.metaclip_2pt5b/", + quick_gelu=True, + ), + metaclip_altogether=_pcfg( + url="https://dl.fbaipublicfiles.com/MMPT/metaclip/h14_v1.2_altogether.pt", + hf_hub="timm/vit_huge_patch14_clip_224.metaclip_altogether/", + # NOTE unlike other MetaCLIP models, this is not using QuickGELU, yay! + ), + # DFN-5B (quick-gelu) + dfn5b=_pcfg( + hf_hub='apple/DFN5B-CLIP-ViT-H-14/', + quick_gelu=True, + interpolation="bicubic", + resize_mode="squash" + ), +) + +_VITH14_378 = dict( + # DFN-5B (quick-gelu) + dfn5b=_pcfg( + hf_hub='apple/DFN5B-CLIP-ViT-H-14-378/', + quick_gelu=True, + interpolation="bicubic", + resize_mode="squash" + ), +) + +_VITg14 = dict( + laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'), + laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'), +) + +_VITbigG14 = dict( + # LAION-2B-en + laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'), + # MetaCLIP (quick-gelu) + metaclip_fullcc=_pcfg( + url='https://dl.fbaipublicfiles.com/MMPT/metaclip/G14_fullcc2.5b.pt', + hf_hub="timm/vit_gigantic_patch14_clip_224.metaclip_2pt5b/", + quick_gelu=True, + ), +) + +_robertaViTB32 = dict( + laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'), +) + +_xlmRobertaBaseViTB32 = dict( + laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'), +) + +_xlmRobertaLargeFrozenViTH14 = dict( + frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'), +) + +_convnext_base = dict( + laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'), +) + +_convnext_base_w = dict( + laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'), + laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'), + laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'), +) + +_convnext_base_w_320 = dict( + laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'), + laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'), +) + +_convnext_large_d = dict( + laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'), +) + +_convnext_large_d_320 = dict( + laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'), + laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'), +) + +_convnext_xxlarge = dict( + laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'), + laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'), + laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'), +) + +_coca_VITB32 = dict( + laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'), + mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/') +) + +_coca_VITL14 = dict( + laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'), + mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/') +) + + +_PRETRAINED = { + "RN50": _RN50, + "RN101": _RN101, + "RN50x4": _RN50x4, + "RN50x16": _RN50x16, + "RN50x64": _RN50x64, + + "ViT-B-32": _VITB32, + "ViT-B-32-256": _VITB32_256, + "ViT-B-16": _VITB16, + "ViT-B-16-plus-240": _VITB16_PLUS_240, + "ViT-L-14": _VITL14, + "ViT-L-14-336": _VITL14_336, + "ViT-H-14": _VITH14, + "ViT-H-14-378": _VITH14_378, + "ViT-g-14": _VITg14, + "ViT-bigG-14": _VITbigG14, + + "roberta-ViT-B-32": _robertaViTB32, + "xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32, + "xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14, + + "convnext_base": _convnext_base, + "convnext_base_w": _convnext_base_w, + "convnext_base_w_320": _convnext_base_w_320, + "convnext_large_d": _convnext_large_d, + "convnext_large_d_320": _convnext_large_d_320, + "convnext_xxlarge": _convnext_xxlarge, + + "coca_ViT-B-32": _coca_VITB32, + "coca_ViT-L-14": _coca_VITL14, + + "EVA01-g-14": dict( + # from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt + laion400m_s11b_b41k=_pcfg(hf_hub='timm/eva_giant_patch14_clip_224.laion400m_s11b_b41k/'), + ), + "EVA01-g-14-plus": dict( + # from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt + merged2b_s11b_b114k=_pcfg(hf_hub='timm/eva_giant_patch14_plus_clip_224.merged2b_s11b_b114k/'), + ), + "EVA02-B-16": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt + merged2b_s8b_b131k=_pcfg(hf_hub='timm/eva02_base_patch16_clip_224.merged2b_s8b_b131k/'), + ), + "EVA02-L-14": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt + merged2b_s4b_b131k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_224.merged2b_s4b_b131k/'), + ), + "EVA02-L-14-336": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt + merged2b_s6b_b61k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_336.merged2b_s6b_b61k/'), + ), + "EVA02-E-14": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt + laion2b_s4b_b115k=_pcfg(hf_hub='timm/eva02_enormous_patch14_clip_224.laion2b_s4b_b115k/'), + ), + "EVA02-E-14-plus": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt + laion2b_s9b_b144k=_pcfg(hf_hub='timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k/'), + ), + + "ViT-B-16-SigLIP": dict( + webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP/'), + ), + "ViT-B-16-SigLIP-256": dict( + webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-256/'), + ), + "ViT-B-16-SigLIP-i18n-256": dict( + webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-i18n-256/'), + ), + "ViT-B-16-SigLIP-384": dict( + webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-384/'), + ), + "ViT-B-16-SigLIP-512": dict( + webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-512/'), + ), + "ViT-L-16-SigLIP-256": dict( + webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP-256/'), + ), + "ViT-L-16-SigLIP-384": dict( + webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP-384/'), + ), + "ViT-SO400M-14-SigLIP": dict( + webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP/'), + ), + "ViT-SO400M-16-SigLIP-i18n-256": dict( + webli=_slpcfg(hf_hub='timm/ViT-SO400M-16-SigLIP-i18n-256/'), + ), + "ViT-SO400M-14-SigLIP-378": dict( + webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP-384/'), # NOTE using 384 weights, but diff img_size used + ), + "ViT-SO400M-14-SigLIP-384": dict( + webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP-384/'), + ), + + "ViT-L-14-CLIPA": dict( + datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-L-14-CLIPA-datacomp1B/'), + ), + "ViT-L-14-CLIPA-336": dict( + datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-L-14-CLIPA-336-datacomp1B/'), + ), + "ViT-H-14-CLIPA": dict( + datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-datacomp1B/'), + ), + "ViT-H-14-CLIPA-336": dict( + laion2b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-336-laion2B/'), + datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-336-datacomp1B/'), + ), + "ViT-bigG-14-CLIPA": dict( + datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-bigG-14-CLIPA-datacomp1B/'), + ), + "ViT-bigG-14-CLIPA-336": dict( + datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-bigG-14-CLIPA-336-datacomp1B/'), + ), + + "nllb-clip-base": dict( + v1=_pcfg(hf_hub='visheratin/nllb-clip-base-oc/'), + ), + "nllb-clip-large": dict( + v1=_pcfg(hf_hub='visheratin/nllb-clip-large-oc/'), + ), + + "nllb-clip-base-siglip": dict( + v1=_slpcfg(hf_hub='visheratin/nllb-clip-base-siglip/'), + mrl=_slpcfg(hf_hub='visheratin/nllb-siglip-mrl-base/'), + ), + "nllb-clip-large-siglip": dict( + v1=_slpcfg(hf_hub='visheratin/nllb-clip-large-siglip/'), + mrl=_slpcfg(hf_hub='visheratin/nllb-siglip-mrl-large/'), + ), + + "MobileCLIP-S1": dict( + datacompdr=_mccfg(hf_hub='apple/MobileCLIP-S1-OpenCLIP/')), + "MobileCLIP-S2": dict( + datacompdr=_mccfg(hf_hub='apple/MobileCLIP-S2-OpenCLIP/')), + "MobileCLIP-B": dict( + datacompdr=_mccfg(hf_hub='apple/MobileCLIP-B-OpenCLIP/'), + datacompdr_lt=_mccfg(hf_hub='apple/MobileCLIP-B-LT-OpenCLIP/'), + ), + + "ViTamin-S": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-S/pytorch_model.bin'), + ), + "ViTamin-S-LTT": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-S-LTT/pytorch_model.bin'), + ), + "ViTamin-B": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-B/pytorch_model.bin'), + ), + "ViTamin-B-LTT": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-B-LTT/pytorch_model.bin'), + ), + "ViTamin-L": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-224px/pytorch_model.bin'), + ), + "ViTamin-L-256": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-256px/pytorch_model.bin'), + ), + "ViTamin-L-336": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-336px/pytorch_model.bin'), + ), + "ViTamin-L-384": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-384px/pytorch_model.bin'), + ), + "ViTamin-L2": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-224px/pytorch_model.bin'), + ), + "ViTamin-L2-256": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-256px/pytorch_model.bin'), + ), + "ViTamin-L2-336": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-336px/pytorch_model.bin'), + ), + "ViTamin-L2-384": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-384px/pytorch_model.bin'), + ), + "ViTamin-XL-256": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-256px/pytorch_model.bin'), + ), + "ViTamin-XL-336": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-336px/pytorch_model.bin'), + ), + "ViTamin-XL-384": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-384px/pytorch_model.bin'), + ), +} + +_PRETRAINED_quickgelu = {} +for k, v in _PRETRAINED.items(): + quick_gelu_tags = {} + for tk, tv in v.items(): + if tv.get('quick_gelu', False): + quick_gelu_tags[tk] = copy.deepcopy(tv) + if quick_gelu_tags: + _PRETRAINED_quickgelu[k + '-quickgelu'] = quick_gelu_tags +_PRETRAINED.update(_PRETRAINED_quickgelu) + +def _clean_tag(tag: str): + # normalize pretrained tags + return tag.lower().replace('-', '_') + + +def list_pretrained(as_str: bool = False): + """ returns list of pretrained models + Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True + """ + return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] + + +def list_pretrained_models_by_tag(tag: str): + """ return all models having the specified pretrain tag """ + models = [] + tag = _clean_tag(tag) + for k in _PRETRAINED.keys(): + if tag in _PRETRAINED[k]: + models.append(k) + return models + + +def list_pretrained_tags_by_model(model: str): + """ return all pretrain tags for the specified model architecture """ + tags = [] + if model in _PRETRAINED: + tags.extend(_PRETRAINED[model].keys()) + return tags + + +def is_pretrained_cfg(model: str, tag: str): + if model not in _PRETRAINED: + return False + return _clean_tag(tag) in _PRETRAINED[model] + + +def get_pretrained_cfg(model: str, tag: str): + if model not in _PRETRAINED: + return {} + model_pretrained = _PRETRAINED[model] + return model_pretrained.get(_clean_tag(tag), {}) + + +def get_pretrained_url(model: str, tag: str): + cfg = get_pretrained_cfg(model, _clean_tag(tag)) + return cfg.get('url', '') + + +def download_pretrained_from_url( + url: str, + cache_dir: Optional[str] = None, +): + if not cache_dir: + cache_dir = os.path.expanduser("~/.cache/clip") + os.makedirs(cache_dir, exist_ok=True) + filename = os.path.basename(url) + + if 'openaipublic' in url: + expected_sha256 = url.split("/")[-2] + elif 'mlfoundations' in url: + expected_sha256 = os.path.splitext(filename)[0].split("-")[-1] + else: + expected_sha256 = '' + + download_target = os.path.join(cache_dir, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if expected_sha256: + if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + else: + return download_target + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): + raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def has_hf_hub(necessary=False): + if not _has_hf_hub and necessary: + # if no HF Hub module installed, and it is necessary to continue, raise error + raise RuntimeError( + 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') + return _has_hf_hub + + +def _get_safe_alternatives(filename: str) -> Iterable[str]: + """Returns potential safetensors alternatives for a given filename. + + Use case: + When downloading a model from the Huggingface Hub, we first look if a .safetensors file exists and if yes, we use it. + """ + if filename == HF_WEIGHTS_NAME: + yield HF_SAFE_WEIGHTS_NAME + + if filename not in (HF_WEIGHTS_NAME,) and (filename.endswith(".bin") or filename.endswith(".pth")): + yield filename[:-4] + ".safetensors" + + +def download_pretrained_from_hf( + model_id: str, + filename: Optional[str] = None, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, +): + has_hf_hub(True) + + filename = filename or HF_WEIGHTS_NAME + + # Look for .safetensors alternatives and load from it if it exists + if _has_safetensors: + for safe_filename in _get_safe_alternatives(filename): + try: + cached_file = hf_hub_download( + repo_id=model_id, + filename=safe_filename, + revision=revision, + cache_dir=cache_dir, + ) + return cached_file + except Exception: + pass + + try: + # Attempt to download the file + cached_file = hf_hub_download( + repo_id=model_id, + filename=filename, + revision=revision, + cache_dir=cache_dir, + ) + return cached_file # Return the path to the downloaded file if successful + except Exception as e: + raise FileNotFoundError(f"Failed to download file ({filename}) for {model_id}. Last error: {e}") + + +def download_pretrained( + cfg: Dict, + prefer_hf_hub: bool = True, + cache_dir: Optional[str] = None, +): + target = '' + if not cfg: + return target + + has_hub = has_hf_hub() + download_url = cfg.get('url', '') + download_hf_hub = cfg.get('hf_hub', '') + if has_hub and prefer_hf_hub and download_hf_hub: + # prefer to use HF hub, remove url info + download_url = '' + + if download_url: + target = download_pretrained_from_url(download_url, cache_dir=cache_dir) + elif download_hf_hub: + has_hf_hub(True) + # we assume the hf_hub entries in pretrained config combine model_id + filename in + # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and + # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'. + model_id, filename = os.path.split(download_hf_hub) + if filename: + target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir) + else: + target = download_pretrained_from_hf(model_id, cache_dir=cache_dir) + + return target diff --git a/open_clip/src/open_clip/timm_model.py b/open_clip/src/open_clip/timm_model.py new file mode 100644 index 0000000000000000000000000000000000000000..975e37a9d3bb527d02382ea5a2e102caa98b1798 --- /dev/null +++ b/open_clip/src/open_clip/timm_model.py @@ -0,0 +1,153 @@ +""" timm model adapter + +Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. +""" +import logging +from collections import OrderedDict + +import torch +import torch.nn as nn + +try: + import timm + try: + # new timm imports >= 0.8.1 + from timm.layers import RotAttentionPool2d + from timm.layers import AttentionPool2d as AbsAttentionPool2d + from timm.layers import Mlp, to_2tuple + except ImportError as e: + # fallback, try old timm imports < 0.8.1 + from timm.models.layers.attention_pool2d import RotAttentionPool2d + from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d + from timm.models.layers import Mlp, to_2tuple +except ImportError: + timm = None + +from .utils import freeze_batch_norm_2d + + +class TimmModel(nn.Module): + """ timm model adapter + """ + + def __init__( + self, + model_name, + embed_dim, + image_size=224, + pool='avg', + proj='linear', + proj_bias=False, + drop=0., + drop_path=None, + patch_drop=None, + pretrained=False, + ): + super().__init__() + if timm is None: + raise RuntimeError("Please `pip install timm` to use timm models.") + self.image_size = to_2tuple(image_size) + + # setup kwargs that may not be common across all models + timm_kwargs = {} + if drop_path is not None: + timm_kwargs['drop_path_rate'] = drop_path + if patch_drop is not None: + timm_kwargs['patch_drop_rate'] = patch_drop + + custom_pool = pool in ('abs_attn', 'rot_attn') + if proj: + assert proj in ("linear", "mlp", "none") + extra_proj = proj in ("linear", "mlp") + if not extra_proj and not custom_pool: + # use network classifier head as projection if no proj specified and no custom pooling used + # if projection is explicitly set to "none" will be pass through from network trunk + proj_dim = 0 if proj == 'none' else embed_dim + self.trunk = timm.create_model( + model_name, + num_classes=proj_dim, + global_pool=pool, + pretrained=pretrained, + **timm_kwargs, + ) + prev_chs = embed_dim + else: + self.trunk = timm.create_model( + model_name, + pretrained=pretrained, + **timm_kwargs, + ) + feat_size = self.trunk.default_cfg.get('pool_size', None) + feature_ndim = 1 if not feat_size else 2 + if custom_pool: + assert feature_ndim == 2 + # if attn pooling used, remove both classifier and default pool + self.trunk.reset_classifier(0, global_pool='') + else: + # reset global pool if pool config set, otherwise leave as network default + reset_kwargs = dict(global_pool=pool) if pool else {} + self.trunk.reset_classifier(0, **reset_kwargs) + prev_chs = self.trunk.num_features + + head_layers = OrderedDict() + + # Add custom pooling to head + if pool == 'abs_attn': + head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) + prev_chs = embed_dim + elif pool == 'rot_attn': + head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) + prev_chs = embed_dim + + # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used + if proj == 'linear': + head_layers['drop'] = nn.Dropout(drop) + head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) + elif proj == 'mlp': + head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias)) + + self.head = nn.Sequential(head_layers) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + """ lock modules + Args: + unlocked_groups (int): leave last n layer groups unlocked (default: 0) + """ + if not unlocked_groups: + # lock full model + for param in self.trunk.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self.trunk) + else: + # NOTE: partial freeze requires latest timm (master) branch and is subject to change + try: + # FIXME import here until API stable and in an official release + from timm.models.helpers import group_parameters, group_modules + except ImportError: + raise RuntimeError( + 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') + matcher = self.trunk.group_matcher() + gparams = group_parameters(self.trunk, matcher) + max_layer_id = max(gparams.keys()) + max_layer_id = max_layer_id - unlocked_groups + for group_idx in range(max_layer_id + 1): + group = gparams[group_idx] + for param in group: + self.trunk.get_parameter(param).requires_grad = False + if freeze_bn_stats: + gmodules = group_modules(self.trunk, matcher, reverse=True) + gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} + freeze_batch_norm_2d(self.trunk, gmodules) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + try: + self.trunk.set_grad_checkpointing(enable) + except Exception as e: + logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') + + def forward(self, x): + x = self.trunk(x) + x = self.head(x) + return x diff --git a/open_clip/src/open_clip/utils.py b/open_clip/src/open_clip/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bb0bb8868ae1f2d31493ca32b73accd6bf1d3cdb --- /dev/null +++ b/open_clip/src/open_clip/utils.py @@ -0,0 +1,89 @@ +from itertools import repeat +import collections.abc + +import torch +from torch import nn as nn +from torchvision.ops.misc import FrozenBatchNorm2d + + +def freeze_batch_norm_2d(module, module_match={}, name=''): + """ + Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is + itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and + returned. Otherwise, the module is walked recursively and submodules are converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. + module_match (dict): Dictionary of full module names to freeze (all if empty) + name (str): Full module name (prefix) + + Returns: + torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 + """ + res = module + is_match = True + if module_match: + is_match = name in module_match + if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): + res = FrozenBatchNorm2d(module.num_features) + res.num_features = module.num_features + res.affine = module.affine + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for child_name, child in module.named_children(): + full_child_name = '.'.join([name, child_name]) if name else child_name + new_child = freeze_batch_norm_2d(child, module_match, full_child_name) + if new_child is not child: + res.add_module(child_name, new_child) + return res + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = lambda n, x: _ntuple(n)(x) + +# Replaces all linear layers with linear_replacement +# TODO: add int8 support for other linear layers including attn and convnets +def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True): + for name, module in model.named_children(): + if len(list(module.children())) > 0: + replace_linear(module, linear_replacement, include_modules, copy_weights) + + if isinstance(module, torch.nn.Linear) and name in include_modules: + old_module = model._modules[name] + model._modules[name] = linear_replacement( + module.in_features, + module.out_features, + module.bias is not None, + ) + if copy_weights: + model._modules[name].weight.data.copy_(old_module.weight.data) + if model._modules[name].bias is not None: + model._modules[name].bias.data.copy_(old_module.bias) + + return model + +def convert_int8_model_to_inference_mode(model): + for m in model.modules(): + if hasattr(m, 'prepare_for_eval'): + int8_original_dtype = m.weight.dtype + m.prepare_for_eval() + m.int8_original_dtype = int8_original_dtype \ No newline at end of file diff --git a/open_clip/src/open_clip_train/__init__.py b/open_clip/src/open_clip_train/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/open_clip/src/open_clip_train/file_utils.py b/open_clip/src/open_clip_train/file_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..395cf7df0acc164c6851f17834d793f5852d4605 --- /dev/null +++ b/open_clip/src/open_clip_train/file_utils.py @@ -0,0 +1,83 @@ +import logging +import os +import multiprocessing +import subprocess +import time +import fsspec +import torch +from tqdm import tqdm + +def remote_sync_s3(local_dir, remote_dir): + # skip epoch_latest which can change during sync. + result = subprocess.run(["aws", "s3", "sync", local_dir, remote_dir, '--exclude', '*epoch_latest.pt'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + if result.returncode != 0: + logging.error(f"Error: Failed to sync with S3 bucket {result.stderr.decode('utf-8')}") + return False + + logging.info(f"Successfully synced with S3 bucket") + return True + +def remote_sync_fsspec(local_dir, remote_dir): + # FIXME currently this is slow and not recommended. Look into speeding up. + a = fsspec.get_mapper(local_dir) + b = fsspec.get_mapper(remote_dir) + + for k in a: + # skip epoch_latest which can change during sync. + if 'epoch_latest.pt' in k: + continue + + logging.info(f'Attempting to sync {k}') + if k in b and len(a[k]) == len(b[k]): + logging.debug(f'Skipping remote sync for {k}.') + continue + + try: + logging.info(f'Successful sync for {k}.') + b[k] = a[k] + except Exception as e: + logging.info(f'Error during remote sync for {k}: {e}') + return False + + return True + +def remote_sync(local_dir, remote_dir, protocol): + logging.info('Starting remote sync.') + if protocol == 's3': + return remote_sync_s3(local_dir, remote_dir) + elif protocol == 'fsspec': + return remote_sync_fsspec(local_dir, remote_dir) + else: + logging.error('Remote protocol not known') + return False + +def keep_running_remote_sync(sync_every, local_dir, remote_dir, protocol): + while True: + time.sleep(sync_every) + remote_sync(local_dir, remote_dir, protocol) + +def start_sync_process(sync_every, local_dir, remote_dir, protocol): + p = multiprocessing.Process(target=keep_running_remote_sync, args=(sync_every, local_dir, remote_dir, protocol)) + return p + +# Note: we are not currently using this save function. +def pt_save(pt_obj, file_path): + of = fsspec.open(file_path, "wb") + with of as f: + torch.save(pt_obj, file_path) + +def pt_load(file_path, map_location=None): + if file_path.startswith('s3'): + logging.info('Loading remote checkpoint, which may take a bit.') + of = fsspec.open(file_path, "rb") + with of as f: + out = torch.load(f, map_location=map_location) + return out + +def check_exists(file_path): + try: + with fsspec.open(file_path): + pass + except FileNotFoundError: + return False + return True diff --git a/open_clip/src/open_clip_train/main.py b/open_clip/src/open_clip_train/main.py new file mode 100644 index 0000000000000000000000000000000000000000..a53da6d14a9ba1a09449ed15600451f464797378 --- /dev/null +++ b/open_clip/src/open_clip_train/main.py @@ -0,0 +1,555 @@ +import copy +import glob +import logging +import os +import re +import subprocess +import sys +import random +from datetime import datetime +from functools import partial + +import numpy as np +import torch +from torch import optim + +try: + import wandb +except ImportError: + wandb = None + +try: + import torch.utils.tensorboard as tensorboard +except ImportError: + tensorboard = None + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + +from open_clip import create_model_and_transforms, trace_model, get_tokenizer, create_loss +from open_clip_train.data import get_data +from open_clip_train.distributed import is_master, init_distributed_device, broadcast_object +from open_clip_train.logger import setup_logging +from open_clip_train.params import parse_args +from open_clip_train.scheduler import cosine_lr, const_lr, const_lr_cooldown +from open_clip_train.train import train_one_epoch, evaluate +from open_clip_train.file_utils import pt_load, check_exists, start_sync_process, remote_sync + + +LATEST_CHECKPOINT_NAME = "epoch_latest.pt" + + +def random_seed(seed=42, rank=0): + torch.manual_seed(seed + rank) + np.random.seed(seed + rank) + random.seed(seed + rank) + + +def natural_key(string_): + """See http://www.codinghorror.com/blog/archives/001018.html""" + return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] + + +def get_latest_checkpoint(path: str, remote : bool): + # as writen, this glob recurses, so can pick up checkpoints across multiple sub-folders + if remote: + result = subprocess.run(["aws", "s3", "ls", path + "/"], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + print(result) + if result.returncode == 1: + return None + checkpoints = [os.path.join(path, x.split(' ')[-1]) for x in result.stdout.decode().split('\n')[:-1]] + else: + checkpoints = glob.glob(path + '**/*.pt', recursive=True) + if checkpoints: + checkpoints = sorted(checkpoints, key=natural_key) + return checkpoints[-1] + return None + + +def main(args): + args = parse_args(args) + + if torch.cuda.is_available(): + # This enables tf32 on Ampere GPUs which is only 8% slower than + # float16 and almost as accurate as float32 + # This was a default in pytorch until 1.12 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + + # fully initialize distributed device environment + device = init_distributed_device(args) + + # get the name of the experiments + if args.name is None: + # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule? + model_name_safe = args.model.replace('/', '-') + date_str = datetime.now().strftime("%Y_%m_%d-%H_%M_%S") + if args.distributed: + # sync date_str from master to all ranks + date_str = broadcast_object(args, date_str) + args.name = '-'.join([ + date_str, + f"model_{model_name_safe}", + f"lr_{args.lr}", + f"b_{args.batch_size}", + f"j_{args.workers}", + f"p_{args.precision}", + ]) + + resume_latest = args.resume == 'latest' + log_base_path = os.path.join(args.logs, args.name) + args.log_path = None + if is_master(args, local=args.log_local): + os.makedirs(log_base_path, exist_ok=True) + log_filename = f'out-{args.rank}' if args.log_local else 'out.log' + args.log_path = os.path.join(log_base_path, log_filename) + if os.path.exists(args.log_path) and not resume_latest: + print( + "Error. Experiment already exists. Use --name {} to specify a new experiment." + ) + return -1 + + # Setup text logger + args.log_level = logging.DEBUG if args.debug else logging.INFO + setup_logging(args.log_path, args.log_level) + + # Setup wandb, tensorboard, checkpoint logging + args.wandb = 'wandb' in args.report_to or 'all' in args.report_to + args.tensorboard = 'tensorboard' in args.report_to or 'all' in args.report_to + args.checkpoint_path = os.path.join(log_base_path, "checkpoints") + if is_master(args): + args.tensorboard_path = os.path.join(log_base_path, "tensorboard") if args.tensorboard else '' + for dirname in [args.tensorboard_path, args.checkpoint_path]: + if dirname: + os.makedirs(dirname, exist_ok=True) + else: + args.tensorboard_path = '' + + if resume_latest: + resume_from = None + checkpoint_path = args.checkpoint_path + # If using remote_sync, need to check the remote instead of the local checkpoints folder. + if args.remote_sync is not None: + checkpoint_path = os.path.join(args.remote_sync, args.name, "checkpoints") + if args.save_most_recent: + print('Error. Cannot use save-most-recent with remote_sync and resume latest.') + return -1 + if args.remote_sync_protocol != 's3': + print('Error. Sync protocol not supported when using resume latest.') + return -1 + if is_master(args): + # Checking for existing checkpoint via master rank only. It is possible for + # different rank processes to see different files if a shared file-system is under + # stress, however it's very difficult to fully work around such situations. + if args.save_most_recent: + # if --save-most-recent flag is set, look for latest at a fixed filename + resume_from = os.path.join(checkpoint_path, LATEST_CHECKPOINT_NAME) + if not os.path.exists(resume_from): + # If no latest checkpoint has been saved yet, don't try to resume + resume_from = None + else: + # otherwise, list checkpoint dir contents and pick the newest checkpoint + resume_from = get_latest_checkpoint(checkpoint_path, remote=args.remote_sync is not None) + if resume_from: + logging.info(f'Found latest resume checkpoint at {resume_from}.') + else: + logging.info(f'No latest resume checkpoint found in {checkpoint_path}.') + if args.distributed: + # sync found checkpoint path to all ranks + resume_from = broadcast_object(args, resume_from) + args.resume = resume_from + + if args.copy_codebase: + copy_codebase(args) + + # start the sync proces if remote-sync is not None + remote_sync_process = None + if is_master(args) and args.remote_sync is not None: + # first make sure it works + result = remote_sync( + os.path.join(args.logs, args.name), + os.path.join(args.remote_sync, args.name), + args.remote_sync_protocol + ) + if result: + logging.info('remote sync successful.') + else: + logging.info('Error: remote sync failed. Exiting.') + return -1 + # if all looks good, start a process to do this every args.remote_sync_frequency seconds + remote_sync_process = start_sync_process( + args.remote_sync_frequency, + os.path.join(args.logs, args.name), + os.path.join(args.remote_sync, args.name), + args.remote_sync_protocol + ) + remote_sync_process.start() + + if args.precision == 'fp16': + logging.warning( + 'It is recommended to use AMP mixed-precision instead of FP16. ' + 'FP16 support needs further verification and tuning, especially for train.') + + if args.horovod: + logging.info( + f'Running in horovod mode with multiple processes / nodes. Device: {args.device}.' + f'Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}.') + elif args.distributed: + logging.info( + f'Running in distributed mode with multiple processes. Device: {args.device}.' + f'Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}.') + else: + logging.info(f'Running with a single process. Device {args.device}.') + + dist_model = None + args.distill = args.distill_model is not None and args.distill_pretrained is not None + if args.distill: + #FIXME: support distillation with grad accum. + assert args.accum_freq == 1 + #FIXME: support distillation with coca. + assert 'coca' not in args.model.lower() + + if isinstance(args.force_image_size, (tuple, list)) and len(args.force_image_size) == 1: + # arg is nargs, single (square) image size list -> int + args.force_image_size = args.force_image_size[0] + random_seed(args.seed, 0) + model_kwargs = {} + if args.siglip: + model_kwargs['init_logit_scale'] = np.log(10) # different from CLIP + model_kwargs['init_logit_bias'] = -10 + model, preprocess_train, preprocess_val = create_model_and_transforms( + args.model, + args.pretrained, + precision=args.precision, + device=device, + jit=args.torchscript, + force_quick_gelu=args.force_quick_gelu, + force_custom_text=args.force_custom_text, + force_patch_dropout=args.force_patch_dropout, + force_image_size=args.force_image_size, + image_mean=args.image_mean, + image_std=args.image_std, + image_interpolation=args.image_interpolation, + image_resize_mode=args.image_resize_mode, # only effective for inference + aug_cfg=args.aug_cfg, + pretrained_image=args.pretrained_image, + output_dict=True, + cache_dir=args.cache_dir, + **model_kwargs, + ) + if args.distill: + # FIXME: currently assumes the model you're distilling from has the same tokenizer & transforms. + dist_model, _, _ = create_model_and_transforms( + args.distill_model, + args.distill_pretrained, + device=device, + precision=args.precision, + output_dict=True, + cache_dir=args.cache_dir, + ) + if args.use_bnb_linear is not None: + print('=> using a layer from bitsandbytes.\n' + ' this is an experimental feature which requires two extra pip installs\n' + ' pip install bitsandbytes triton' + ' please make sure to use triton 2.0.0') + import bitsandbytes as bnb + from open_clip.utils import replace_linear + print(f'=> replacing linear layers with {args.use_bnb_linear}') + linear_replacement_cls = getattr(bnb.nn.triton_based_modules, args.use_bnb_linear) + replace_linear(model, linear_replacement_cls) + model = model.to(device) + + random_seed(args.seed, args.rank) + + if args.trace: + model = trace_model(model, batch_size=args.batch_size, device=device) + + if args.lock_image: + # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 + model.lock_image_tower( + unlocked_groups=args.lock_image_unlocked_groups, + freeze_bn_stats=args.lock_image_freeze_bn_stats) + if args.lock_text: + model.lock_text_tower( + unlocked_layers=args.lock_text_unlocked_layers, + freeze_layer_norm=args.lock_text_freeze_layer_norm) + + if args.grad_checkpointing: + model.set_grad_checkpointing() + + if is_master(args): + logging.info("Model:") + logging.info(f"{str(model)}") + logging.info("Params:") + params_file = os.path.join(args.logs, args.name, "params.txt") + with open(params_file, "w") as f: + for name in sorted(vars(args)): + val = getattr(args, name) + logging.info(f" {name}: {val}") + f.write(f"{name}: {val}\n") + + if args.distributed and not args.horovod: + if args.use_bn_sync: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + ddp_args = {} + if args.ddp_static_graph: + # this doesn't exist in older PyTorch, arg only added if enabled + ddp_args['static_graph'] = True + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], **ddp_args) + + if args.distill: + dist_model = torch.nn.parallel.DistributedDataParallel(dist_model, device_ids=[device], **ddp_args) + + # create optimizer and scaler + optimizer = None + scaler = None + + if args.train_data or args.dataset_type == "synthetic": + assert not args.trace, 'Cannot train with traced model' + + opt = getattr(args, 'opt', 'adamw').lower() + if opt.startswith('timm/'): + from timm.optim import create_optimizer_v2 + timm_opt = opt.split('timm/')[-1] + opt_kwargs = {} + assert (args.beta1 is None) == (args.beta2 is None), \ + 'When using timm optimizer, BOTH beta1 and beta2 must be specified (or not specified).' + if args.beta1 is not None: + opt_kwargs['betas'] = (args.beta1, args.beta2) + if args.momentum is not None: + opt_kwargs['momentum'] = args.momentum + optimizer = create_optimizer_v2( + model, + timm_opt, + lr=args.lr, + weight_decay=args.wd, + eps=args.eps, + **opt_kwargs, + ) + else: + # If some params are not passed, we use the default values based on model name. + exclude = lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n + include = lambda n, p: not exclude(n, p) + + named_parameters = list(model.named_parameters()) + gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad] + rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad] + + if opt == 'adamw': + optimizer = optim.AdamW( + [ + {"params": gain_or_bias_params, "weight_decay": 0.}, + {"params": rest_params, "weight_decay": args.wd}, + ], + lr=args.lr, + betas=(args.beta1, args.beta2), + eps=args.eps, + ) + else: + assert False, f'Unknown optimizer {opt}' + + if is_master(args): + if is_master(args): + defaults = copy.deepcopy(optimizer.defaults) + defaults['weight_decay'] = args.wd + defaults = ', '.join([f'{k}: {v}' for k, v in defaults.items()]) + logging.info( + f'Created {type(optimizer).__name__} ({args.opt}) optimizer: {defaults}' + ) + + if args.horovod: + optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters()) + hvd.broadcast_parameters(model.state_dict(), root_rank=0) + hvd.broadcast_optimizer_state(optimizer, root_rank=0) + + scaler = None + if args.precision == "amp": + try: + scaler = torch.amp.GradScaler(device=device) + except (AttributeError, TypeError) as e: + scaler = torch.cuda.amp.GradScaler() + + # optionally resume from a checkpoint + start_epoch = 0 + if args.resume is not None: + checkpoint = pt_load(args.resume, map_location='cpu') + if 'epoch' in checkpoint: + # resuming a train checkpoint w/ epoch and optimizer state + start_epoch = checkpoint["epoch"] + sd = checkpoint["state_dict"] + if not args.distributed and next(iter(sd.items()))[0].startswith('module'): + sd = {k[len('module.'):]: v for k, v in sd.items()} + model.load_state_dict(sd) + if optimizer is not None: + optimizer.load_state_dict(checkpoint["optimizer"]) + if scaler is not None and 'scaler' in checkpoint: + scaler.load_state_dict(checkpoint['scaler']) + logging.info(f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})") + else: + # loading a bare (model only) checkpoint for fine-tune or evaluation + model.load_state_dict(checkpoint) + logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})") + + # initialize datasets + tokenizer = get_tokenizer(args.model, cache_dir=args.cache_dir) + data = get_data( + args, + (preprocess_train, preprocess_val), + epoch=start_epoch, + tokenizer=tokenizer, + ) + assert len(data), 'At least one train or eval dataset must be specified.' + + # create scheduler if train + scheduler = None + if 'train' in data and optimizer is not None: + total_steps = (data["train"].dataloader.num_batches // args.accum_freq) * args.epochs + if args.lr_scheduler == "cosine": + scheduler = cosine_lr(optimizer, args.lr, args.warmup, total_steps) + elif args.lr_scheduler == "const": + scheduler = const_lr(optimizer, args.lr, args.warmup, total_steps) + elif args.lr_scheduler == "const-cooldown": + assert args.epochs_cooldown is not None,\ + "Please specify the number of cooldown epochs for this lr schedule." + cooldown_steps = (data["train"].dataloader.num_batches // args.accum_freq) * args.epochs_cooldown + scheduler = const_lr_cooldown( + optimizer, args.lr, args.warmup, total_steps, + cooldown_steps, args.lr_cooldown_power, args.lr_cooldown_end) + else: + logging.error( + f'Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const, const-cooldown.') + exit(1) + + # determine if this worker should save logs and checkpoints. only do so if it is rank == 0 + args.save_logs = args.logs and args.logs.lower() != 'none' and is_master(args) + writer = None + if args.save_logs and args.tensorboard: + assert tensorboard is not None, "Please install tensorboard." + writer = tensorboard.SummaryWriter(args.tensorboard_path) + + if args.wandb and is_master(args): + assert wandb is not None, 'Please install wandb.' + logging.debug('Starting wandb.') + args.train_sz = data["train"].dataloader.num_samples + if args.val_data is not None: + args.val_sz = data["val"].dataloader.num_samples + # you will have to configure this for your project! + wandb.init( + project=args.wandb_project_name, + name=args.name, + id=args.name, + notes=args.wandb_notes, + tags=[], + resume='auto' if args.resume == "latest" else None, + config=vars(args), + ) + if args.debug: + wandb.watch(model, log='all') + wandb.save(params_file) + logging.debug('Finished loading wandb.') + + # Pytorch 2.0 adds '_orig_mod.' prefix to keys of state_dict() of compiled models. + # For compatibility, we save state_dict() of the original model, which shares the + # weights without the prefix. + original_model = model + if args.torchcompile: + logging.info('Compiling model...') + + if args.grad_checkpointing and args.distributed: + logging.info('Disabling DDP dynamo optimizer when grad checkpointing enabled.') + # As of now (~PyTorch 2.4/2.5), compile + grad checkpointing work, but DDP optimizer must be disabled + torch._dynamo.config.optimize_ddp = False + + model = torch.compile(original_model) + + if 'train' not in data: + # If using int8, convert to inference mode. + if args.use_bnb_linear is not None: + from open_clip.utils import convert_int8_model_to_inference_mode + convert_int8_model_to_inference_mode(model) + # Evaluate. + evaluate(model, data, start_epoch, args, tb_writer=writer, tokenizer=tokenizer) + return + + loss = create_loss(args) + + for epoch in range(start_epoch, args.epochs): + if is_master(args): + logging.info(f'Start epoch {epoch}') + + train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist_model, args, tb_writer=writer) + completed_epoch = epoch + 1 + + if any(v in data for v in ('val', 'imagenet-val', 'imagenet-v2')): + evaluate(model, data, completed_epoch, args, tb_writer=writer, tokenizer=tokenizer) + + # Saving checkpoints. + if args.save_logs: + checkpoint_dict = { + "epoch": completed_epoch, + "name": args.name, + "state_dict": original_model.state_dict(), + "optimizer": optimizer.state_dict(), + } + if scaler is not None: + checkpoint_dict["scaler"] = scaler.state_dict() + + if completed_epoch == args.epochs or ( + args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0 + ): + torch.save( + checkpoint_dict, + os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"), + ) + if args.delete_previous_checkpoint: + previous_checkpoint = os.path.join(args.checkpoint_path, f"epoch_{completed_epoch - 1}.pt") + if os.path.exists(previous_checkpoint): + os.remove(previous_checkpoint) + + if args.save_most_recent: + # try not to corrupt the latest checkpoint if save fails + tmp_save_path = os.path.join(args.checkpoint_path, "tmp.pt") + latest_save_path = os.path.join(args.checkpoint_path, LATEST_CHECKPOINT_NAME) + torch.save(checkpoint_dict, tmp_save_path) + os.replace(tmp_save_path, latest_save_path) + + if args.wandb and is_master(args): + wandb.finish() + + # run a final sync. + if remote_sync_process is not None: + logging.info('Final remote sync.') + remote_sync_process.terminate() + result = remote_sync( + os.path.join(args.logs, args.name), + os.path.join(args.remote_sync, args.name), + args.remote_sync_protocol + ) + if result: + logging.info('Final remote sync successful.') + else: + logging.info('Final remote sync failed.') + + +def copy_codebase(args): + from shutil import copytree, ignore_patterns + new_code_path = os.path.join(args.logs, args.name, "code") + if os.path.exists(new_code_path): + print( + f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment." + ) + return -1 + print(f"Copying codebase to {new_code_path}") + current_code_path = os.path.realpath(__file__) + for _ in range(3): + current_code_path = os.path.dirname(current_code_path) + copytree(current_code_path, new_code_path, ignore=ignore_patterns('log', 'logs', 'wandb')) + print("Done copying code.") + return 1 + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/open_clip/src/open_clip_train/precision.py b/open_clip/src/open_clip_train/precision.py new file mode 100644 index 0000000000000000000000000000000000000000..5af494892d1c2c0c26fc878f2e1fa69b585194cb --- /dev/null +++ b/open_clip/src/open_clip_train/precision.py @@ -0,0 +1,14 @@ +import torch +from contextlib import suppress +from functools import partial + + +def get_autocast(precision, device_type='cuda'): + if precision =='amp': + amp_dtype = torch.float16 + elif precision == 'amp_bfloat16' or precision == 'amp_bf16': + amp_dtype = torch.bfloat16 + else: + return suppress + + return partial(torch.amp.autocast, device_type=device_type, dtype=amp_dtype) \ No newline at end of file