diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..6e77c631547db050cb6e79c29d6e8df139becc6f --- /dev/null +++ b/app.py @@ -0,0 +1,68 @@ +import gradio as gr +from PIL import Image +from detector import FakeImageDetector + +print("正在初始化检测器,请稍候...") +try: + detector = FakeImageDetector() + + print("检测器初始化完成,Web 服务准备就绪。") + models_loaded = True +except Exception as e: + print(f"模型加载失败: {e}") + models_loaded = False + detector = None + +def predict_image(input_image_numpy, threshold): + """ + 接收 Gradio 的输入 (numpy array),调用检测器,并返回结果。 + """ + if not models_loaded or detector is None: + return "错误:模型未能成功加载,请检查后台日志。", None + + pil_image = Image.fromarray(input_image_numpy) + + result_text, score = detector.detect(pil_image, threshold) + + label_color = "red" if score > threshold else "green" + + return result_text, gr.Label(value=f"{score:.10f}", label=label_color) + + +with gr.Blocks(title="伪造图像检测器", theme=gr.themes.Soft()) as demo: + gr.Markdown( + """ + # 伪造图像检测器 (Fake Image Detector) + 上传一张图片,模型将判断其为 **真实的 (Real)** 还是 **AI 生成的伪造图像 (Fake)**。 + """ + ) + + with gr.Row(): + with gr.Column(scale=1): + # 输入组件 + image_input = gr.Image(type="numpy", label="上传图片", height=300) + # threshold_slider = gr.Slider( + # minimum=0.495, maximum=0.55, value=0.499892068, step=0.0001, + # label="检测门限 (Threshold)", + # info="得分低于此门限的图片被认为是伪造的" + # ) + submit_btn = gr.Button("开始检测", variant="primary") + + with gr.Column(scale=1): + # 输出组件 + result_output_text = gr.Textbox(label="检测结论", lines=2) + # 这里我们用一个临时的 Label 来显示带颜色的分数 + result_output_score = gr.Label(label="模型原始得分") + + submit_btn.click( + fn=predict_image, + inputs=[image_input, 0.49999], + outputs=[result_output_text, result_output_score] + ) + +if not models_loaded: + print("\n由于模型加载失败,Gradio Web服务无法启动。") +else: + print("正在启动 Gradio 服务...") + + demo.launch(server_name="0.0.0.0") \ No newline at end of file diff --git a/augmentations_clip.py b/augmentations_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..369e027236e2c6bbb377e8dc1f74f02ac4b97418 --- /dev/null +++ b/augmentations_clip.py @@ -0,0 +1,255 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# import logging + +from torchvision import transforms + +import torch +import cv2 +from PIL import Image +import numpy as np + +from my_transforms import ( + GaussianBlur, + make_normalize_transform, + make_normalize_transform_clip, +) + +def add_gaussian_noise(tensor, mean=0.0, std=0.1): + noise = torch.randn(tensor.size()).cuda() * std + mean + return tensor + noise + + + + +class DataAugmentationCLIP(object): + def __init__( + self, + global_crops_scale, + local_crops_scale, + local_crops_number, + global_crops_size=224, + local_crops_size=96, + ): + + self.source_trans = transforms.Compose([ + # transforms.RandomCrop(224), + # transforms.CenterCrop(224), + transforms.ToTensor(), + make_normalize_transform_clip(), + ]) + + # self.crop = transforms.Compose([ + # transforms.CenterCrop(224), + + # ]) + + self.crop = transforms.Compose([ + transforms.Resize(224), # 将短边缩放到 224,长边会按比例缩放 + transforms.RandomCrop(224), # 然后裁剪到 224x224 +]) + + self.centercrop = transforms.Compose([ + transforms.CenterCrop(224), + + ]) + + self.randomcrop = transforms.Compose([ + transforms.RandomCrop(224), + + ]) + + self.local_crops_number = local_crops_number + + def __call__(self, image): + output = {} + output["source"] = [] + + if np.array(image).shape[0]<224 or np.array(image).shape[1]<224: + crops_all = [ + self.centercrop(image) for _ in range(self.local_crops_number) + ] + else: + crops_all = [ + self.centercrop(image) for _ in range(self.local_crops_number) + ] + + for crops_image in crops_all: + output["source"].append(self.source_trans(crops_image)) #单独使用好一些 + + + output["offsets"] = () + + return output + + +class DataAugmentationDINO(object): + def __init__( + self, + global_crops_scale, + local_crops_scale, + local_crops_number, + global_crops_size=224, + local_crops_size=96, + ): + + self.source_trans = transforms.Compose([ + # transforms.RandomCrop(224), + # transforms.CenterCrop(224), + transforms.ToTensor(), + make_normalize_transform(), + ]) + + # self.crop = transforms.Compose([ + # transforms.CenterCrop(224), + + # ]) + + self.crop = transforms.Compose([ + transforms.Resize(224), # 将短边缩放到 224,长边会按比例缩放 + transforms.CenterCrop(224), # 然后裁剪到 224x224 +]) + + self.centercrop = transforms.Compose([ + transforms.CenterCrop(224), + + ]) + + self.local_crops_number = local_crops_number + + def __call__(self, image): + output = {} + output["source"] = [] + + if np.array(image).shape[0]<224 or np.array(image).shape[1]<224: + crops_all = [ + self.centercrop(image) for _ in range(self.local_crops_number) + ] + else: + crops_all = [ + self.centercrop(image) for _ in range(self.local_crops_number) + ] + + for crops_image in crops_all: + output["source"].append(self.source_trans(crops_image)) #单独使用好一些 + + + output["offsets"] = () + + return output + + +class DataAugmentationResNet_test(object): + def __init__( + self, + global_crops_scale, + local_crops_scale, + local_crops_number, + global_crops_size=224, + local_crops_size=96, + ): + + self.source_trans = transforms.Compose([ + # transforms.RandomCrop(224), + # transforms.CenterCrop(224), + transforms.ToTensor(), + make_normalize_transform(), + ]) + + # self.crop = transforms.Compose([ + # transforms.CenterCrop(224), + + # ]) + + self.crop = transforms.Compose([ + transforms.Resize(224), # 将短边缩放到 224,长边会按比例缩放 + transforms.CenterCrop(224), # 然后裁剪到 224x224 +]) + + self.centercrop = transforms.Compose([ + transforms.CenterCrop(224), + + ]) + + self.local_crops_number = local_crops_number + + def __call__(self, image): + output = {} + output["source"] = [] + + if np.array(image).shape[0]<224 or np.array(image).shape[1]<224: + crops_all = [ + self.centercrop(image) for _ in range(self.local_crops_number) + ] + else: + crops_all = [ + self.centercrop(image) for _ in range(self.local_crops_number) + ] + + for crops_image in crops_all: + output["source"].append(self.source_trans(crops_image)) #单独使用好一些 + + + output["offsets"] = () + + return output + + + +class DataAugmentationCLIP_gen(object): + def __init__( + self, + global_crops_scale, + local_crops_scale, + local_crops_number, + global_crops_size=224, + local_crops_size=96, + ): + + self.source_trans = transforms.Compose([ + # transforms.RandomCrop(224), + # transforms.CenterCrop(224), + transforms.ToTensor(), + make_normalize_transform_clip(), + ]) + + # self.crop = transforms.Compose([ + # transforms.RandomCrop(224), + + # ]) + + self.crop = transforms.Compose([ + transforms.Resize(224), # 将短边缩放到 224,长边会按比例缩放 + transforms.CenterCrop(224), # 然后裁剪到 224x224 +]) + + self.centercrop = transforms.Compose([ + transforms.CenterCrop(224), + + ]) + + self.local_crops_number = local_crops_number + + def __call__(self, image): + output = {} + output["source"] = [] + + # if np.array(image).shape[0]<224 or np.array(image).shape[1]<224: + # crops_all = [ + # self.crop(self.centercrop(image)) for _ in range(self.local_crops_number) + # ] + # else: + crops_all = [ + self.crop(image) for _ in range(self.local_crops_number) + ] + + for crops_image in crops_all: + output["source"].append(self.source_trans(crops_image)) #单独使用好一些 + + + output["offsets"] = () + + return output \ No newline at end of file diff --git a/clip/.ipynb_checkpoints/clip-checkpoint.py b/clip/.ipynb_checkpoints/clip-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..b4bf700321b793cf996cdb3b7c516f1f7861690d --- /dev/null +++ b/clip/.ipynb_checkpoints/clip-checkpoint.py @@ -0,0 +1,225 @@ +import hashlib +import os +import urllib +import warnings +from typing import Union, List + +import torch +from PIL import Image +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from tqdm import tqdm + +from .model import build_model +from .simple_tokenizer import SimpleTokenizer as _Tokenizer + +try: + from torchvision.transforms import InterpolationMode + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + + +if torch.__version__.split(".") < ["1", "7", "1"]: + warnings.warn("PyTorch version 1.7.1 or higher is recommended") + + +__all__ = ["available_models", "load", "tokenize"] +_tokenizer = _Tokenizer() + +_MODELS = { + "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", + "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", + "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", +} + + +def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: + raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def _transform(n_px): + return Compose([ + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + lambda image: image.convert("RGB"), + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + +def available_models() -> List[str]: + """Returns the names of available CLIP models""" + return list(_MODELS.keys()) + + +def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + + device : Union[str, torch.device] + The device to put the loaded model + + jit : bool + Whether to load the optimized JIT model or more hackable non-JIT model (default). + + Returns + ------- + model : torch.nn.Module + The CLIP model + + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + ''' + if name in _MODELS: + model_path = _download(_MODELS[name]) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + ''' + model_path = '/model/4DaiRui/pretrained_ood/ViT-B-16.pt' + + + try: + # loading JIT archive + model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(model_path, map_location="cpu") + + if not jit: + model = build_model(state_dict or model.state_dict()).to(device) + if str(device) == "cpu": + model.float() + return model, _transform(model.visual.input_resolution) + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + + return model, _transform(model.input_resolution.item()) + + +def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + + context_length : int + The context length to use; all CLIP models use 77 as the context length + + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder["<|startoftext|>"] + eot_token = _tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + result[i, :len(tokens)] = torch.tensor(tokens) + + return result diff --git a/clip/.ipynb_checkpoints/model-checkpoint.py b/clip/.ipynb_checkpoints/model-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..f2c95c481724270116998b90de64cee8ef58c94e --- /dev/null +++ b/clip/.ipynb_checkpoints/model-checkpoint.py @@ -0,0 +1,432 @@ +from collections import OrderedDict +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + + return x[0] + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.avgpool = nn.AvgPool2d(2) + self.relu = nn.ReLU(inplace=True) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + def stem(x): + for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisionTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x + + +class CLIP(nn.Module): + def __init__(self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int + ): + super().__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width + ) + else: + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask() + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image): + return self.visual(image.type(self.dtype)) + + def encode_text(self, text): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm(dim=-1, keepdim=True) + text_features = text_features / text_features.norm(dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logit_scale * text_features @ image_features.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +def build_model(state_dict: dict): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + else: + counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) + + model = CLIP( + embed_dim, + image_resolution, vision_layers, vision_width, vision_patch_size, + context_length, vocab_size, transformer_width, transformer_heads, transformer_layers + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + + convert_weights(model) + model.load_state_dict(state_dict) + return model.eval() diff --git a/clip/__init__.py b/clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dcc5619538c0f7c782508bdbd9587259d805e0d9 --- /dev/null +++ b/clip/__init__.py @@ -0,0 +1 @@ +from .clip import * diff --git a/clip/__pycache__/__init__.cpython-311.pyc b/clip/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3117e9001a6a21f49bec3bd8ef37a358fdf38d4 Binary files /dev/null and b/clip/__pycache__/__init__.cpython-311.pyc differ diff --git a/clip/__pycache__/__init__.cpython-36.pyc b/clip/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ebff55a3f878074b0707155cc42828b130dcbc0 Binary files /dev/null and b/clip/__pycache__/__init__.cpython-36.pyc differ diff --git a/clip/__pycache__/__init__.cpython-38.pyc b/clip/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..744aefdae161a7b857aad4d2937565f2266bffed Binary files /dev/null and b/clip/__pycache__/__init__.cpython-38.pyc differ diff --git a/clip/__pycache__/clip.cpython-311.pyc b/clip/__pycache__/clip.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..227fd4cc4e1376af1034272f01e61e0ee1201daa Binary files /dev/null and b/clip/__pycache__/clip.cpython-311.pyc differ diff --git a/clip/__pycache__/clip.cpython-36.pyc b/clip/__pycache__/clip.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..628cbc31e7c7e585def0a987949cd87fe5685f12 Binary files /dev/null and b/clip/__pycache__/clip.cpython-36.pyc differ diff --git a/clip/__pycache__/clip.cpython-38.pyc b/clip/__pycache__/clip.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e09a7bb5ddc013ae03de221764fdfcc0903c7d23 Binary files /dev/null and b/clip/__pycache__/clip.cpython-38.pyc differ diff --git a/clip/__pycache__/model.cpython-311.pyc b/clip/__pycache__/model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea59d03241a5ecb8c59c6d16acb6ca1c320444d0 Binary files /dev/null and b/clip/__pycache__/model.cpython-311.pyc differ diff --git a/clip/__pycache__/model.cpython-36.pyc b/clip/__pycache__/model.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf9e2354d5fa743cf12ffd46706712b0106d53c6 Binary files /dev/null and b/clip/__pycache__/model.cpython-36.pyc differ diff --git a/clip/__pycache__/model.cpython-38.pyc b/clip/__pycache__/model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5dc7e3001b8f44b844b154c842c109312698264c Binary files /dev/null and b/clip/__pycache__/model.cpython-38.pyc differ diff --git a/clip/__pycache__/simple_tokenizer.cpython-311.pyc b/clip/__pycache__/simple_tokenizer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d98dab1d15385bc3324f610f0d630548829bead5 Binary files /dev/null and b/clip/__pycache__/simple_tokenizer.cpython-311.pyc differ diff --git a/clip/__pycache__/simple_tokenizer.cpython-36.pyc b/clip/__pycache__/simple_tokenizer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ce2f1fc75930e00477ae3918b4e9894735eccd2 Binary files /dev/null and b/clip/__pycache__/simple_tokenizer.cpython-36.pyc differ diff --git a/clip/__pycache__/simple_tokenizer.cpython-38.pyc b/clip/__pycache__/simple_tokenizer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28899197c90989d077324409fa186df5d95638b0 Binary files /dev/null and b/clip/__pycache__/simple_tokenizer.cpython-38.pyc differ diff --git a/clip/__pycache__/utils.cpython-38.pyc b/clip/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73c0499ce226d10ccc82105d4245406849982ac8 Binary files /dev/null and b/clip/__pycache__/utils.cpython-38.pyc differ diff --git a/clip/bpe_simple_vocab_16e6.txt.gz b/clip/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/clip/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/clip/clip.py b/clip/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..14e7ec7fd040b2ca9d154e8736af150eb63ca9c2 --- /dev/null +++ b/clip/clip.py @@ -0,0 +1,228 @@ +import hashlib +import os +import urllib +import warnings +from typing import Union, List + +import torch +from PIL import Image +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from tqdm import tqdm + +from .model import build_model +from .simple_tokenizer import SimpleTokenizer as _Tokenizer + +try: + from torchvision.transforms import InterpolationMode + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + + +if torch.__version__.split(".") < ["1", "7", "1"]: + warnings.warn("PyTorch version 1.7.1 or higher is recommended") + + +__all__ = ["available_models", "load", "tokenize"] +_tokenizer = _Tokenizer() + +_MODELS = { + "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", + "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", + "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", + "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", + "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", +} + + +def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: + raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def _transform(n_px): + return Compose([ + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + lambda image: image.convert("RGB"), + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + +def available_models() -> List[str]: + """Returns the names of available CLIP models""" + return list(_MODELS.keys()) + + +def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + + device : Union[str, torch.device] + The device to put the loaded model + + jit : bool + Whether to load the optimized JIT model or more hackable non-JIT model (default). + + Returns + ------- + model : torch.nn.Module + The CLIP model + + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + + if name in _MODELS: + model_path = _download(_MODELS[name]) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + + #model_path = 'E:/code/lsn/clip/RN50.pt' + # model_path = 'E:/code/lsn/clip/ViT-B-16.pt' + + + try: + # loading JIT archive + model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(model_path, map_location="cpu") + + if not jit: + model = build_model(state_dict or model.state_dict()).to(device) + if str(device) == "cpu": + model.float() + return model, _transform(model.visual.input_resolution) + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + + return model, _transform(model.input_resolution.item()) + + +def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + + context_length : int + The context length to use; all CLIP models use 77 as the context length + + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder["<|startoftext|>"] + eot_token = _tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + result[i, :len(tokens)] = torch.tensor(tokens) + + return result diff --git a/clip/model.py b/clip/model.py new file mode 100644 index 0000000000000000000000000000000000000000..606fdfc5f1af7648b7ee1ae4d68071d319a186d7 --- /dev/null +++ b/clip/model.py @@ -0,0 +1,432 @@ +from collections import OrderedDict +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + + return x[0] + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.avgpool = nn.AvgPool2d(2) + self.relu = nn.ReLU(inplace=True) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + def stem(x): + for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisionTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + # if self.proj is not None: + # x = x @ self.proj + + return x + + +class CLIP(nn.Module): + def __init__(self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int + ): + super().__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width + ) + else: + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask() + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image): + return self.visual(image.type(self.dtype)) + + def encode_text(self, text): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm(dim=-1, keepdim=True) + text_features = text_features / text_features.norm(dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logit_scale * text_features @ image_features.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +def build_model(state_dict: dict): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + else: + counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) + + model = CLIP( + embed_dim, + image_resolution, vision_layers, vision_width, vision_patch_size, + context_length, vocab_size, transformer_width, transformer_heads, transformer_layers + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + + convert_weights(model) + model.load_state_dict(state_dict) + return model.eval() diff --git a/clip/simple_tokenizer.py b/clip/simple_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0a66286b7d5019c6e221932a813768038f839c91 --- /dev/null +++ b/clip/simple_tokenizer.py @@ -0,0 +1,132 @@ +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text diff --git a/config.py b/config.py new file mode 100644 index 0000000000000000000000000000000000000000..964738fbfd1aed333084e8ddcf2e3f94cc7fecf2 --- /dev/null +++ b/config.py @@ -0,0 +1,48 @@ +'''This file configures the training procedure because handling arguments in every single function is so exhaustive for +research purposes. Don't try this code if you are a software engineer.''' + +# device settings +device = 'cuda' # or 'cpu' +import torch +torch.cuda.set_device(0) + +# data settings +dataset_path = "dummy_dataset" +class_name = "dummy_class" +modelname = "dummy_test" + +img_size = (448, 448) +img_dims = [3] + list(img_size) + +# transformation settings +transf_rotations = True +transf_brightness = 0.0 +transf_contrast = 0.0 +transf_saturation = 0.0 +norm_mean, norm_std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] + +# network hyperparameters +n_scales = 3 # number of scales at which features are extracted, img_size is the highest - others are //2, //4,... +clamp_alpha = 3 # see paper equation 2 for explanation +n_coupling_blocks = 2 +fc_internal = 4096 # number of neurons in hidden layers of s-t-networks +dropout = 0# dropout in s-t-networks +lr_init = 2e-4 +n_feat = 256 * n_scales # do not change except you change the feature extractor + +# dataloader parameters +n_transforms = 4 # number of transformations per sample in training +n_transforms_test = 64 # number of transformations per sample in testing +batch_size = 24 # actual batch size is this value multiplied by n_transforms(_test) +batch_size_test = batch_size * n_transforms // n_transforms_test + +# total epochs = meta_epochs * sub_epochs +# evaluation after epochs +meta_epochs = 24 +sub_epochs = 8 + +# output settings +verbose = True +grad_map_viz = False +hide_tqdm_bar = True +save_model = True diff --git a/detector.py b/detector.py new file mode 100644 index 0000000000000000000000000000000000000000..00fface9967886fff15024e51bf9c6639583c6ce --- /dev/null +++ b/detector.py @@ -0,0 +1,86 @@ + +import torch +import clip +from PIL import Image +from torch.cuda.amp import autocast as autocast +from huggingface_hub import hf_hub_download +import spaces + +from model import flow_model +from augmentations_clip import DataAugmentationCLIP as DataAugmentationCLIP_test + +MODEL_REPO_ID = "davjoython/flow_fake" +FLOW_MODEL_FILENAME = "flow_fake_detector_centercrop_v4.pth" +CLIP_MODEL_FILENAME = "my_clip_ViT-L-14.pt" +class FakeImageDetector: + + def __init__(self): + + self.device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"检测器初始化在 CPU 上,运行时将使用 {self.device}") + + print(f"正在从 {MODEL_REPO_ID} 下载 CLIP 模型...") + clip_model_path = hf_hub_download( + repo_id=MODEL_REPO_ID, + filename=CLIP_MODEL_FILENAME + ) + print("CLIP 模型已下载。") + self.clip_model, _ = clip.load(clip_model_path, device="cpu") + self.clip_model.eval() + print("CLIP 模型已加载到 CPU。") + + print(f"正在从 {MODEL_REPO_ID} 下载 Flow 模型...") + flow_model_path = hf_hub_download( + repo_id=MODEL_REPO_ID, + filename=FLOW_MODEL_FILENAME + ) + print("Flow 模型已下载。") + self.flow = flow_model() + self.flow.load_state_dict(torch.load(flow_model_path, map_location="cpu")) + self.flow = self.flow.to("cpu") + self.flow.eval() + print("Flow 模型已加载到 CPU。") + + print("模型加载完成。") + + self.transform = DataAugmentationCLIP_test( + (0.9, 1.0), (0.05, 0.4), 1, + global_crops_size=224, local_crops_size=96, + ) + + @spaces.GPU(duration=10) + def detect(self, image_pil, threshold=0.5): + + if not isinstance(image_pil, Image.Image): + raise TypeError("输入必须是 PIL Image 对象") + + img_rgb = image_pil.convert("RGB") + + current_device = "cuda" if torch.cuda.is_available() else "cpu" + + flow_model_gpu = self.flow.to(current_device) + clip_model_gpu = self.clip_model.to(current_device) + + transformed_img_dict = self.transform(img_rgb) + img_tensor = transformed_img_dict["source"][0].unsqueeze(0).to(current_device) + + with torch.no_grad(): + if current_device == "cuda": + with autocast(): + embedding = clip_model_gpu.visual(img_tensor.half()) + z = flow_model_gpu(embedding) + score = 1 - torch.sigmoid(torch.mean(z.float()**2 / 10000, dim=1)).item() + else: + embedding = clip_model_gpu.visual(img_tensor) + z = flow_model_gpu(embedding.float()) + score = 1 - torch.sigmoid(torch.mean(z.float()**2 / 10000, dim=1)).item() + + if current_device == "cuda": + torch.cuda.empty_cache() + + if score > threshold: + result_text = f"结论: 伪造的 (Fake)\n分数: {score:.10f}" + else: + result_text = f"结论: 真实的 (Real)\n分数: {score:.10f}" + + return result_text, score \ No newline at end of file diff --git a/freia_funcs.py b/freia_funcs.py new file mode 100644 index 0000000000000000000000000000000000000000..24947f878af948070ee9f90a530575f6389770c8 --- /dev/null +++ b/freia_funcs.py @@ -0,0 +1,473 @@ +'''This Code is based on the FrEIA Framework, source: https://github.com/VLL-HD/FrEIA +It is a assembly of the necessary modules/functions from FrEIA that are needed for our purposes.''' +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +from math import exp +import numpy as np +VERBOSE = False + + +class dummy_data: + def __init__(self, *dims): + self.dims = dims + + @property + def shape(self): + return self.dims + +class F_fully_connected(nn.Module): + '''Fully connected tranformation, not reversible, but used below.''' + + def __init__(self, size_in, size, internal_size=None, dropout=0.0): + super(F_fully_connected, self).__init__() + if not internal_size: + internal_size = 2*size + + self.d1 = nn.Dropout(p=dropout) + self.d2 = nn.Dropout(p=dropout) + self.d2b = nn.Dropout(p=dropout) + + self.fc1 = nn.Linear(size_in, internal_size) + self.fc2 = nn.Linear(internal_size, internal_size) + self.fc2b = nn.Linear(internal_size, internal_size) + self.fc3 = nn.Linear(internal_size, size) + + self.nl1 = nn.ReLU() + self.nl2 = nn.ReLU() + self.nl2b = nn.ReLU() + + self.bn = nn.BatchNorm1d(size_in) + + + def forward(self, x): + out = self.nl1(self.d1(self.fc1(x))) + out = self.nl2(self.d2(self.fc2(out))) + out = self.nl2b(self.d2b(self.fc2b(out))) + out = self.fc3(out) + return out + +class permute_layer(nn.Module): + '''permutes input vector in a random but fixed way''' + + def __init__(self, dims_in, seed): + super(permute_layer, self).__init__() + self.in_channels = dims_in[0][0] + + np.random.seed(seed) + self.perm = np.random.permutation(self.in_channels) + np.random.seed() + + self.perm_inv = np.zeros_like(self.perm) + for i, p in enumerate(self.perm): + self.perm_inv[p] = i + + self.perm = torch.LongTensor(self.perm) + self.perm_inv = torch.LongTensor(self.perm_inv) + + def forward(self, x, rev=False): + if not rev: + return [x[0][:, self.perm]] + else: + return [x[0][:, self.perm_inv]] + + def jacobian(self, x, rev=False): + # TODO: use batch size, set as nn.Parameter so cuda() works + return 0. + + def output_dims(self, input_dims): + assert len(input_dims) == 1, "Can only use 1 input" + return input_dims + + + +class glow_coupling_layer(nn.Module): + def __init__(self, dims_in, F_class=F_fully_connected, F_args={}, + clamp=5.): + super(glow_coupling_layer, self).__init__() + channels = dims_in[0][0] + self.ndims = len(dims_in[0]) + + self.split_len1 = channels // 2 + self.split_len2 = channels - channels // 2 + + self.clamp = clamp + self.max_s = exp(clamp) + self.min_s = exp(-clamp) + + self.s1 = F_class(self.split_len1, self.split_len2*2, **F_args) + self.s2 = F_class(self.split_len2, self.split_len1*2, **F_args) + + def e(self, s): + return torch.exp(self.log_e(s)) + + def log_e(self, s): + return self.clamp * 0.636 * torch.atan(s / self.clamp) + + def forward(self, x, rev=False): + x1, x2 = (x[0].narrow(1, 0, self.split_len1), + x[0].narrow(1, self.split_len1, self.split_len2)) + + if not rev: + r2 = self.s2(x2) + s2, t2 = r2[:, :self.split_len1], r2[:, self.split_len1:] + #print(s2.shape, x1.shape, t2.shape) + y1 = self.e(s2) * x1 + t2 + + r1 = self.s1(y1) + s1, t1 = r1[:, :self.split_len2], r1[:, self.split_len2:] + y2 = self.e(s1) * x2 + t1 + + else: # names of x and y are swapped! + r1 = self.s1(x1) + s1, t1 = r1[:, :self.split_len2], r1[:, self.split_len2:] + y2 = (x2 - t1) / self.e(s1) + + r2 = self.s2(y2) + s2, t2 = r2[:, :self.split_len1], r2[:, self.split_len1:] + y1 = (x1 - t2) / self.e(s2) + y = torch.cat((y1, y2), 1) + y = torch.clamp(y, -1e6, 1e6) + return [y] + + def jacobian(self, x, rev=False): + x1, x2 = (x[0].narrow(1, 0, self.split_len1), + x[0].narrow(1, self.split_len1, self.split_len2)) + + if not rev: + r2 = self.s2(x2) + s2, t2 = r2[:, :self.split_len1], r2[:, self.split_len1:] + y1 = self.e(s2) * x1 + t2 + + r1 = self.s1(y1) + s1, t1 = r1[:, :self.split_len2], r1[:, self.split_len2:] + + else: # names of x and y are swapped! + r1 = self.s1(x1) + s1, t1 = r1[:, :self.split_len2], r1[:, self.split_len2:] + y2 = (x2 - t1) / self.e(s1) + + r2 = self.s2(y2) + s2, t2 = r2[:, :self.split_len1], r2[:, self.split_len1:] + + jac = (torch.sum(self.log_e(s1), dim=1) + + torch.sum(self.log_e(s2), dim=1)) + for i in range(self.ndims-1): + jac = torch.sum(jac, dim=1) + + return jac + + def output_dims(self, input_dims): + assert len(input_dims) == 1, "Can only use 1 input" + return input_dims + +class Node: + '''The Node class represents one transformation in the graph, with an + arbitrary number of in- and outputs.''' + def __init__(self, inputs, module_type, module_args, name=None): + self.inputs = inputs + self.outputs = [] + self.module_type = module_type + self.module_args = module_args + + self.input_dims, self.module = None, None + self.computed = None + self.computed_rev = None + self.id = None + + if name: + self.name = name + else: + self.name = hex(id(self))[-6:] + for i in range(255): + exec('self.out{0} = (self, {0})'.format(i)) + + def build_modules(self, verbose=VERBOSE): + ''' Returns a list with the dimension of each output of this node, + recursively calling build_modules of the nodes connected to the input. + Use this information to initialize the pytorch nn.Module of this node. + ''' + + if not self.input_dims: # Only do it if this hasn't been computed yet + self.input_dims = [n.build_modules(verbose=verbose)[c] + for n, c in self.inputs] + try: + self.module = self.module_type(self.input_dims, + **self.module_args) + except Exception as e: + print('Error in node %s' % (self.name)) + raise e + + if verbose: + print("Node %s has following input dimensions:" % (self.name)) + for d, (n, c) in zip(self.input_dims, self.inputs): + print("\t Output #%i of node %s:" % (c, n.name), d) + print() + + self.output_dims = self.module.output_dims(self.input_dims) + self.n_outputs = len(self.output_dims) + + return self.output_dims + + def run_forward(self, op_list): + '''Determine the order of operations needed to reach this node. Calls + run_forward of parent nodes recursively. Each operation is appended to + the global list op_list, in the form (node ID, input variable IDs, + output variable IDs)''' + + if not self.computed: + + # Compute all nodes which provide inputs, filter out the + # channels you need + self.input_vars = [] + for i, (n, c) in enumerate(self.inputs): + self.input_vars.append(n.run_forward(op_list)[c]) + # Register youself as an output in the input node + n.outputs.append((self, i)) + + # All outputs could now be computed + self.computed = [(self.id, i) for i in range(self.n_outputs)] + op_list.append((self.id, self.input_vars, self.computed)) + + # Return the variables you have computed (this happens mulitple times + # without recomputing if called repeatedly) + return self.computed + + def run_backward(self, op_list): + '''See run_forward, this is the same, only for the reverse computation. + Need to call run_forward first, otherwise this function will not + work''' + + assert len(self.outputs) > 0, "Call run_forward first" + if not self.computed_rev: + + # These are the input variables that must be computed first + output_vars = [(self.id, i) for i in range(self.n_outputs)] + + # Recursively compute these + for n, c in self.outputs: + n.run_backward(op_list) + + # The variables that this node computes are the input variables + # from the forward pass + self.computed_rev = self.input_vars + op_list.append((self.id, output_vars, self.computed_rev)) + + return self.computed_rev + + +class InputNode(Node): + '''Special type of node that represents the input data of the whole net (or + ouput when running reverse)''' + + def __init__(self, *dims, name='node'): + self.name = name + self.data = dummy_data(*dims) + self.outputs = [] + self.module = None + self.computed_rev = None + self.n_outputs = 1 + self.input_vars = [] + self.out0 = (self, 0) + + def build_modules(self, verbose=VERBOSE): + return [self.data.shape] + + def run_forward(self, op_list): + return [(self.id, 0)] + + +class OutputNode(Node): + '''Special type of node that represents the output of the whole net (of the + input when running in reverse)''' + class dummy(nn.Module): + + def __init__(self, *args): + super(OutputNode.dummy, self).__init__() + + def __call__(*args): + return args + + def output_dims(*args): + return args + + def __init__(self, inputs, name='node'): + self.module_type, self.module_args = self.dummy, {} + self.output_dims = [] + self.inputs = inputs + self.input_dims, self.module = None, None + self.computed = None + self.id = None + self.name = name + + for c, inp in enumerate(self.inputs): + inp[0].outputs.append((self, c)) + + def run_backward(self, op_list): + return [(self.id, 0)] + + +class ReversibleGraphNet(nn.Module): + '''This class represents the invertible net itself. It is a subclass of + torch.nn.Module and supports the same methods. The forward method has an + additional option 'rev', whith which the net can be computed in reverse.''' + + def __init__(self, node_list, ind_in=None, ind_out=None, verbose=False): + '''node_list should be a list of all nodes involved, and ind_in, + ind_out are the indexes of the special nodes InputNode and OutputNode + in this list.''' + super(ReversibleGraphNet, self).__init__() + + # Gather lists of input and output nodes + if ind_in is not None: + if isinstance(ind_in, int): + self.ind_in = list([ind_in]) + else: + self.ind_in = ind_in + else: + self.ind_in = [i for i in range(len(node_list)) + if isinstance(node_list[i], InputNode)] + assert len(self.ind_in) > 0, "No input nodes specified." + if ind_out is not None: + if isinstance(ind_out, int): + self.ind_out = list([ind_out]) + else: + self.ind_out = ind_out + else: + self.ind_out = [i for i in range(len(node_list)) + if isinstance(node_list[i], OutputNode)] + assert len(self.ind_out) > 0, "No output nodes specified." + + self.return_vars = [] + self.input_vars = [] + + # Assign each node a unique ID + self.node_list = node_list + for i, n in enumerate(node_list): + n.id = i + + # Recursively build the nodes nn.Modules and determine order of + # operations + ops = [] + for i in self.ind_out: + node_list[i].build_modules(verbose=verbose) + node_list[i].run_forward(ops) + + # create list of Pytorch variables that are used + variables = set() + for o in ops: + variables = variables.union(set(o[1] + o[2])) + self.variables_ind = list(variables) + + self.indexed_ops = self.ops_to_indexed(ops) + + self.module_list = nn.ModuleList([n.module for n in node_list]) + self.variable_list = [Variable(requires_grad=True) for v in variables] + + # Find out the order of operations for reverse calculations + ops_rev = [] + for i in self.ind_in: + node_list[i].run_backward(ops_rev) + self.indexed_ops_rev = self.ops_to_indexed(ops_rev) + + def ops_to_indexed(self, ops): + '''Helper function to translate the list of variables (origin ID, channel), + to variable IDs.''' + result = [] + + for o in ops: + try: + vars_in = [self.variables_ind.index(v) for v in o[1]] + except ValueError: + vars_in = -1 + + vars_out = [self.variables_ind.index(v) for v in o[2]] + + # Collect input/output nodes in separate lists, but don't add to + # indexed ops + if o[0] in self.ind_out: + self.return_vars.append(self.variables_ind.index(o[1][0])) + continue + if o[0] in self.ind_in: + self.input_vars.append(self.variables_ind.index(o[1][0])) + continue + + result.append((o[0], vars_in, vars_out)) + + # Sort input/output variables so they correspond to initial node list + # order + self.return_vars.sort(key=lambda i: self.variables_ind[i][0]) + self.input_vars.sort(key=lambda i: self.variables_ind[i][0]) + + return result + + def forward(self, x, rev=False): + '''Forward or backward computation of the whole net.''' + if rev: + use_list = self.indexed_ops_rev + input_vars, output_vars = self.return_vars, self.input_vars + else: + use_list = self.indexed_ops + input_vars, output_vars = self.input_vars, self.return_vars + + if isinstance(x, (list, tuple)): + assert len(x) == len(input_vars), ( + f"Got list of {len(x)} input tensors for " + f"{'inverse' if rev else 'forward'} pass, but expected " + f"{len(input_vars)}." + ) + for i in range(len(input_vars)): + self.variable_list[input_vars[i]] = x[i] + else: + assert len(input_vars) == 1, (f"Got single input tensor for " + f"{'inverse' if rev else 'forward'} " + f"pass, but expected list of " + f"{len(input_vars)}.") + self.variable_list[input_vars[0]] = x + + for o in use_list: + try: + results = self.module_list[o[0]]([self.variable_list[i] + for i in o[1]], rev=rev) + except TypeError: + raise RuntimeError("Are you sure all used Nodes are in the " + "Node list?") + for i, r in zip(o[2], results): + self.variable_list[i] = r + # self.variable_list[o[2][0]] = self.variable_list[o[1][0]] + + out = [self.variable_list[output_vars[i]] + for i in range(len(output_vars))] + if len(out) == 1: + return out[0] + else: + return out + + def jacobian(self, x=None, rev=False, run_forward=True): + '''Compute the jacobian determinant of the whole net.''' + jacobian = 0 + + if rev: + use_list = self.indexed_ops_rev + else: + use_list = self.indexed_ops + + if run_forward: + if x is None: + raise RuntimeError("You need to provide an input if you want " + "to run a forward pass") + self.forward(x, rev=rev) + jacobian_list = list() + for o in use_list: + try: + node_jac = self.module_list[o[0]].jacobian( + [self.variable_list[i] for i in o[1]], rev=rev + ) + jacobian += node_jac + jacobian_list.append(jacobian) + except TypeError: + raise RuntimeError("Are you sure all used Nodes are in the " + "Node list?") + + return jacobian \ No newline at end of file diff --git a/loralib/__init__.py b/loralib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1aa9c48e838e1bcc9ce9f5c3a790aa94295bbf87 --- /dev/null +++ b/loralib/__init__.py @@ -0,0 +1,2 @@ +from .layers import * +from .utils import * \ No newline at end of file diff --git a/loralib/__pycache__/__init__.cpython-38.pyc b/loralib/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76f3ae4ebf25b9a2a584a640707ea162ef04b339 Binary files /dev/null and b/loralib/__pycache__/__init__.cpython-38.pyc differ diff --git a/loralib/__pycache__/layers.cpython-38.pyc b/loralib/__pycache__/layers.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1625c7d0b6cdc27a9cb10219b0d4aae973b9fc96 Binary files /dev/null and b/loralib/__pycache__/layers.cpython-38.pyc differ diff --git a/loralib/__pycache__/utils.cpython-38.pyc b/loralib/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ccff90067ef6065ec21473bdc653c715b574d27 Binary files /dev/null and b/loralib/__pycache__/utils.cpython-38.pyc differ diff --git a/loralib/easymultiheadattention.py b/loralib/easymultiheadattention.py new file mode 100644 index 0000000000000000000000000000000000000000..be028b176155fe8c34f3cc335b885184a4dd4d35 --- /dev/null +++ b/loralib/easymultiheadattention.py @@ -0,0 +1,124 @@ +import torch.nn as nn +import torch +import torch.nn.functional as F + +""" +Source : https://github.com/KyanChen/MakeMultiHeadNaive/blob/master/main.py +""" + +class PlainMultiHeadAttention(nn.Module): + def __init__( + self, + existing_mha: nn.MultiheadAttention): + super().__init__() + + self.dropout = 0 # this module is not used to retrain the main block + self.embed_dim = existing_mha.embed_dim + self.kdim = existing_mha.kdim + self.vdim = existing_mha.vdim + self._qkv_same_embed_dim = existing_mha._qkv_same_embed_dim + self.num_heads = existing_mha.num_heads + self.batch_first = existing_mha.batch_first + self.head_dim = existing_mha.head_dim + self.qkv = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=existing_mha.in_proj_bias is not None) + self.proj = nn.Linear(self.embed_dim, self.embed_dim, bias=existing_mha.out_proj.bias is not None) + + # Initialize parameters + with torch.no_grad(): + self.qkv.weight.data.copy_(existing_mha.in_proj_weight.data) + if self.qkv.bias is not None: + self.qkv.bias.data.copy_(existing_mha.in_proj_bias.data) + self.proj.weight.data.copy_(existing_mha.out_proj.weight.data) + if self.proj.bias is not None: + self.proj.bias.data.copy_(existing_mha.out_proj.bias.data) + + self.scaled_dot_product_attention = F.scaled_dot_product_attention + + def forward( + self, + query, + key, + value, + key_padding_mask=None, + need_weights=True, + attn_mask=None, + average_attn_weights=True, + is_causal=False): + + if attn_mask is not None and is_causal: + raise AssertionError("Only allow causal mask or attn_mask") + is_batched = query.dim() == 3 + key_padding_mask = F._canonical_mask( + mask=key_padding_mask, + mask_name="key_padding_mask", + other_type=F._none_or_dtype(attn_mask), + other_name="attn_mask", + target_type=query.dtype + ) + + if self.batch_first and is_batched: + if key is value: + if query is key: + query = key = value = query.transpose(1, 0) + else: + query, key = [x.transpose(1, 0) for x in (query, key)] + value = key + else: + query, key, value = [x.transpose(1, 0) for x in (query, key, value)] + + tgt_len, bsz, embed_dim = query.shape + src_len, _, _ = key.shape + + E = query.size(-1) + qkv = self.qkv(query) + qkv = qkv.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + + attn_mask = F._canonical_mask( + mask=attn_mask, + mask_name="attn_mask", + other_type=F._none_or_dtype(key_padding_mask), + other_name="key_padding_mask", + target_type=q.dtype, + check_other=False, + ) + + if attn_mask is not None: + # ensure attn_mask's dim is 3 + if attn_mask.dim() == 2: + correct_2d_size = (tgt_len, src_len) + if attn_mask.shape != correct_2d_size: + raise RuntimeError( + f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.") + attn_mask = attn_mask.unsqueeze(0) + elif attn_mask.dim() == 3: + correct_3d_size = (bsz * self.num_heads, tgt_len, src_len) + if attn_mask.shape != correct_3d_size: + raise RuntimeError( + f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.") + else: + raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported") + + if attn_mask is not None: + if attn_mask.size(0) == 1 and attn_mask.dim() == 3: + attn_mask = attn_mask.unsqueeze(0) + else: + attn_mask = attn_mask.view(bsz, self.num_heads, -1, src_len) + + dropout_p = self.dropout if self.training else 0. + + q = q.view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) + k = k.view(src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) + v = v.view(src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) + src_len = k.size(1) + q = q.view(bsz, self.num_heads, tgt_len, self.head_dim) + k = k.view(bsz, self.num_heads, src_len, self.head_dim) + v = v.view(bsz, self.num_heads, src_len, self.head_dim) + + attn_output = self.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal) + attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim) + attn_output = self.proj(attn_output) + attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) + if self.batch_first and is_batched: + return attn_output.transpose(1, 0), None + return attn_output, None \ No newline at end of file diff --git a/loralib/layers.py b/loralib/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..cb5a75977d02a828ef24c4515694f4ab1f2163f1 --- /dev/null +++ b/loralib/layers.py @@ -0,0 +1,598 @@ +# ------------------------------------------------------------------------------------------ +# This code is reconstructed based on loralib (https://github.com/microsoft/LoRA) by Baijiong Lin. +# ------------------------------------------------------------------------------------------ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import math +from typing import Optional, List + +def set_param(curr_mod, name, param=None, mode='update'): + r"""Refer to https://github.com/Baijiong-Lin/MOML/blob/main/MTL/utils.py""" + if '.' in name: + n = name.split('.') + module_name = n[0] + rest = '.'.join(n[1:]) + for name, mod in curr_mod.named_children(): + if module_name == name: + return set_param(mod, rest, param, mode=mode) + else: + if mode == 'update': + delattr(curr_mod, name) + setattr(curr_mod, name, param) + elif mode == 'get': + if hasattr(curr_mod, name): + p = getattr(curr_mod, name) + return p + +class LoRALayer(): + def __init__( + self, + r: int, + lora_alpha: int, + fan_in_fan_out: bool = False, + dropout_rate:float = 0, + ): + self.r = r + self.lora_alpha = lora_alpha + self.dropout_rate = dropout_rate + if self.r > 0: + #self.scaling = self.lora_alpha / self.r + self.scaling = self.lora_alpha/math.sqrt(self.r) # + # Mark the weight as unmerged + self.merged = False + # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + self.fan_in_fan_out = fan_in_fan_out + # define params that require LoRA {'param_name': 'lora_name'} + self.params_with_lora = {} + + def register_lora_param(self): + r"""Register LoRA matrix""" + for param_name, lora_name in self.params_with_lora.items(): + assert len(eval(f'self.{param_name}').size()) == 2 + self.register_parameter(f'{lora_name}_lora_A', + nn.Parameter(eval(f'self.{param_name}').new_zeros((self.r, eval(f'self.{param_name}').size()[1]))) + ) + self.register_parameter(f'{lora_name}_lora_B', + nn.Parameter(eval(f'self.{param_name}').new_zeros((eval(f'self.{param_name}').size()[0], self.r))) + ) + + eval(f'self.{param_name}').requires_grad = False + + def init_lora_param(self): + for param_name, lora_name in self.params_with_lora.items(): + if hasattr(self, f'{lora_name}_lora_A'): + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.kaiming_uniform_(eval(f'self.{lora_name}_lora_A'), a=math.sqrt(5)) + nn.init.zeros_(eval(f'self.{lora_name}_lora_B')) + + def transpose(self, w: torch.Tensor): + return w.transpose(0, 1) if self.fan_in_fan_out else w + + def merge_BA(self, param_name: str): + lora_name = self.params_with_lora[param_name] + return self.transpose((eval(f'self.{lora_name}_lora_B') @ eval(f'self.{lora_name}_lora_A')).view(eval(f'self.{param_name}').shape)) + + + + + def merge_lora_param(self): + r"""p_new = p + scaling * B @ A and keep differentiable to A and B""" + for param_name, lora_name in self.params_with_lora.items(): + p = set_param(self, param_name, mode='get') + # detach() is very important here + + p_new = p.detach() + self.merge_BA(param_name) * self.scaling + set_param(self, param_name, param=p_new, mode='update') + + def add_lora_data(self): + r"""NOT differentiable""" + for param_name, lora_name in self.params_with_lora.items(): + eval(f'self.{param_name}').data += self.merge_BA(param_name) * self.scaling + + def sub_lora_data(self): + r"""NOT differentiable""" + for param_name, lora_name in self.params_with_lora.items(): + eval(f'self.{param_name}').data -= self.merge_BA(param_name) * self.scaling + + + def lora_train(self, mode: bool = True): + if mode: + if self.merged and self.r > 0: + # Make sure that the weights are not merged + self.sub_lora_data() + self.merged = False + else: + if not self.merged and self.r > 0: + # Merge the weights and mark it + self.add_lora_data() + self.merged = True + + +class Embedding(nn.Embedding, LoRALayer): + # LoRA implemented in a Embedding layer + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + r: int = 0, + lora_alpha: int = 1, + **kwargs + ): + nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) + LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha) + + self.params_with_lora = {'weight': 'w'} + if r > 0: + self.register_lora_param() + nn.Embedding.reset_parameters(self) + self.init_lora_param() + + def init_lora_param(self): + if hasattr(self, 'w_lora_A'): + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.zeros_(self.w_lora_A) + nn.init.normal_(self.w_lora_B) + + def train(self, mode: bool = True): + nn.Embedding.train(self, mode) + self.lora_train(mode) + + def forward(self, x: torch.Tensor, **kwargs): + + if self.r > 0 and not self.merged: + self.merge_lora_param() + result = nn.Embedding.forward(self, x, **kwargs) + self.sub_lora_data() + return result + else: + return nn.Embedding.forward(self, x, **kwargs) + +class LinearLoRA(nn.Linear, LoRALayer): + # LoRA implemented in a Linear layer + def __init__( + self, + existing_linear: nn.Linear, + r: int = 0, + lora_alpha: int = 1, + fan_in_fan_out: bool = False, + dropout_rate = 0., + **kwargs + ): + super().__init__( + in_features=existing_linear.in_features, + out_features=existing_linear.out_features) + + self.load_state_dict(existing_linear.state_dict()) + LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, fan_in_fan_out=fan_in_fan_out) + + # Actual trainable parameters + self.params_with_lora = {'weight': 'w'} + if r > 0: + self.register_lora_param() + self.init_lora_param() + self.weight.data = self.transpose(self.weight.data) + if dropout_rate > 0: + self.dropout = nn.Dropout(dropout_rate) + else: + self.dropout = None + + def train(self, mode: bool = True): + super().train(mode) + self.lora_train(mode) + + + def forward(self, x: torch.Tensor, **kwargs): + + if self.dropout is None: # do as before + if self.r > 0 and not self.merged: + self.merge_lora_param() + result = nn.Linear.forward(self, x, **kwargs) + self.sub_lora_data() + return result + else: + return nn.Linear.forward(self, x, **kwargs) + + # Compute the original linear transformation + original_output = nn.Linear.forward(self, x) + + if self.training and self.dropout.p > 0: + x = self.dropout(x) + + if self.r > 0 and not self.merged: + lora_adjustment = torch.matmul(x,self.merge_BA('weight').transpose(0, 1)) * self.scaling + result = original_output + lora_adjustment + else: + result = original_output + return result + +class Conv1d(nn.Conv1d, LoRALayer): + # LoRA implemented in a Conv1d layer + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + r: int = 0, + lora_alpha: int = 1, + **kwargs + ): + nn.Conv1d.__init__(self, in_channels, out_channels, kernel_size, **kwargs) + LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha) + + assert type(kernel_size) is int + # Actual trainable parameters + self.params_with_lora = {'weight': 'w'} + if r > 0: + self.w_lora_A = nn.Parameter( + self.weight.new_zeros((r*kernel_size, in_channels*kernel_size)) + ) + self.w_lora_B = nn.Parameter( + self.weight.new_zeros((out_channels//self.groups*kernel_size, r*kernel_size)) + ) + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + nn.Conv1d.reset_parameters(self) + self.init_lora_param() + + def train(self, mode: bool = True): + nn.Conv1d.train(self, mode) + self.lora_train(mode) + + def forward(self, x: torch.Tensor, **kwargs): + + if self.r > 0 and not self.merged: + self.merge_lora_param() + result = nn.Conv1d.forward(self, x, **kwargs) + self.sub_lora_data() + return result + else: + return nn.Conv1d.forward(self, x, **kwargs) + +class Conv2d(nn.Conv2d, LoRALayer): + # LoRA implemented in a Conv2d layer + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + r: int = 0, + lora_alpha: int = 1, + **kwargs + ): + nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs) + LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha) + + assert type(kernel_size) is int + # Actual trainable parameters + self.params_with_lora = {'weight': 'w'} + if r > 0: + self.w_lora_A = nn.Parameter( + self.weight.new_zeros((r*kernel_size, in_channels*kernel_size)) + ) + self.w_lora_B = nn.Parameter( + self.weight.new_zeros((out_channels//self.groups*kernel_size, r*kernel_size)) + ) + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + nn.Conv2d.reset_parameters(self) + self.init_lora_param() + + def train(self, mode: bool = True): + nn.Conv2d.train(self, mode) + self.lora_train(mode) + + def forward(self, x: torch.Tensor, **kwargs): + + if self.r > 0 and not self.merged: + self.merge_lora_param() + result = nn.Conv2d.forward(self, x, **kwargs) + self.sub_lora_data() + return result + else: + return nn.Conv2d.forward(self, x, **kwargs) + +class Conv3d(nn.Conv3d, LoRALayer): + # LoRA implemented in a Conv3d layer + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + r: int = 0, + lora_alpha: int = 1, + **kwargs + ): + nn.Conv3d.__init__(self, in_channels, out_channels, kernel_size, **kwargs) + LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha) + + assert type(kernel_size) is int + # Actual trainable parameters + self.params_with_lora = {'weight': 'w'} + if r > 0: + self.w_lora_A = nn.Parameter( + self.weight.new_zeros((r*kernel_size, in_channels*kernel_size)) + ) + self.w_lora_B = nn.Parameter( + self.weight.new_zeros((out_channels//self.groups*kernel_size, r*kernel_size)) + ) + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + nn.Conv3d.reset_parameters(self) + self.init_lora_param() + + def train(self, mode: bool = True): + nn.Conv3d.train(self, mode) + self.lora_train(mode) + + def forward(self, x: torch.Tensor, **kwargs): + + if self.r > 0 and not self.merged: + self.merge_lora_param() + result = nn.Conv3d.forward(self, x, **kwargs) + self.sub_lora_data() + return result + else: + return nn.Conv3d.forward(self, x, **kwargs) + + +class PlainMultiheadAttentionLoRA(nn.Module): + def __init__( + self, + existing_mha: nn.MultiheadAttention, + enable_lora: list = ['q', 'k', 'v', 'o'], + r: int = 0, + lora_alpha: int = 1, + dropout_rate:float = 0., + **kwargs + ): + super().__init__() + + self.dropout = 0 # this module is not used to retrain the main block + self.embed_dim = existing_mha.embed_dim + self.kdim = existing_mha.kdim + self.vdim = existing_mha.vdim + self._qkv_same_embed_dim = existing_mha._qkv_same_embed_dim + self.num_heads = existing_mha.num_heads + self.batch_first = existing_mha.batch_first + self.head_dim = existing_mha.head_dim + #self.qkv = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=existing_mha.in_proj_bias is not None) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=existing_mha.in_proj_bias is not None) + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=existing_mha.in_proj_bias is not None) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=existing_mha.in_proj_bias is not None) + self.proj = nn.Linear(self.embed_dim, self.embed_dim, bias=existing_mha.out_proj.bias is not None) + + # Initialize parameters + with torch.no_grad(): + + # Extract the existing weights and biases + existing_weight = existing_mha.in_proj_weight.data + existing_bias = existing_mha.in_proj_bias.data if existing_mha.in_proj_bias is not None else None + + # Initialize q_proj + self.q_proj.weight.data.copy_(existing_weight[:self.embed_dim, :]) + if existing_bias is not None: + self.q_proj.bias.data.copy_(existing_bias[:self.embed_dim]) + + # Initialize k_proj + self.k_proj.weight.data.copy_(existing_weight[self.embed_dim:2*self.embed_dim, :]) + if existing_bias is not None: + self.k_proj.bias.data.copy_(existing_bias[self.embed_dim:2*self.embed_dim]) + + # Initialize v_proj + self.v_proj.weight.data.copy_(existing_weight[2*self.embed_dim:, :]) + if existing_bias is not None: + self.v_proj.bias.data.copy_(existing_bias[2*self.embed_dim:]) + + # Initialize proj + self.proj.weight.data.copy_(existing_mha.out_proj.weight.data) + if self.proj.bias is not None: + self.proj.bias.data.copy_(existing_mha.out_proj.bias.data) + + self.scaled_dot_product_attention = F.scaled_dot_product_attention + + + LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, dropout_rate=dropout_rate) + + # Init qkv as a new lora linear layer + for item in enable_lora: + if item == 'q': + self.q_proj = LinearLoRA(self.q_proj, + r=r, + lora_alpha=lora_alpha, + fan_in_fan_out=False, + dropout_rate = dropout_rate) + elif item == 'k': + self.k_proj = LinearLoRA(self.k_proj, + r=r, + lora_alpha=lora_alpha, + fan_in_fan_out=False, + dropout_rate = dropout_rate) + elif item == 'v': + self.v_proj = LinearLoRA(self.v_proj, + r=r, + lora_alpha=lora_alpha, + fan_in_fan_out=False, + dropout_rate = dropout_rate) + elif item == 'o': + self.proj = LinearLoRA(self.proj, + r=r, + lora_alpha=lora_alpha, + fan_in_fan_out=False, + dropout_rate = dropout_rate) + + def forward_module( + self, + query, + key, + value, + key_padding_mask=None, + need_weights=True, + attn_mask=None, + average_attn_weights=True, + is_causal=False): + + if attn_mask is not None and is_causal: + raise AssertionError("Only allow causal mask or attn_mask") + is_batched = query.dim() == 3 + key_padding_mask = F._canonical_mask( + mask=key_padding_mask, + mask_name="key_padding_mask", + other_type=F._none_or_dtype(attn_mask), + other_name="attn_mask", + target_type=query.dtype + ) + + if self.batch_first and is_batched: + if key is value: + if query is key: + query = key = value = query.transpose(1, 0) + else: + query, key = [x.transpose(1, 0) for x in (query, key)] + value = key + else: + query, key, value = [x.transpose(1, 0) for x in (query, key, value)] + + tgt_len, bsz, embed_dim = query.shape + src_len, _, _ = key.shape + """ + E = query.size(-1) + qkv = self.qkv(query) + qkv = qkv.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] + """ + + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + + attn_mask = F._canonical_mask( + mask=attn_mask, + mask_name="attn_mask", + other_type=F._none_or_dtype(key_padding_mask), + other_name="key_padding_mask", + target_type=q.dtype, + check_other=False, + ) + + if attn_mask is not None: + # ensure attn_mask's dim is 3 + if attn_mask.dim() == 2: + correct_2d_size = (tgt_len, src_len) + if attn_mask.shape != correct_2d_size: + raise RuntimeError( + f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.") + attn_mask = attn_mask.unsqueeze(0) + elif attn_mask.dim() == 3: + correct_3d_size = (bsz * self.num_heads, tgt_len, src_len) + if attn_mask.shape != correct_3d_size: + raise RuntimeError( + f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.") + else: + raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported") + + if attn_mask is not None: + if attn_mask.size(0) == 1 and attn_mask.dim() == 3: + attn_mask = attn_mask.unsqueeze(0) + else: + attn_mask = attn_mask.view(bsz, self.num_heads, -1, src_len) + + dropout_p = self.dropout if self.training else 0. + + q = q.view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) + k = k.view(src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) + v = v.view(src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) + src_len = k.size(1) + q = q.view(bsz, self.num_heads, tgt_len, self.head_dim) + k = k.view(bsz, self.num_heads, src_len, self.head_dim) + v = v.view(bsz, self.num_heads, src_len, self.head_dim) + + attn_output = self.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal) + attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim) + attn_output = self.proj(attn_output) + attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) + if self.batch_first and is_batched: + return attn_output.transpose(1, 0), None + return attn_output, None + + def train(self, mode: bool = True): + super().train(mode) + #self.lora_train(mode) + + def forward(self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + **kwargs): + + + return self.forward_module(query, key, value, **kwargs) + + + +class MergedLinear(nn.Linear, LoRALayer): + # LoRA implemented in a dense layer + def __init__( + self, + in_features: int, + out_features: int, + r: int = 0, + lora_alpha: int = 1, + enable_lora: List[bool] = [False], + fan_in_fan_out: bool = False, + **kwargs + ): + nn.Linear.__init__(self, in_features, out_features, **kwargs) + LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha) + + assert out_features % len(enable_lora) == 0, \ + 'The length of enable_lora must divide out_features' + self.enable_lora = enable_lora + # Actual trainable parameters + self.params_with_lora = {'weight': 'w'} + if r > 0 and any(enable_lora): + self.w_lora_A = nn.Parameter( + self.weight.new_zeros((r * sum(enable_lora), in_features))) + self.w_lora_B = nn.Parameter( + self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r)) + ) # weights for Conv1D with groups=sum(enable_lora) + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + # Compute the indices + self.lora_ind = self.weight.new_zeros( + (out_features, ), dtype=torch.bool + ).view(len(enable_lora), -1) + self.lora_ind[enable_lora, :] = True + self.lora_ind = self.lora_ind.view(-1) + nn.Linear.reset_parameters(self) + self.init_lora_param() + self.weight.data = self.transpose(self.weight.data) + + def zero_pad(self, x): + result = x.new_zeros((len(self.lora_ind), *x.shape[1:])) + result[self.lora_ind] = x + return result + + def merge_BA(self, param_name: str): + lora_name = self.params_with_lora[param_name] + delta_w = F.conv1d( + eval(f'self.{lora_name}_lora_A').unsqueeze(0), + eval(f'self.{lora_name}_lora_B').unsqueeze(-1), + groups=sum(self.enable_lora) + ).squeeze(0) + return self.transpose(self.zero_pad(delta_w)) + + def train(self, mode: bool = True): + nn.Linear.train(self, mode) + self.lora_train(mode) + + def forward(self, x: torch.Tensor, **kwargs): + + if self.r > 0 and not self.merged: + self.merge_lora_param() + result = nn.Linear.forward(self, x, **kwargs) + self.sub_lora_data() + return result + else: + return nn.Linear.forward(self, x, **kwargs) \ No newline at end of file diff --git a/loralib/utils.py b/loralib/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7a594d68724babf7b1a12a17b14d71d36e5c7e83 --- /dev/null +++ b/loralib/utils.py @@ -0,0 +1,236 @@ +import os + +import torch +import torch.nn as nn + +from typing import Dict + +from .layers import LoRALayer, PlainMultiheadAttentionLoRA + +INDEX_POSITIONS_TEXT = { + 'top1': [11], + 'top2': [10, 11], + 'top3': [9, 10, 11], + 'bottom': [0, 1, 2, 3], + 'mid': [4, 5, 6, 7], + 'up': [8, 9, 10, 11], + 'half-up': [6, 7, 8, 9, 10, 11], + 'half-bottom': [0, 1, 2, 3, 4, 5], + 'all': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]} + + +INDEX_POSITIONS_VISION = { + 'ViT-B/16': { + 'top': [11], + 'top3': [9, 10, 11], + 'bottom': [0, 1, 2, 3], + 'mid': [4, 5, 6, 7], + 'up': [8, 9, 10, 11], + 'half-up': [6, 7, 8, 9, 10, 11], + 'half-bottom': [0, 1, 2, 3, 4, 5], + 'all': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]}, + 'ViT-B/32': { + 'bottom': [0, 1, 2, 3], + 'mid': [4, 5, 6, 7], + 'up': [8, 9, 10, 11], + 'half-up': [6, 7, 8, 9, 10, 11], + 'half-bottom': [0, 1, 2, 3, 4, 5], + 'all': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]}, + + 'ViT-L/14': { + 'half-up': [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23], + 'half-bottom': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + 'all': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]} +} + + +def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None: + for n, p in model.named_parameters(): + if 'lora_' not in n: + p.requires_grad = False + if bias == 'none': + return + elif bias == 'all': + for n, p in model.named_parameters(): + if 'bias' in n: + p.requires_grad = True + elif bias == 'lora_only': + for m in model.modules(): + if isinstance(m, LoRALayer) and \ + hasattr(m, 'bias') and \ + m.bias is not None: + m.bias.requires_grad = True + else: + raise NotImplementedError + + +def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, torch.Tensor]: + my_state_dict = model.state_dict() + if bias == 'none': + return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k} + elif bias == 'all': + return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'bias' in k} + elif bias == 'lora_only': + to_return = {} + for k in my_state_dict: + if 'lora_' in k: + to_return[k] = my_state_dict[k] + bias_name = k.split('lora_')[0]+'bias' + if bias_name in my_state_dict: + to_return[bias_name] = my_state_dict[bias_name] + return to_return + else: + raise NotImplementedError + + +def get_lora_parameters(model, bias='none'): + params = [] + for name, param in model.named_parameters(): + if bias == 'none': + if 'lora_' in name: + params.append(param) + elif bias == 'all': + if 'lora_' in name or 'bias' in name: + params.append(param) + elif bias == 'lora_only': + if 'lora_' in name: + params.append(param) + bias_name = name.split('lora_')[0] + 'bias' + if bias_name in model.state_dict(): + bias_param = dict(model.named_parameters())[bias_name] + params.append(bias_param) + else: + raise NotImplementedError + return params + + +def apply_lora(args, clip_model): + list_lora_layers = [] + if args.encoder == 'text' or args.encoder == 'both': + indices = INDEX_POSITIONS_TEXT[args.position] + text_encoder = clip_model.transformer + for i, block in enumerate(text_encoder.resblocks): + print(f"Residual Attention Block {i}: {block}") + if i in indices: + for name, submodule in block.named_children(): + if isinstance(submodule, nn.MultiheadAttention): + new_multi_head_lora = PlainMultiheadAttentionLoRA( + submodule, enable_lora=args.params, r=args.r, lora_alpha=args.alpha, dropout_rate=args.dropout_rate) + setattr(block, name, new_multi_head_lora) + list_lora_layers.append(new_multi_head_lora) + + if args.encoder == 'vision' or args.encoder == 'both': + indices = INDEX_POSITIONS_VISION[args.backbone][args.position] + vision_encoder = clip_model.visual.transformer + for i, block in enumerate(vision_encoder.resblocks): + print(f"Residual Attention Block {i}: {block}") + if i in indices: + for name, submodule in block.named_children(): + if isinstance(submodule, nn.MultiheadAttention): + new_multi_head_lora = PlainMultiheadAttentionLoRA( + submodule, enable_lora=args.params, r=args.r, lora_alpha=args.alpha, dropout_rate=args.dropout_rate) + setattr(block, name, new_multi_head_lora) + list_lora_layers.append(new_multi_head_lora) + return list_lora_layers + + +def save_lora(args, list_lora_layers): + weights = {} + for i, layer in enumerate(list_lora_layers): + layer_weights = {} + if 'q' in args.params: + layer_weights['q_proj'] = { + 'w_lora_A': layer.q_proj.w_lora_A.data, + 'w_lora_B': layer.q_proj.w_lora_B.data + } + if 'k' in args.params: + layer_weights['k_proj'] = { + 'w_lora_A': layer.k_proj.w_lora_A.data, + 'w_lora_B': layer.k_proj.w_lora_B.data + } + if 'v' in args.params: + layer_weights['v_proj'] = { + 'w_lora_A': layer.v_proj.w_lora_A.data, + 'w_lora_B': layer.v_proj.w_lora_B.data + } + if 'o' in args.params: + layer_weights['proj'] = { + 'w_lora_A': layer.proj.w_lora_A.data, + 'w_lora_B': layer.proj.w_lora_B.data + } + + weights[f'layer_{i}'] = layer_weights + + metadata = { + 'r': args.r, + 'alpha': args.alpha, + 'encoder': args.encoder, + 'params': args.params, + 'position': args.position + } + + save_data = { + 'weights': weights, + 'metadata': metadata + } + + # to manage names like ViT-B/16 + backbone = args.backbone.replace('/', '').replace('-', '').lower() + save_dir = f'{args.save_path}/{backbone}/{args.dataset}/{args.shots}shots/seed{args.seed}' + os.makedirs(save_dir, exist_ok=True) + + save_path = f'{save_dir}/{args.filename}.pt' + torch.save(save_data, save_path) + print(f'LoRA weights saved to {save_path}') + + +def load_lora(args, list_lora_layers): + # to manage names like ViT-B/16 + backbone = args.backbone.replace('/', '').replace('-', '').lower() + load_path = f'{args.save_path}/{backbone}/{args.dataset}/{args.shots}shots/seed{args.seed}/{args.filename}.pt' + + if not os.path.exists(load_path): + raise FileNotFoundError(f'File {load_path} does not exist.') + + loaded_data = torch.load(load_path) + + metadata = loaded_data['metadata'] + if metadata['r'] != args.r: + raise ValueError( + f"r mismatch: expected {args.r}, found {metadata['r']}") + if metadata['alpha'] != args.alpha: + raise ValueError( + f"alpha mismatch: expected {args.alpha}, found {metadata['alpha']}") + if metadata['encoder'] != args.encoder: + raise ValueError( + f"Encoder mismatch: expected {args.encoder}, found {metadata['encoder']}") + if metadata['params'] != args.params: + raise ValueError( + f"Params mismatch: expected {args.params}, found {metadata['params']}") + if metadata['position'] != args.position: + raise ValueError( + f"Position mismatch: expected {args.position}, found {metadata['position']}") + + weights = loaded_data['weights'] + for i, layer in enumerate(list_lora_layers): + layer_weights = weights[f'layer_{i}'] + if 'q' in args.params and 'q_proj' in layer_weights: + layer.q_proj.w_lora_A.data.copy_( + layer_weights['q_proj']['w_lora_A']) + layer.q_proj.w_lora_B.data.copy_( + layer_weights['q_proj']['w_lora_B']) + if 'k' in args.params and 'k_proj' in layer_weights: + layer.k_proj.w_lora_A.data.copy_( + layer_weights['k_proj']['w_lora_A']) + layer.k_proj.w_lora_B.data.copy_( + layer_weights['k_proj']['w_lora_B']) + if 'v' in args.params and 'v_proj' in layer_weights: + layer.v_proj.w_lora_A.data.copy_( + layer_weights['v_proj']['w_lora_A']) + layer.v_proj.w_lora_B.data.copy_( + layer_weights['v_proj']['w_lora_B']) + if 'o' in args.params and 'proj' in layer_weights: + layer.proj.w_lora_A.data.copy_(layer_weights['proj']['w_lora_A']) + layer.proj.w_lora_B.data.copy_(layer_weights['proj']['w_lora_B']) + + print(f'LoRA weights loaded from {load_path}') \ No newline at end of file diff --git a/model.py b/model.py new file mode 100644 index 0000000000000000000000000000000000000000..e4f8b0428ea733e7b7332d45e0090113ba891113 --- /dev/null +++ b/model.py @@ -0,0 +1,95 @@ +import numpy as np +import os +import torch +import torch.nn.functional as F +from torch import nn +from torchvision.models import alexnet + +import config as c +from freia_funcs import permute_layer, glow_coupling_layer, F_fully_connected, ReversibleGraphNet, OutputNode, \ + InputNode, Node + +WEIGHT_DIR = './weights' +MODEL_DIR = './models' + + +def nf_head(input_dim=c.n_feat): + nodes = list() + nodes.append(InputNode(input_dim, name='input')) + for k in range(c.n_coupling_blocks): + nodes.append(Node([nodes[-1].out0], permute_layer, {'seed': k}, name=F'permute_{k}')) + nodes.append(Node([nodes[-1].out0], glow_coupling_layer, + {'clamp': c.clamp_alpha, 'F_class': F_fully_connected, + 'F_args': {'internal_size': c.fc_internal, 'dropout': c.dropout}}, + name=F'fc_{k}')) + nodes.append(OutputNode([nodes[-1].out0], name='output')) + coder = ReversibleGraphNet(nodes) + return coder + + +class flow_model(nn.Module): + def __init__(self): + super(flow_model, self).__init__() + + self.nf = nf_head(input_dim = 1024) + + def forward(self, x): + z = self.nf(x) + return z + +class flow_model_multi_fc(nn.Module): + def __init__(self): + super(flow_model_multi_fc, self).__init__() + self.fc1 = torch.nn.Linear(1024, 512) + self.relu = torch.nn.LeakyReLU(0.2) + self.fc2 = torch.nn.Linear(512, 256) + + self.nf = nf_head(input_dim = 256) + + def forward(self, x): + res_x = self.fc2(self.relu((self.fc1(x)))) + z = self.nf(res_x) + return z + + +class DifferNet(nn.Module): + def __init__(self): + super(DifferNet, self).__init__() + self.feature_extractor = alexnet(pretrained=True) + self.nf = nf_head() + + def forward(self, x): + y_cat = list() + + for s in range(c.n_scales): + x_scaled = F.interpolate(x, size=c.img_size[0] // (2 ** s)) if s > 0 else x + feat_s = self.feature_extractor.features(x_scaled) + y_cat.append(torch.mean(feat_s, dim=(2, 3))) + + y = torch.cat(y_cat, dim=1) + z = self.nf(y) + return z + + +def save_model(model, filename): + if not os.path.exists(MODEL_DIR): + os.makedirs(MODEL_DIR) + torch.save(model, os.path.join(MODEL_DIR, filename)) + + +def load_model(filename): + path = os.path.join(MODEL_DIR, filename) + model = torch.load(path) + return model + + +def save_weights(model, filename): + if not os.path.exists(WEIGHT_DIR): + os.makedirs(WEIGHT_DIR) + torch.save(model.state_dict(), os.path.join(WEIGHT_DIR, filename)) + + +def load_weights(model, filename): + path = os.path.join(WEIGHT_DIR, filename) + model.load_state_dict(torch.load(path)) + return model diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d7b790b40d29d1f3bf02f398f3522eea8e4c2c22 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,43 @@ +from .clip_models import CLIPModel +from .imagenet_models import ImagenetModel + + +VALID_NAMES = [ + 'Imagenet:resnet18', + 'Imagenet:resnet34', + 'Imagenet:resnet50', + 'Imagenet:resnet101', + 'Imagenet:resnet152', + 'Imagenet:vgg11', + 'Imagenet:vgg19', + 'Imagenet:swin-b', + 'Imagenet:swin-s', + 'Imagenet:swin-t', + 'Imagenet:vit_b_16', + 'Imagenet:vit_b_32', + 'Imagenet:vit_l_16', + 'Imagenet:vit_l_32', + + 'CLIP:RN50', + 'CLIP:RN101', + 'CLIP:RN50x4', + 'CLIP:RN50x16', + 'CLIP:RN50x64', + 'CLIP:ViT-B/32', + 'CLIP:ViT-B/16', + 'CLIP:ViT-L/14', + 'CLIP:ViT-L/14@336px', +] + + + + + +def get_model(name): + assert name in VALID_NAMES + if name.startswith("Imagenet:"): + return ImagenetModel(name[9:]) + elif name.startswith("CLIP:"): + return CLIPModel(name[5:]) + else: + assert False diff --git a/models/__pycache__/__init__.cpython-38.pyc b/models/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07ac870068e357bfe4b1310976a7dcfb01e8ddf3 Binary files /dev/null and b/models/__pycache__/__init__.cpython-38.pyc differ diff --git a/models/__pycache__/clip_models.cpython-38.pyc b/models/__pycache__/clip_models.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..655b0b445b1524d858ae9e002c8f311d9e29d5e3 Binary files /dev/null and b/models/__pycache__/clip_models.cpython-38.pyc differ diff --git a/models/__pycache__/imagenet_models.cpython-38.pyc b/models/__pycache__/imagenet_models.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1d18d9093315877d2685867d28590b11781e253 Binary files /dev/null and b/models/__pycache__/imagenet_models.cpython-38.pyc differ diff --git a/models/__pycache__/resnet.cpython-38.pyc b/models/__pycache__/resnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbedf73838bdda53445eb239ebe5d1b6c61db147 Binary files /dev/null and b/models/__pycache__/resnet.cpython-38.pyc differ diff --git a/models/__pycache__/vision_transformer.cpython-38.pyc b/models/__pycache__/vision_transformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe4f0dc5a6c5dc7a99859291b06b0c3bb7dedcee Binary files /dev/null and b/models/__pycache__/vision_transformer.cpython-38.pyc differ diff --git a/models/__pycache__/vision_transformer_misc.cpython-38.pyc b/models/__pycache__/vision_transformer_misc.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..792024c3b8bfc87a8f03570be4b0b2559b282e5c Binary files /dev/null and b/models/__pycache__/vision_transformer_misc.cpython-38.pyc differ diff --git a/models/__pycache__/vision_transformer_utils.cpython-38.pyc b/models/__pycache__/vision_transformer_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..209f0bdea754a8cc500ef7d92b47b2292d18c6e8 Binary files /dev/null and b/models/__pycache__/vision_transformer_utils.cpython-38.pyc differ diff --git a/models/clip/__init__.py b/models/clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dcc5619538c0f7c782508bdbd9587259d805e0d9 --- /dev/null +++ b/models/clip/__init__.py @@ -0,0 +1 @@ +from .clip import * diff --git a/models/clip/__pycache__/__init__.cpython-310.pyc b/models/clip/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..432c7bbc639332fd8a9ca8c918e5abf60dc0f6e7 Binary files /dev/null and b/models/clip/__pycache__/__init__.cpython-310.pyc differ diff --git a/models/clip/__pycache__/__init__.cpython-38.pyc b/models/clip/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f855b83e6066e19aafd7d94f215da5878590a644 Binary files /dev/null and b/models/clip/__pycache__/__init__.cpython-38.pyc differ diff --git a/models/clip/__pycache__/__init__.cpython-39.pyc b/models/clip/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88a3e863c93116cbddc25c6bc982b704c49876a8 Binary files /dev/null and b/models/clip/__pycache__/__init__.cpython-39.pyc differ diff --git a/models/clip/__pycache__/clip.cpython-310.pyc b/models/clip/__pycache__/clip.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f839e4303bbe911212f80a5ff27855cfa089f8b Binary files /dev/null and b/models/clip/__pycache__/clip.cpython-310.pyc differ diff --git a/models/clip/__pycache__/clip.cpython-38.pyc b/models/clip/__pycache__/clip.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55a5105afa39b10c57b5978c299343293765532e Binary files /dev/null and b/models/clip/__pycache__/clip.cpython-38.pyc differ diff --git a/models/clip/__pycache__/clip.cpython-39.pyc b/models/clip/__pycache__/clip.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b02b7c686dc8e91d7709ff9bd8f1135d2f880601 Binary files /dev/null and b/models/clip/__pycache__/clip.cpython-39.pyc differ diff --git a/models/clip/__pycache__/model.cpython-310.pyc b/models/clip/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fe5ed8576243c0dd4feb59baaf9af5395fd21fe Binary files /dev/null and b/models/clip/__pycache__/model.cpython-310.pyc differ diff --git a/models/clip/__pycache__/model.cpython-38.pyc b/models/clip/__pycache__/model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bdd966af24881708d3f2ee89152220a606e9d234 Binary files /dev/null and b/models/clip/__pycache__/model.cpython-38.pyc differ diff --git a/models/clip/__pycache__/model.cpython-39.pyc b/models/clip/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9508370bcab3e891fb9164af0c5628e1577bb7b4 Binary files /dev/null and b/models/clip/__pycache__/model.cpython-39.pyc differ diff --git a/models/clip/__pycache__/simple_tokenizer.cpython-310.pyc b/models/clip/__pycache__/simple_tokenizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e05bfbaf5118749c7a574133c43e3995ef57951a Binary files /dev/null and b/models/clip/__pycache__/simple_tokenizer.cpython-310.pyc differ diff --git a/models/clip/__pycache__/simple_tokenizer.cpython-38.pyc b/models/clip/__pycache__/simple_tokenizer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eeca6cfbb0bef561f84ff453c3001b0822bcfabd Binary files /dev/null and b/models/clip/__pycache__/simple_tokenizer.cpython-38.pyc differ diff --git a/models/clip/__pycache__/simple_tokenizer.cpython-39.pyc b/models/clip/__pycache__/simple_tokenizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31351f58c8b617d682cdcc547e745e6fbc429ba9 Binary files /dev/null and b/models/clip/__pycache__/simple_tokenizer.cpython-39.pyc differ diff --git a/models/clip/bpe_simple_vocab_16e6.txt.gz b/models/clip/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/models/clip/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/models/clip/clip.py b/models/clip/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..257511e1d40c120e0d64a0f1562d44b2b8a40a17 --- /dev/null +++ b/models/clip/clip.py @@ -0,0 +1,237 @@ +import hashlib +import os +import urllib +import warnings +from typing import Any, Union, List +from pkg_resources import packaging + +import torch +from PIL import Image +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from tqdm import tqdm + +from .model import build_model +from .simple_tokenizer import SimpleTokenizer as _Tokenizer + +try: + from torchvision.transforms import InterpolationMode + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + + +if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): + warnings.warn("PyTorch version 1.7.1 or higher is recommended") + + +__all__ = ["available_models", "load", "tokenize"] +_tokenizer = _Tokenizer() + +_MODELS = { + "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", + "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", + "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", + "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", + "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", + "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", +} + + +def _download(url: str, root: str): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: + raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def _convert_image_to_rgb(image): + return image.convert("RGB") + + +def _transform(n_px): + return Compose([ + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + _convert_image_to_rgb, + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + +def available_models() -> List[str]: + """Returns the names of available CLIP models""" + return list(_MODELS.keys()) + + +def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + + device : Union[str, torch.device] + The device to put the loaded model + + jit : bool + Whether to load the optimized JIT model or more hackable non-JIT model (default). + + download_root: str + path to download the model files; by default, it uses "~/.cache/clip" + + Returns + ------- + model : torch.nn.Module + The CLIP model + + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if name in _MODELS: + model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + + with open(model_path, 'rb') as opened_file: + try: + # loading JIT archive + model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(opened_file, map_location="cpu") + + if not jit: + model = build_model(state_dict or model.state_dict()).to(device) + if str(device) == "cpu": + model.float() + return model, _transform(model.visual.input_resolution) + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + + return model, _transform(model.input_resolution.item()) + + +def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + + context_length : int + The context length to use; all CLIP models use 77 as the context length + + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. + We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder["<|startoftext|>"] + eot_token = _tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + else: + result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + result[i, :len(tokens)] = torch.tensor(tokens) + + return result diff --git a/models/clip/model.py b/models/clip/model.py new file mode 100644 index 0000000000000000000000000000000000000000..df803b4a41403eb92a89d249ba141f95ec7d89eb --- /dev/null +++ b/models/clip/model.py @@ -0,0 +1,454 @@ +from collections import OrderedDict +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + return x.squeeze(0) + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.relu3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + def stem(x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + def forward(self, x: torch.Tensor): + out = {} + for idx, layer in enumerate(self.resblocks.children()): + x = layer(x) + out['layer'+str(idx)] = x[0] # shape:LND. choose cls token feature + return out, x + + # return self.resblocks(x) # This is the original code + + +class VisionTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + out, x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + + out['before_projection'] = x + # print(x.shape) + + # if self.proj is not None: + # x = x @ self.proj + out['after_projection'] = x + # print(x.shape) + + # Return both intermediate features and final clip feature + # return out + + # This only returns CLIP features + return x + + +class CLIP(nn.Module): + def __init__(self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int + ): + super().__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width + ) + else: + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask() + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image): + return self.visual(image.type(self.dtype)) + + def encode_text(self, text): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm(dim=1, keepdim=True) + text_features = text_features / text_features.norm(dim=1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +def build_model(state_dict: dict): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + else: + counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) + + model = CLIP( + embed_dim, + image_resolution, vision_layers, vision_width, vision_patch_size, + context_length, vocab_size, transformer_width, transformer_heads, transformer_layers + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + + convert_weights(model) + model.load_state_dict(state_dict) + return model.eval() diff --git a/models/clip/simple_tokenizer.py b/models/clip/simple_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0a66286b7d5019c6e221932a813768038f839c91 --- /dev/null +++ b/models/clip/simple_tokenizer.py @@ -0,0 +1,132 @@ +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text diff --git a/models/clip_models.py b/models/clip_models.py new file mode 100644 index 0000000000000000000000000000000000000000..292c990e411c0be5e03f25d1c0f60966e92364fa --- /dev/null +++ b/models/clip_models.py @@ -0,0 +1,25 @@ +from .clip import clip +from PIL import Image +import torch.nn as nn + + +CHANNELS = { + "RN50" : 1024, + "ViT-L/14" : 768 +} + +class CLIPModel(nn.Module): + def __init__(self, name, num_classes=1): + super(CLIPModel, self).__init__() + + self.model, self.preprocess = clip.load(name, device="cpu") # self.preprecess will not be used during training, which is handled in Dataset class + self.fc = nn.Linear( CHANNELS[name], num_classes ) + + + def forward(self, x, return_feature=False): + features = self.model.encode_image(x) + # print(features.shape) + if return_feature: + return features + return self.fc(features) + diff --git a/models/imagenet_models.py b/models/imagenet_models.py new file mode 100644 index 0000000000000000000000000000000000000000..20a40b916793d926c915aa2f62602651613fec04 --- /dev/null +++ b/models/imagenet_models.py @@ -0,0 +1,40 @@ +from .resnet import resnet18, resnet34, resnet50, resnet101, resnet152 +from .vision_transformer import vit_b_16, vit_b_32, vit_l_16, vit_l_32 + +from torchvision import transforms +from PIL import Image +import torch +import torch.nn as nn + + +model_dict = { + 'resnet18': resnet18, + 'resnet34': resnet34, + 'resnet50': resnet50, + 'resnet101': resnet101, + 'resnet152': resnet152, + 'vit_b_16': vit_b_16, + 'vit_b_32': vit_b_32, + 'vit_l_16': vit_l_16, + 'vit_l_32': vit_l_32 +} + + +CHANNELS = { + "resnet50" : 2048, + "vit_b_16" : 768, +} + + + +class ImagenetModel(nn.Module): + def __init__(self, name, num_classes=1): + super(ImagenetModel, self).__init__() + + self.model = model_dict[name](pretrained=True) + self.fc = nn.Linear(CHANNELS[name], num_classes) #manually define a fc layer here + + + def forward(self, x): + feature = self.model(x)["penultimate"] + return self.fc(feature) diff --git a/models/resnet.py b/models/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..a78e3d65e263cb9dbd1afa0e1a88dba9f5ddd164 --- /dev/null +++ b/models/resnet.py @@ -0,0 +1,337 @@ +import torch +from torch import Tensor +import torch.nn as nn +from typing import Type, Any, Callable, Union, List, Optional + +try: + from torch.hub import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth', + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', + 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', + 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', +} + + + + +def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion: int = 4 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__( + self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + num_classes: int = 1000, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] + + def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, + stride: int = 1, dilate: bool = False) -> nn.Sequential: + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def _forward_impl(self, x): + # The comment resolution is based on input size is 224*224 imagenet + out = {} + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + out['f0'] = x # N*64*56*56 + + x = self.layer1(x) + out['f1'] = x # N*64*56*56 + + x = self.layer2(x) + out['f2'] = x # N*128*28*28 + + x = self.layer3(x) + out['f3'] = x # N*256*14*14 + + x = self.layer4(x) + out['f4'] = x # N*512*7*7 + + x = self.avgpool(x) + x = torch.flatten(x, 1) + out['penultimate'] = x # N*512 + + x = self.fc(x) + out['logits'] = x # N*1000 + + # return all features + return out + + # return final classification result + # return x + + def forward(self, x): + return self._forward_impl(x) + + +def _resnet( + arch: str, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + pretrained: bool, + progress: bool, + **kwargs: Any +) -> ResNet: + model = ResNet(block, layers, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + model.load_state_dict(state_dict) + return model + + +def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) + + +def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs) + + +def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) + + +def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-101 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) + + +def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-152 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs) + diff --git a/models/vgg.py b/models/vgg.py new file mode 100644 index 0000000000000000000000000000000000000000..a30a1df18a64f9ab2ca309b264cd4e8409b0cf64 --- /dev/null +++ b/models/vgg.py @@ -0,0 +1,120 @@ +import torch +import torch.nn as nn +from typing import Union, List, Dict, Any, cast +import torchvision +import torch.nn.functional as F + + + + + +class VGG(torch.nn.Module): + def __init__(self, arch_type, pretrained, progress): + super().__init__() + + self.layer1 = torch.nn.Sequential() + self.layer2 = torch.nn.Sequential() + self.layer3 = torch.nn.Sequential() + self.layer4 = torch.nn.Sequential() + self.layer5 = torch.nn.Sequential() + + if arch_type == 'vgg11': + official_vgg = torchvision.models.vgg11(pretrained=pretrained, progress=progress) + blocks = [ [0,2], [2,5], [5,10], [10,15], [15,20] ] + last_idx = 20 + elif arch_type == 'vgg19': + official_vgg = torchvision.models.vgg19(pretrained=pretrained, progress=progress) + blocks = [ [0,4], [4,9], [9,18], [18,27], [27,36] ] + last_idx = 36 + else: + raise NotImplementedError + + + for x in range( *blocks[0] ): + self.layer1.add_module(str(x), official_vgg.features[x]) + for x in range( *blocks[1] ): + self.layer2.add_module(str(x), official_vgg.features[x]) + for x in range( *blocks[2] ): + self.layer3.add_module(str(x), official_vgg.features[x]) + for x in range( *blocks[3] ): + self.layer4.add_module(str(x), official_vgg.features[x]) + for x in range( *blocks[4] ): + self.layer5.add_module(str(x), official_vgg.features[x]) + + self.max_pool = official_vgg.features[last_idx] + self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) + + self.fc1 = official_vgg.classifier[0] + self.fc2 = official_vgg.classifier[3] + self.fc3 = official_vgg.classifier[6] + self.dropout = nn.Dropout() + + + def forward(self, x): + out = {} + + x = self.layer1(x) + out['f0'] = x + + x = self.layer2(x) + out['f1'] = x + + x = self.layer3(x) + out['f2'] = x + + x = self.layer4(x) + out['f3'] = x + + x = self.layer5(x) + out['f4'] = x + + x = self.max_pool(x) + x = self.avgpool(x) + x = x.view(-1,512*7*7) + + x = self.fc1(x) + x = F.relu(x) + x = self.dropout(x) + x = self.fc2(x) + x = F.relu(x) + out['penultimate'] = x + x = self.dropout(x) + x = self.fc3(x) + out['logits'] = x + + return out + + + + + + + + + + +def vgg11(pretrained=False, progress=True): + r"""VGG 11-layer model (configuration "A") from + `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return VGG('vgg11', pretrained, progress) + + + +def vgg19(pretrained=False, progress=True): + r"""VGG 19-layer model (configuration "E") + `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return VGG('vgg19', pretrained, progress) + + + + diff --git a/models/vision_transformer.py b/models/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..618e9626ca43f1afdb3419e19be11f3a3048f81e --- /dev/null +++ b/models/vision_transformer.py @@ -0,0 +1,481 @@ +import math +from collections import OrderedDict +from functools import partial +from typing import Any, Callable, List, NamedTuple, Optional + +import torch +import torch.nn as nn + +# from .._internally_replaced_utils import load_state_dict_from_url +from .vision_transformer_misc import ConvNormActivation +from .vision_transformer_utils import _log_api_usage_once + +try: + from torch.hub import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url + +# __all__ = [ +# "VisionTransformer", +# "vit_b_16", +# "vit_b_32", +# "vit_l_16", +# "vit_l_32", +# ] + +model_urls = { + "vit_b_16": "https://download.pytorch.org/models/vit_b_16-c867db91.pth", + "vit_b_32": "https://download.pytorch.org/models/vit_b_32-d86f8d99.pth", + "vit_l_16": "https://download.pytorch.org/models/vit_l_16-852ce7e3.pth", + "vit_l_32": "https://download.pytorch.org/models/vit_l_32-c7638314.pth", +} + + +class ConvStemConfig(NamedTuple): + out_channels: int + kernel_size: int + stride: int + norm_layer: Callable[..., nn.Module] = nn.BatchNorm2d + activation_layer: Callable[..., nn.Module] = nn.ReLU + + +class MLPBlock(nn.Sequential): + """Transformer MLP block.""" + + def __init__(self, in_dim: int, mlp_dim: int, dropout: float): + super().__init__() + self.linear_1 = nn.Linear(in_dim, mlp_dim) + self.act = nn.GELU() + self.dropout_1 = nn.Dropout(dropout) + self.linear_2 = nn.Linear(mlp_dim, in_dim) + self.dropout_2 = nn.Dropout(dropout) + + nn.init.xavier_uniform_(self.linear_1.weight) + nn.init.xavier_uniform_(self.linear_2.weight) + nn.init.normal_(self.linear_1.bias, std=1e-6) + nn.init.normal_(self.linear_2.bias, std=1e-6) + + +class EncoderBlock(nn.Module): + """Transformer encoder block.""" + + def __init__( + self, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + dropout: float, + attention_dropout: float, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + ): + super().__init__() + self.num_heads = num_heads + + # Attention block + self.ln_1 = norm_layer(hidden_dim) + self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True) + self.dropout = nn.Dropout(dropout) + + # MLP block + self.ln_2 = norm_layer(hidden_dim) + self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout) + + def forward(self, input: torch.Tensor): + torch._assert(input.dim() == 3, f"Expected (seq_length, batch_size, hidden_dim) got {input.shape}") + x = self.ln_1(input) + x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False) + x = self.dropout(x) + x = x + input + + y = self.ln_2(x) + y = self.mlp(y) + return x + y + + +class Encoder(nn.Module): + """Transformer Model Encoder for sequence to sequence translation.""" + + def __init__( + self, + seq_length: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + dropout: float, + attention_dropout: float, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + ): + super().__init__() + # Note that batch_size is on the first dim because + # we have batch_first=True in nn.MultiAttention() by default + self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT + self.dropout = nn.Dropout(dropout) + layers: OrderedDict[str, nn.Module] = OrderedDict() + for i in range(num_layers): + layers[f"encoder_layer_{i}"] = EncoderBlock( + num_heads, + hidden_dim, + mlp_dim, + dropout, + attention_dropout, + norm_layer, + ) + self.layers = nn.Sequential(layers) + self.ln = norm_layer(hidden_dim) + + def forward(self, input: torch.Tensor): + torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") + input = input + self.pos_embedding + return self.ln(self.layers(self.dropout(input))) + + +class VisionTransformer(nn.Module): + """Vision Transformer as per https://arxiv.org/abs/2010.11929.""" + + def __init__( + self, + image_size: int, + patch_size: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + dropout: float = 0.0, + attention_dropout: float = 0.0, + num_classes: int = 1000, + representation_size: Optional[int] = None, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + conv_stem_configs: Optional[List[ConvStemConfig]] = None, + ): + super().__init__() + _log_api_usage_once(self) + torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!") + self.image_size = image_size + self.patch_size = patch_size + self.hidden_dim = hidden_dim + self.mlp_dim = mlp_dim + self.attention_dropout = attention_dropout + self.dropout = dropout + self.num_classes = num_classes + self.representation_size = representation_size + self.norm_layer = norm_layer + + if conv_stem_configs is not None: + # As per https://arxiv.org/abs/2106.14881 + seq_proj = nn.Sequential() + prev_channels = 3 + for i, conv_stem_layer_config in enumerate(conv_stem_configs): + seq_proj.add_module( + f"conv_bn_relu_{i}", + ConvNormActivation( + in_channels=prev_channels, + out_channels=conv_stem_layer_config.out_channels, + kernel_size=conv_stem_layer_config.kernel_size, + stride=conv_stem_layer_config.stride, + norm_layer=conv_stem_layer_config.norm_layer, + activation_layer=conv_stem_layer_config.activation_layer, + ), + ) + prev_channels = conv_stem_layer_config.out_channels + seq_proj.add_module( + "conv_last", nn.Conv2d(in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1) + ) + self.conv_proj: nn.Module = seq_proj + else: + self.conv_proj = nn.Conv2d( + in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size + ) + + seq_length = (image_size // patch_size) ** 2 + + # Add a class token + self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim)) + seq_length += 1 + + self.encoder = Encoder( + seq_length, + num_layers, + num_heads, + hidden_dim, + mlp_dim, + dropout, + attention_dropout, + norm_layer, + ) + self.seq_length = seq_length + + heads_layers: OrderedDict[str, nn.Module] = OrderedDict() + if representation_size is None: + heads_layers["head"] = nn.Linear(hidden_dim, num_classes) + else: + heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size) + heads_layers["act"] = nn.Tanh() + heads_layers["head"] = nn.Linear(representation_size, num_classes) + + self.heads = nn.Sequential(heads_layers) + + if isinstance(self.conv_proj, nn.Conv2d): + # Init the patchify stem + fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1] + nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in)) + if self.conv_proj.bias is not None: + nn.init.zeros_(self.conv_proj.bias) + elif self.conv_proj.conv_last is not None and isinstance(self.conv_proj.conv_last, nn.Conv2d): + # Init the last 1x1 conv of the conv stem + nn.init.normal_( + self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels) + ) + if self.conv_proj.conv_last.bias is not None: + nn.init.zeros_(self.conv_proj.conv_last.bias) + + if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear): + fan_in = self.heads.pre_logits.in_features + nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in)) + nn.init.zeros_(self.heads.pre_logits.bias) + + if isinstance(self.heads.head, nn.Linear): + nn.init.zeros_(self.heads.head.weight) + nn.init.zeros_(self.heads.head.bias) + + def _process_input(self, x: torch.Tensor) -> torch.Tensor: + n, c, h, w = x.shape + p = self.patch_size + torch._assert(h == self.image_size, "Wrong image height!") + torch._assert(w == self.image_size, "Wrong image width!") + n_h = h // p + n_w = w // p + + # (n, c, h, w) -> (n, hidden_dim, n_h, n_w) + x = self.conv_proj(x) + # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w)) + x = x.reshape(n, self.hidden_dim, n_h * n_w) + + # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim) + # The self attention layer expects inputs in the format (N, S, E) + # where S is the source sequence length, N is the batch size, E is the + # embedding dimension + x = x.permute(0, 2, 1) + + return x + + def forward(self, x: torch.Tensor): + out = {} + + # Reshape and permute the input tensor + x = self._process_input(x) + n = x.shape[0] + + # Expand the class token to the full batch + batch_class_token = self.class_token.expand(n, -1, -1) + x = torch.cat([batch_class_token, x], dim=1) + + + x = self.encoder(x) + img_feature = x[:,1:] + H = W = int(self.image_size / self.patch_size) + out['f4'] = img_feature.view(n, H, W, self.hidden_dim).permute(0,3,1,2) + + # Classifier "token" as used by standard language architectures + x = x[:, 0] + out['penultimate'] = x + + x = self.heads(x) # I checked that for all pretrained ViT, this is just a fc + out['logits'] = x + + return out + + +def _vision_transformer( + arch: str, + patch_size: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + pretrained: bool, + progress: bool, + **kwargs: Any, +) -> VisionTransformer: + image_size = kwargs.pop("image_size", 224) + + model = VisionTransformer( + image_size=image_size, + patch_size=patch_size, + num_layers=num_layers, + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + **kwargs, + ) + + if pretrained: + if arch not in model_urls: + raise ValueError(f"No checkpoint is available for model type '{arch}'!") + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + model.load_state_dict(state_dict) + + return model + + +def vit_b_16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: + """ + Constructs a vit_b_16 architecture from + `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _vision_transformer( + arch="vit_b_16", + patch_size=16, + num_layers=12, + num_heads=12, + hidden_dim=768, + mlp_dim=3072, + pretrained=pretrained, + progress=progress, + **kwargs, + ) + + +def vit_b_32(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: + """ + Constructs a vit_b_32 architecture from + `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _vision_transformer( + arch="vit_b_32", + patch_size=32, + num_layers=12, + num_heads=12, + hidden_dim=768, + mlp_dim=3072, + pretrained=pretrained, + progress=progress, + **kwargs, + ) + + +def vit_l_16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: + """ + Constructs a vit_l_16 architecture from + `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _vision_transformer( + arch="vit_l_16", + patch_size=16, + num_layers=24, + num_heads=16, + hidden_dim=1024, + mlp_dim=4096, + pretrained=pretrained, + progress=progress, + **kwargs, + ) + + +def vit_l_32(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: + """ + Constructs a vit_l_32 architecture from + `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _vision_transformer( + arch="vit_l_32", + patch_size=32, + num_layers=24, + num_heads=16, + hidden_dim=1024, + mlp_dim=4096, + pretrained=pretrained, + progress=progress, + **kwargs, + ) + + +def interpolate_embeddings( + image_size: int, + patch_size: int, + model_state: "OrderedDict[str, torch.Tensor]", + interpolation_mode: str = "bicubic", + reset_heads: bool = False, +) -> "OrderedDict[str, torch.Tensor]": + """This function helps interpolating positional embeddings during checkpoint loading, + especially when you want to apply a pre-trained model on images with different resolution. + + Args: + image_size (int): Image size of the new model. + patch_size (int): Patch size of the new model. + model_state (OrderedDict[str, torch.Tensor]): State dict of the pre-trained model. + interpolation_mode (str): The algorithm used for upsampling. Default: bicubic. + reset_heads (bool): If true, not copying the state of heads. Default: False. + + Returns: + OrderedDict[str, torch.Tensor]: A state dict which can be loaded into the new model. + """ + # Shape of pos_embedding is (1, seq_length, hidden_dim) + pos_embedding = model_state["encoder.pos_embedding"] + n, seq_length, hidden_dim = pos_embedding.shape + if n != 1: + raise ValueError(f"Unexpected position embedding shape: {pos_embedding.shape}") + + new_seq_length = (image_size // patch_size) ** 2 + 1 + + # Need to interpolate the weights for the position embedding. + # We do this by reshaping the positions embeddings to a 2d grid, performing + # an interpolation in the (h, w) space and then reshaping back to a 1d grid. + if new_seq_length != seq_length: + # The class token embedding shouldn't be interpolated so we split it up. + seq_length -= 1 + new_seq_length -= 1 + pos_embedding_token = pos_embedding[:, :1, :] + pos_embedding_img = pos_embedding[:, 1:, :] + + # (1, seq_length, hidden_dim) -> (1, hidden_dim, seq_length) + pos_embedding_img = pos_embedding_img.permute(0, 2, 1) + seq_length_1d = int(math.sqrt(seq_length)) + torch._assert(seq_length_1d * seq_length_1d == seq_length, "seq_length is not a perfect square!") + + # (1, hidden_dim, seq_length) -> (1, hidden_dim, seq_l_1d, seq_l_1d) + pos_embedding_img = pos_embedding_img.reshape(1, hidden_dim, seq_length_1d, seq_length_1d) + new_seq_length_1d = image_size // patch_size + + # Perform interpolation. + # (1, hidden_dim, seq_l_1d, seq_l_1d) -> (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) + new_pos_embedding_img = nn.functional.interpolate( + pos_embedding_img, + size=new_seq_length_1d, + mode=interpolation_mode, + align_corners=True, + ) + + # (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) -> (1, hidden_dim, new_seq_length) + new_pos_embedding_img = new_pos_embedding_img.reshape(1, hidden_dim, new_seq_length) + + # (1, hidden_dim, new_seq_length) -> (1, new_seq_length, hidden_dim) + new_pos_embedding_img = new_pos_embedding_img.permute(0, 2, 1) + new_pos_embedding = torch.cat([pos_embedding_token, new_pos_embedding_img], dim=1) + + model_state["encoder.pos_embedding"] = new_pos_embedding + + if reset_heads: + model_state_copy: "OrderedDict[str, torch.Tensor]" = OrderedDict() + for k, v in model_state.items(): + if not k.startswith("heads"): + model_state_copy[k] = v + model_state = model_state_copy + + return model_state diff --git a/models/vision_transformer_misc.py b/models/vision_transformer_misc.py new file mode 100644 index 0000000000000000000000000000000000000000..7915f036c00f0d9c57c176e621afc9f1e69dcb30 --- /dev/null +++ b/models/vision_transformer_misc.py @@ -0,0 +1,163 @@ +from typing import Callable, List, Optional + +import torch +from torch import Tensor + +from .vision_transformer_utils import _log_api_usage_once + + +interpolate = torch.nn.functional.interpolate + + +# This is not in nn +class FrozenBatchNorm2d(torch.nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed + + Args: + num_features (int): Number of features ``C`` from an expected input of size ``(N, C, H, W)`` + eps (float): a value added to the denominator for numerical stability. Default: 1e-5 + """ + + def __init__( + self, + num_features: int, + eps: float = 1e-5, + ): + super().__init__() + _log_api_usage_once(self) + self.eps = eps + self.register_buffer("weight", torch.ones(num_features)) + self.register_buffer("bias", torch.zeros(num_features)) + self.register_buffer("running_mean", torch.zeros(num_features)) + self.register_buffer("running_var", torch.ones(num_features)) + + def _load_from_state_dict( + self, + state_dict: dict, + prefix: str, + local_metadata: dict, + strict: bool, + missing_keys: List[str], + unexpected_keys: List[str], + error_msgs: List[str], + ): + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def forward(self, x: Tensor) -> Tensor: + # move reshapes to the beginning + # to make it fuser-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + scale = w * (rv + self.eps).rsqrt() + bias = b - rm * scale + return x * scale + bias + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps})" + + +class ConvNormActivation(torch.nn.Sequential): + """ + Configurable block used for Convolution-Normalzation-Activation blocks. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the Convolution-Normalzation-Activation block + kernel_size: (int, optional): Size of the convolving kernel. Default: 3 + stride (int, optional): Stride of the convolution. Default: 1 + padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in wich case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation`` + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolutiuon layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm2d`` + activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU`` + dilation (int): Spacing between kernel elements. Default: 1 + inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True`` + bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``. + + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + padding: Optional[int] = None, + groups: int = 1, + norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d, + activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, + dilation: int = 1, + inplace: Optional[bool] = True, + bias: Optional[bool] = None, + ) -> None: + if padding is None: + padding = (kernel_size - 1) // 2 * dilation + if bias is None: + bias = norm_layer is None + layers = [ + torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + ] + if norm_layer is not None: + layers.append(norm_layer(out_channels)) + if activation_layer is not None: + params = {} if inplace is None else {"inplace": inplace} + layers.append(activation_layer(**params)) + super().__init__(*layers) + _log_api_usage_once(self) + self.out_channels = out_channels + + +class SqueezeExcitation(torch.nn.Module): + """ + This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507 (see Fig. 1). + Parameters ``activation``, and ``scale_activation`` correspond to ``delta`` and ``sigma`` in in eq. 3. + + Args: + input_channels (int): Number of channels in the input image + squeeze_channels (int): Number of squeeze channels + activation (Callable[..., torch.nn.Module], optional): ``delta`` activation. Default: ``torch.nn.ReLU`` + scale_activation (Callable[..., torch.nn.Module]): ``sigma`` activation. Default: ``torch.nn.Sigmoid`` + """ + + def __init__( + self, + input_channels: int, + squeeze_channels: int, + activation: Callable[..., torch.nn.Module] = torch.nn.ReLU, + scale_activation: Callable[..., torch.nn.Module] = torch.nn.Sigmoid, + ) -> None: + super().__init__() + _log_api_usage_once(self) + self.avgpool = torch.nn.AdaptiveAvgPool2d(1) + self.fc1 = torch.nn.Conv2d(input_channels, squeeze_channels, 1) + self.fc2 = torch.nn.Conv2d(squeeze_channels, input_channels, 1) + self.activation = activation() + self.scale_activation = scale_activation() + + def _scale(self, input: Tensor) -> Tensor: + scale = self.avgpool(input) + scale = self.fc1(scale) + scale = self.activation(scale) + scale = self.fc2(scale) + return self.scale_activation(scale) + + def forward(self, input: Tensor) -> Tensor: + scale = self._scale(input) + return scale * input diff --git a/models/vision_transformer_utils.py b/models/vision_transformer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6d3293d103d0e186a1244e7cc0c6e3bde63d1df3 --- /dev/null +++ b/models/vision_transformer_utils.py @@ -0,0 +1,549 @@ +import math +import pathlib +import warnings +from types import FunctionType +from typing import Any, BinaryIO, List, Optional, Tuple, Union + +import numpy as np +import torch +from PIL import Image, ImageColor, ImageDraw, ImageFont + +__all__ = [ + "make_grid", + "save_image", + "draw_bounding_boxes", + "draw_segmentation_masks", + "draw_keypoints", + "flow_to_image", +] + + +@torch.no_grad() +def make_grid( + tensor: Union[torch.Tensor, List[torch.Tensor]], + nrow: int = 8, + padding: int = 2, + normalize: bool = False, + value_range: Optional[Tuple[int, int]] = None, + scale_each: bool = False, + pad_value: float = 0.0, + **kwargs, +) -> torch.Tensor: + """ + Make a grid of images. + + Args: + tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W) + or a list of images all of the same size. + nrow (int, optional): Number of images displayed in each row of the grid. + The final grid size is ``(B / nrow, nrow)``. Default: ``8``. + padding (int, optional): amount of padding. Default: ``2``. + normalize (bool, optional): If True, shift the image to the range (0, 1), + by the min and max values specified by ``value_range``. Default: ``False``. + value_range (tuple, optional): tuple (min, max) where min and max are numbers, + then these numbers are used to normalize the image. By default, min and max + are computed from the tensor. + range (tuple. optional): + .. warning:: + This parameter was deprecated in ``0.12`` and will be removed in ``0.14``. Please use ``value_range`` + instead. + scale_each (bool, optional): If ``True``, scale each image in the batch of + images separately rather than the (min, max) over all images. Default: ``False``. + pad_value (float, optional): Value for the padded pixels. Default: ``0``. + + Returns: + grid (Tensor): the tensor containing grid of images. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(make_grid) + if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): + raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}") + + if "range" in kwargs.keys(): + warnings.warn( + "The parameter 'range' is deprecated since 0.12 and will be removed in 0.14. " + "Please use 'value_range' instead." + ) + value_range = kwargs["range"] + + # if list of tensors, convert to a 4D mini-batch Tensor + if isinstance(tensor, list): + tensor = torch.stack(tensor, dim=0) + + if tensor.dim() == 2: # single image H x W + tensor = tensor.unsqueeze(0) + if tensor.dim() == 3: # single image + if tensor.size(0) == 1: # if single-channel, convert to 3-channel + tensor = torch.cat((tensor, tensor, tensor), 0) + tensor = tensor.unsqueeze(0) + + if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images + tensor = torch.cat((tensor, tensor, tensor), 1) + + if normalize is True: + tensor = tensor.clone() # avoid modifying tensor in-place + if value_range is not None: + assert isinstance( + value_range, tuple + ), "value_range has to be a tuple (min, max) if specified. min and max are numbers" + + def norm_ip(img, low, high): + img.clamp_(min=low, max=high) + img.sub_(low).div_(max(high - low, 1e-5)) + + def norm_range(t, value_range): + if value_range is not None: + norm_ip(t, value_range[0], value_range[1]) + else: + norm_ip(t, float(t.min()), float(t.max())) + + if scale_each is True: + for t in tensor: # loop over mini-batch dimension + norm_range(t, value_range) + else: + norm_range(tensor, value_range) + + assert isinstance(tensor, torch.Tensor) + if tensor.size(0) == 1: + return tensor.squeeze(0) + + # make the mini-batch of images into a grid + nmaps = tensor.size(0) + xmaps = min(nrow, nmaps) + ymaps = int(math.ceil(float(nmaps) / xmaps)) + height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding) + num_channels = tensor.size(1) + grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value) + k = 0 + for y in range(ymaps): + for x in range(xmaps): + if k >= nmaps: + break + # Tensor.copy_() is a valid method but seems to be missing from the stubs + # https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_ + grid.narrow(1, y * height + padding, height - padding).narrow( # type: ignore[attr-defined] + 2, x * width + padding, width - padding + ).copy_(tensor[k]) + k = k + 1 + return grid + + +@torch.no_grad() +def save_image( + tensor: Union[torch.Tensor, List[torch.Tensor]], + fp: Union[str, pathlib.Path, BinaryIO], + format: Optional[str] = None, + **kwargs, +) -> None: + """ + Save a given Tensor into an image file. + + Args: + tensor (Tensor or list): Image to be saved. If given a mini-batch tensor, + saves the tensor as a grid of images by calling ``make_grid``. + fp (string or file object): A filename or a file object + format(Optional): If omitted, the format to use is determined from the filename extension. + If a file object was used instead of a filename, this parameter should always be used. + **kwargs: Other arguments are documented in ``make_grid``. + """ + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(save_image) + grid = make_grid(tensor, **kwargs) + # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer + ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() + im = Image.fromarray(ndarr) + im.save(fp, format=format) + + +@torch.no_grad() +def draw_bounding_boxes( + image: torch.Tensor, + boxes: torch.Tensor, + labels: Optional[List[str]] = None, + colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None, + fill: Optional[bool] = False, + width: int = 1, + font: Optional[str] = None, + font_size: int = 10, +) -> torch.Tensor: + + """ + Draws bounding boxes on given image. + The values of the input image should be uint8 between 0 and 255. + If fill is True, Resulting Tensor should be saved as PNG image. + + Args: + image (Tensor): Tensor of shape (C x H x W) and dtype uint8. + boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that + the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and + `0 <= ymin < ymax < H`. + labels (List[str]): List containing the labels of bounding boxes. + colors (color or list of colors, optional): List containing the colors + of the boxes or single color for all boxes. The color can be represented as + PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. + By default, random colors are generated for boxes. + fill (bool): If `True` fills the bounding box with specified color. + width (int): Width of bounding box. + font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may + also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`, + `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS. + font_size (int): The requested font size in points. + + Returns: + img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted. + """ + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(draw_bounding_boxes) + if not isinstance(image, torch.Tensor): + raise TypeError(f"Tensor expected, got {type(image)}") + elif image.dtype != torch.uint8: + raise ValueError(f"Tensor uint8 expected, got {image.dtype}") + elif image.dim() != 3: + raise ValueError("Pass individual images, not batches") + elif image.size(0) not in {1, 3}: + raise ValueError("Only grayscale and RGB images are supported") + + num_boxes = boxes.shape[0] + + if labels is None: + labels: Union[List[str], List[None]] = [None] * num_boxes # type: ignore[no-redef] + elif len(labels) != num_boxes: + raise ValueError( + f"Number of boxes ({num_boxes}) and labels ({len(labels)}) mismatch. Please specify labels for each box." + ) + + if colors is None: + colors = _generate_color_palette(num_boxes) + elif isinstance(colors, list): + if len(colors) < num_boxes: + raise ValueError(f"Number of colors ({len(colors)}) is less than number of boxes ({num_boxes}). ") + else: # colors specifies a single color for all boxes + colors = [colors] * num_boxes + + colors = [(ImageColor.getrgb(color) if isinstance(color, str) else color) for color in colors] + + # Handle Grayscale images + if image.size(0) == 1: + image = torch.tile(image, (3, 1, 1)) + + ndarr = image.permute(1, 2, 0).cpu().numpy() + img_to_draw = Image.fromarray(ndarr) + img_boxes = boxes.to(torch.int64).tolist() + + if fill: + draw = ImageDraw.Draw(img_to_draw, "RGBA") + else: + draw = ImageDraw.Draw(img_to_draw) + + txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size) + + for bbox, color, label in zip(img_boxes, colors, labels): # type: ignore[arg-type] + if fill: + fill_color = color + (100,) + draw.rectangle(bbox, width=width, outline=color, fill=fill_color) + else: + draw.rectangle(bbox, width=width, outline=color) + + if label is not None: + margin = width + 1 + draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font) + + return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) + + +@torch.no_grad() +def draw_segmentation_masks( + image: torch.Tensor, + masks: torch.Tensor, + alpha: float = 0.8, + colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None, +) -> torch.Tensor: + + """ + Draws segmentation masks on given RGB image. + The values of the input image should be uint8 between 0 and 255. + + Args: + image (Tensor): Tensor of shape (3, H, W) and dtype uint8. + masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool. + alpha (float): Float number between 0 and 1 denoting the transparency of the masks. + 0 means full transparency, 1 means no transparency. + colors (color or list of colors, optional): List containing the colors + of the masks or single color for all masks. The color can be represented as + PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. + By default, random colors are generated for each mask. + + Returns: + img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top. + """ + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(draw_segmentation_masks) + if not isinstance(image, torch.Tensor): + raise TypeError(f"The image must be a tensor, got {type(image)}") + elif image.dtype != torch.uint8: + raise ValueError(f"The image dtype must be uint8, got {image.dtype}") + elif image.dim() != 3: + raise ValueError("Pass individual images, not batches") + elif image.size()[0] != 3: + raise ValueError("Pass an RGB image. Other Image formats are not supported") + if masks.ndim == 2: + masks = masks[None, :, :] + if masks.ndim != 3: + raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)") + if masks.dtype != torch.bool: + raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}") + if masks.shape[-2:] != image.shape[-2:]: + raise ValueError("The image and the masks must have the same height and width") + + num_masks = masks.size()[0] + if colors is not None and num_masks > len(colors): + raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})") + + if colors is None: + colors = _generate_color_palette(num_masks) + + if not isinstance(colors, list): + colors = [colors] + if not isinstance(colors[0], (tuple, str)): + raise ValueError("colors must be a tuple or a string, or a list thereof") + if isinstance(colors[0], tuple) and len(colors[0]) != 3: + raise ValueError("It seems that you passed a tuple of colors instead of a list of colors") + + out_dtype = torch.uint8 + + colors_ = [] + for color in colors: + if isinstance(color, str): + color = ImageColor.getrgb(color) + colors_.append(torch.tensor(color, dtype=out_dtype)) + + img_to_draw = image.detach().clone() + # TODO: There might be a way to vectorize this + for mask, color in zip(masks, colors_): + img_to_draw[:, mask] = color[:, None] + + out = image * (1 - alpha) + img_to_draw * alpha + return out.to(out_dtype) + + +@torch.no_grad() +def draw_keypoints( + image: torch.Tensor, + keypoints: torch.Tensor, + connectivity: Optional[List[Tuple[int, int]]] = None, + colors: Optional[Union[str, Tuple[int, int, int]]] = None, + radius: int = 2, + width: int = 3, +) -> torch.Tensor: + + """ + Draws Keypoints on given RGB image. + The values of the input image should be uint8 between 0 and 255. + + Args: + image (Tensor): Tensor of shape (3, H, W) and dtype uint8. + keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances, + in the format [x, y]. + connectivity (List[Tuple[int, int]]]): A List of tuple where, + each tuple contains pair of keypoints to be connected. + colors (str, Tuple): The color can be represented as + PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. + radius (int): Integer denoting radius of keypoint. + width (int): Integer denoting width of line connecting keypoints. + + Returns: + img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn. + """ + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(draw_keypoints) + if not isinstance(image, torch.Tensor): + raise TypeError(f"The image must be a tensor, got {type(image)}") + elif image.dtype != torch.uint8: + raise ValueError(f"The image dtype must be uint8, got {image.dtype}") + elif image.dim() != 3: + raise ValueError("Pass individual images, not batches") + elif image.size()[0] != 3: + raise ValueError("Pass an RGB image. Other Image formats are not supported") + + if keypoints.ndim != 3: + raise ValueError("keypoints must be of shape (num_instances, K, 2)") + + ndarr = image.permute(1, 2, 0).cpu().numpy() + img_to_draw = Image.fromarray(ndarr) + draw = ImageDraw.Draw(img_to_draw) + img_kpts = keypoints.to(torch.int64).tolist() + + for kpt_id, kpt_inst in enumerate(img_kpts): + for inst_id, kpt in enumerate(kpt_inst): + x1 = kpt[0] - radius + x2 = kpt[0] + radius + y1 = kpt[1] - radius + y2 = kpt[1] + radius + draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0) + + if connectivity: + for connection in connectivity: + start_pt_x = kpt_inst[connection[0]][0] + start_pt_y = kpt_inst[connection[0]][1] + + end_pt_x = kpt_inst[connection[1]][0] + end_pt_y = kpt_inst[connection[1]][1] + + draw.line( + ((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)), + width=width, + ) + + return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) + + +# Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization +@torch.no_grad() +def flow_to_image(flow: torch.Tensor) -> torch.Tensor: + + """ + Converts a flow to an RGB image. + + Args: + flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float. + + Returns: + img (Tensor): Image Tensor of dtype uint8 where each color corresponds + to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input. + """ + + if flow.dtype != torch.float: + raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.") + + orig_shape = flow.shape + if flow.ndim == 3: + flow = flow[None] # Add batch dim + + if flow.ndim != 4 or flow.shape[1] != 2: + raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.") + + max_norm = torch.sum(flow ** 2, dim=1).sqrt().max() + epsilon = torch.finfo((flow).dtype).eps + normalized_flow = flow / (max_norm + epsilon) + img = _normalized_flow_to_image(normalized_flow) + + if len(orig_shape) == 3: + img = img[0] # Remove batch dim + return img + + +@torch.no_grad() +def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor: + + """ + Converts a batch of normalized flow to an RGB image. + + Args: + normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W) + Returns: + img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8. + """ + + N, _, H, W = normalized_flow.shape + device = normalized_flow.device + flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device) + colorwheel = _make_colorwheel().to(device) # shape [55x3] + num_cols = colorwheel.shape[0] + norm = torch.sum(normalized_flow ** 2, dim=1).sqrt() + a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi + fk = (a + 1) / 2 * (num_cols - 1) + k0 = torch.floor(fk).to(torch.long) + k1 = k0 + 1 + k1[k1 == num_cols] = 0 + f = fk - k0 + + for c in range(colorwheel.shape[1]): + tmp = colorwheel[:, c] + col0 = tmp[k0] / 255.0 + col1 = tmp[k1] / 255.0 + col = (1 - f) * col0 + f * col1 + col = 1 - norm * (1 - col) + flow_image[:, c, :, :] = torch.floor(255 * col) + return flow_image + + +def _make_colorwheel() -> torch.Tensor: + """ + Generates a color wheel for optical flow visualization as presented in: + Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) + URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf. + + Returns: + colorwheel (Tensor[55, 3]): Colorwheel Tensor. + """ + + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + colorwheel = torch.zeros((ncols, 3)) + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = torch.floor(255 * torch.arange(0, RY) / RY) + col = col + RY + # YG + colorwheel[col : col + YG, 0] = 255 - torch.floor(255 * torch.arange(0, YG) / YG) + colorwheel[col : col + YG, 1] = 255 + col = col + YG + # GC + colorwheel[col : col + GC, 1] = 255 + colorwheel[col : col + GC, 2] = torch.floor(255 * torch.arange(0, GC) / GC) + col = col + GC + # CB + colorwheel[col : col + CB, 1] = 255 - torch.floor(255 * torch.arange(CB) / CB) + colorwheel[col : col + CB, 2] = 255 + col = col + CB + # BM + colorwheel[col : col + BM, 2] = 255 + colorwheel[col : col + BM, 0] = torch.floor(255 * torch.arange(0, BM) / BM) + col = col + BM + # MR + colorwheel[col : col + MR, 2] = 255 - torch.floor(255 * torch.arange(MR) / MR) + colorwheel[col : col + MR, 0] = 255 + return colorwheel + + +def _generate_color_palette(num_objects: int): + palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) + return [tuple((i * palette) % 255) for i in range(num_objects)] + + +def _log_api_usage_once(obj: Any) -> None: + + """ + Logs API usage(module and name) within an organization. + In a large ecosystem, it's often useful to track the PyTorch and + TorchVision APIs usage. This API provides the similar functionality to the + logging module in the Python stdlib. It can be used for debugging purpose + to log which methods are used and by default it is inactive, unless the user + manually subscribes a logger via the `SetAPIUsageLogger method `_. + Please note it is triggered only once for the same API call within a process. + It does not collect any data from open-source users since it is no-op by default. + For more information, please refer to + * PyTorch note: https://pytorch.org/docs/stable/notes/large_scale_deployments.html#api-usage-logging; + * Logging policy: https://github.com/pytorch/vision/issues/5052; + + Args: + obj (class instance or method): an object to extract info from. + """ + if not obj.__module__.startswith("torchvision"): + return + name = obj.__class__.__name__ + if isinstance(obj, FunctionType): + name = obj.__name__ + torch._C._log_api_usage_once(f"{obj.__module__}.{name}") diff --git a/my_transforms.py b/my_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..90566ea2ace99209622000b8f5fbe7a58c407c4f --- /dev/null +++ b/my_transforms.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from typing import Sequence + +import torch +from torchvision import transforms +import random + + +class GaussianBlur(transforms.RandomApply): + """ + Apply Gaussian Blur to the PIL image. + """ + + def __init__(self, *, p: float = 0.5, radius_min: float = 0.7, radius_max: float = 1): #0.1 2.0 1.8 2.0 9 best 0.7 1.0 + # NOTE: torchvision is applying 1 - probability to return the original image + keep_p = 1 - p + transform = transforms.GaussianBlur(kernel_size=9, sigma=(radius_min, radius_max)) + super().__init__(transforms=[transform], p=keep_p) + + +class MaybeToTensor(transforms.ToTensor): + """ + Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor, or keep as is if already a tensor. + """ + + def __call__(self, pic): + """ + Args: + pic (PIL Image, numpy.ndarray or torch.tensor): Image to be converted to tensor. + Returns: + Tensor: Converted image. + """ + if isinstance(pic, torch.Tensor): + return pic + return super().__call__(pic) + + +# Use timm's names +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) + + +def make_normalize_transform( + mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, + std: Sequence[float] = IMAGENET_DEFAULT_STD, +) -> transforms.Normalize: + return transforms.Normalize(mean=mean, std=std) + + +IMAGENET_DEFAULT_MEAN_clip = (0.48145466, 0.4578275, 0.40821073) +IMAGENET_DEFAULT_STD_clip = (0.26862954, 0.26130258, 0.27577711) + +def make_normalize_transform_clip( + mean: Sequence[float] = IMAGENET_DEFAULT_MEAN_clip, + std: Sequence[float] = IMAGENET_DEFAULT_STD_clip, +) -> transforms.Normalize: + return transforms.Normalize(mean=mean, std=std) + + +# This roughly matches torchvision's preset for classification training: +# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L6-L44 +def make_classification_train_transform( + *, + crop_size: int = 224, + interpolation=transforms.InterpolationMode.BICUBIC, + hflip_prob: float = 0.5, + mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, + std: Sequence[float] = IMAGENET_DEFAULT_STD, +): + transforms_list = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] + if hflip_prob > 0.0: + transforms_list.append(transforms.RandomHorizontalFlip(hflip_prob)) + transforms_list.extend( + [ + MaybeToTensor(), + make_normalize_transform(mean=mean, std=std), + ] + ) + return transforms.Compose(transforms_list) + + +# This matches (roughly) torchvision's preset for classification evaluation: +# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L47-L69 +def make_classification_eval_transform( + *, + resize_size: int = 256, + interpolation=transforms.InterpolationMode.BICUBIC, + crop_size: int = 224, + mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, + std: Sequence[float] = IMAGENET_DEFAULT_STD, +) -> transforms.Compose: + transforms_list = [ + transforms.Resize(resize_size, interpolation=interpolation), + transforms.CenterCrop(crop_size), + MaybeToTensor(), + make_normalize_transform(mean=mean, std=std), + ] + return transforms.Compose(transforms_list) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..84c1cad9c4b1c971ffb09f30a5d2ab90053eb2f4 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +pytorch +torchvision +torchaudio +scikit-learn +tqdm +ftfy +regex +opencv-python +gradio +spaces +huggingface_hub \ No newline at end of file diff --git a/utils.py b/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..30af82b6b61f754458b26d6a690cff0fb499ac36 --- /dev/null +++ b/utils.py @@ -0,0 +1,333 @@ +import os +import torch +from torch.utils.data import DataLoader +from torchvision import datasets, transforms +from torchvision.transforms.functional import rotate +import config as c + + +import sklearn.metrics as sk + +import numpy as np + +from copy import deepcopy + +def stable_cumsum(arr, rtol=1e-05, atol=1e-08): + """Use high precision for cumsum and check that final value matches sum + Parameters + ---------- + arr : array-like + To be cumulatively summed as flat + rtol : float + Relative tolerance, see ``np.allclose`` + atol : float + Absolute tolerance, see ``np.allclose`` + """ + out = np.cumsum(arr, dtype=np.float64) + expected = np.sum(arr, dtype=np.float64) + if not np.allclose(out[-1], expected, rtol=rtol, atol=atol): + raise RuntimeError('cumsum was found to be unstable: ' + 'its last element does not correspond to sum') + return out + +def fpr_and_fdr_at_recall(y_true, y_score, recall_level=0.95, pos_label=None): + classes = np.unique(y_true) + if (pos_label is None and + not (np.array_equal(classes, [0, 1]) or + np.array_equal(classes, [-1, 1]) or + np.array_equal(classes, [0]) or + np.array_equal(classes, [-1]) or + np.array_equal(classes, [1]))): + raise ValueError("Data is not binary and pos_label is not specified") + elif pos_label is None: + pos_label = 1. + + # make y_true a boolean vector + y_true = (y_true == pos_label) + + # sort scores and corresponding truth values + desc_score_indices = np.argsort(y_score, kind="mergesort")[::-1] + y_score = y_score[desc_score_indices] + #print(y_score) + y_true = y_true[desc_score_indices] + + # y_score typically has many tied values. Here we extract + # the indices associated with the distinct values. We also + # concatenate a value for the end of the curve. + distinct_value_indices = np.where(np.diff(y_score))[0] + threshold_idxs = np.r_[distinct_value_indices, y_true.size - 1] + + # accumulate the true positives with decreasing threshold + tps = stable_cumsum(y_true)[threshold_idxs] + fps = 1 + threshold_idxs - tps # add one because of zero-based indexing + + thresholds = y_score[threshold_idxs] + + recall = tps / tps[-1] + + last_ind = tps.searchsorted(tps[-1]) + sl = slice(last_ind, None, -1) # [last_ind::-1] + recall, fps, tps, thresholds = np.r_[recall[sl], 1], np.r_[fps[sl], 0], np.r_[tps[sl], 0], thresholds[sl] + #print(recall) + cutoff = np.argmin(np.abs(recall - recall_level)) + return fps[cutoff] / (np.sum(np.logical_not(y_true))), thresholds[cutoff] # , fps[cutoff]/(fps[cutoff] + tps[cutoff]) + +def get_random_transforms(): + augmentative_transforms = [] + if c.transf_rotations: + augmentative_transforms += [transforms.RandomRotation(180)] + if c.transf_brightness > 0.0 or c.transf_contrast > 0.0 or c.transf_saturation > 0.0: + augmentative_transforms += [transforms.ColorJitter(brightness=c.transf_brightness, contrast=c.transf_contrast, + saturation=c.transf_saturation)] + + tfs = [transforms.Resize(c.img_size)] + augmentative_transforms + [transforms.ToTensor(), + transforms.Normalize(c.norm_mean, c.norm_std)] + + transform_train = transforms.Compose(tfs) + return transform_train + + +def get_fixed_transforms(degrees): + cust_rot = lambda x: rotate(x, degrees, False, False, None) + augmentative_transforms = [cust_rot] + if c.transf_brightness > 0.0 or c.transf_contrast > 0.0 or c.transf_saturation > 0.0: + augmentative_transforms += [ + transforms.ColorJitter(brightness=c.transf_brightness, contrast=c.transf_contrast, + saturation=c.transf_saturation)] + tfs = [transforms.Resize(c.img_size)] + augmentative_transforms + [transforms.ToTensor(), + transforms.Normalize(c.norm_mean, + c.norm_std)] + return transforms.Compose(tfs) + + +def t2np(tensor): + '''pytorch tensor -> numpy array''' + return tensor.cpu().data.numpy() if tensor is not None else None + + +def get_loss(z, jac): + '''check equation 4 of the paper why this makes sense - oh and just ignore the scaling here''' + return torch.mean(0.5 * torch.sum(z ** 2, dim=(1,)) - jac) / z.shape[1] + +# def get_loss_neg_pos(z, jac, labels): +# '''损失函数:正样本接近高斯分布,负样本远离高斯分布''' +# # 计算流模型的标准生成损失 +# normalizing_loss = torch.mean(0.5 * torch.sum(z ** 2, dim=(1,)) - jac) / z.shape[1] + +# # 对正样本(标签为0)希望其潜在特征接近高斯分布 +# positive_loss = normalizing_loss * (labels == 0).float() + +# # 对负样本(标签为1)希望其潜在特征远离高斯分布 +# negative_loss = -normalizing_loss * (labels == 1).float() + +# # 计算总损失 +# total_loss = torch.mean(positive_loss + negative_loss) + +# return total_loss + + +def get_loss_neg_pos(z, jac, labels, target_distribution="gaussian", margin = 500): + # 计算流模型的标准生成损失 + + loss_sample_pos = 0.5 * torch.sum((z-10) ** 2, dim=(1,)) - jac #损失是否应该都大于零 + + loss_sample_neg = 0.5 * torch.sum(z ** 2, dim=(1,)) - jac + + + positive_loss = loss_sample_pos * (labels == 0).float() + + negative_loss = loss_sample_neg * (labels == 1).float() + + # 计算总损失 + total_loss = torch.mean(positive_loss + negative_loss )/ z.shape[1] + + return total_loss + + +def get_loss_neg_pos_margin(z, jac, labels, margin = 500): + # 计算流模型的标准生成损失 + + # print(jac) + + # jac = torch.clamp(jac, min=1e-5, max=1e5) + # z = torch.clamp(z, min=-1e5, max=1e5) + + + loss_sample = 0.5 * torch.sum(z ** 2, dim=(1,)) #损失是否应该都大于零 + # print(loss_sample) + + + # positive_loss = (-loss_sample) * (labels == 0).float()* (loss_sample 0.5).float() # 趋向0,差距越小越好 + consistent_loss = consistent_loss/len(labels) + # total_loss = shape_loss + consistent_loss * 0.05 + total_loss = consistent_loss + + return shape_loss, consistent_loss, total_loss + +def get_measures(_pos, _neg, recall_level=0.95): + pos = np.array(_pos[:]).reshape((-1, 1)) + neg = np.array(_neg[:]).reshape((-1, 1)) + examples = np.squeeze(np.vstack((pos, neg))) + labels = np.zeros(len(examples), dtype=np.int32) + labels[:len(pos)] += 1 + + auroc = sk.roc_auc_score(labels, examples) + aupr = sk.average_precision_score(labels, examples) + fpr, threshold = fpr_and_fdr_at_recall(labels, examples, recall_level) + return auroc, aupr, fpr + +def find_best_threshold(y_true, y_pred): + "We assume first half is real 0, and the second half is fake 1" + + N = y_true.shape[0] + + if y_pred[0:N//2].max() <= y_pred[N//2:N].min(): # perfectly separable case + return (y_pred[0:N//2].max() + y_pred[N//2:N].min()) / 2 + + best_acc = 0 + best_thres = 0 + for thres in y_pred: + temp = deepcopy(y_pred) + temp[temp>=thres] = 1 + temp[temp= best_acc: + best_thres = thres + best_acc = acc + + return best_thres + +def get_loss_neg(z, jac, labels, margin = 500): + # 计算流模型的标准生成损失 + + # print(jac) + + loss_sample = 0.5 * torch.sum(z ** 2, dim=(1,)) -jac #损失是否应该都大于零 + # print(loss_sample) + + + # positive_loss = (-loss_sample) * (labels == 0).float()* (loss_sample