diff --git a/GDPOSR/datasets/__pycache__/realesrgan.cpython-310.pyc b/GDPOSR/datasets/__pycache__/realesrgan.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..581313ac0215f6ec87fbf4930ccee3bea3c3037a
Binary files /dev/null and b/GDPOSR/datasets/__pycache__/realesrgan.cpython-310.pyc differ
diff --git a/GDPOSR/datasets/params_GDPO.yml b/GDPOSR/datasets/params_GDPO.yml
new file mode 100644
index 0000000000000000000000000000000000000000..e62322e79da69b91004f0f336f5cc0896852e399
--- /dev/null
+++ b/GDPOSR/datasets/params_GDPO.yml
@@ -0,0 +1,42 @@
+scale: 4
+color_jitter_prob: 0.0
+gray_prob: 0.0
+
+# the first degradation process
+resize_prob: [0.2, 0.7, 0.1] # up, down, keep
+resize_range: [0.3, 1.5]
+gaussian_noise_prob: 0.5
+noise_range: [1, 15]
+poisson_scale_range: [0.05, 2.0]
+gray_noise_prob: 0.4
+jpeg_range: [60, 95]
+
+# the second degradation process
+second_phase_prob: 1.0
+second_blur_prob: 0.5
+resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
+resize_range2: [0.6, 1.2]
+gaussian_noise_prob2: 0.5
+noise_range2: [1, 12]
+poisson_scale_range2: [0.05, 1.0]
+gray_noise_prob2: 0.4
+jpeg_range2: [60, 100]
+
+kernel_info:
+ blur_kernel_size: 21
+ kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
+ kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
+ sinc_prob: 0.1
+ blur_sigma: [0.2, 3]
+ betag_range: [0.5, 4]
+ betap_range: [1, 2]
+
+ blur_kernel_size2: 21
+ kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
+ kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
+ sinc_prob2: 0.1
+ blur_sigma2: [0.2, 1.5]
+ betag_range2: [0.5, 4]
+ betap_range2: [1, 2]
+
+ final_sinc_prob: 0.8
\ No newline at end of file
diff --git a/GDPOSR/datasets/realesrgan.py b/GDPOSR/datasets/realesrgan.py
new file mode 100644
index 0000000000000000000000000000000000000000..e88b7bc0cb4329e3fff7b8f3e9510c429b88b132
--- /dev/null
+++ b/GDPOSR/datasets/realesrgan.py
@@ -0,0 +1,305 @@
+import os
+import numpy as np
+import cv2
+import glob
+import math
+import yaml
+import random
+from collections import OrderedDict
+import torch
+import torch.nn.functional as F
+
+from basicsr.data.transforms import augment
+from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
+from basicsr.utils import DiffJPEG, USMSharp, img2tensor, tensor2img
+from basicsr.utils.img_process_util import filter2D
+from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
+from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
+ normalize, rgb_to_grayscale)
+
+cur_path = os.path.dirname(os.path.abspath(__file__))
+
+def ordered_yaml():
+ """Support OrderedDict for yaml.
+
+ Returns:
+ yaml Loader and Dumper.
+ """
+ try:
+ from yaml import CDumper as Dumper
+ from yaml import CLoader as Loader
+ except ImportError:
+ from yaml import Dumper, Loader
+
+ _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
+
+ def dict_representer(dumper, data):
+ return dumper.represent_dict(data.items())
+
+ def dict_constructor(loader, node):
+ return OrderedDict(loader.construct_pairs(node))
+
+ Dumper.add_representer(OrderedDict, dict_representer)
+ Loader.add_constructor(_mapping_tag, dict_constructor)
+ return Loader, Dumper
+
+def opt_parse(opt_path):
+ with open(opt_path, mode='r') as f:
+ Loader, _ = ordered_yaml()
+ opt = yaml.load(f, Loader=Loader) # ignore_security_alert_wait_for_fix RCE
+
+ return opt
+
+class RealESRGAN_degradation(object):
+ def __init__(self, opt_name='params_realesrgan.yml', device='cpu'):
+ opt_path = f'{cur_path}/{opt_name}'
+ self.opt = opt_parse(opt_path)
+ self.device = device #torch.device('cpu')
+ optk = self.opt['kernel_info']
+
+ # blur settings for the first degradation
+ self.blur_kernel_size = optk['blur_kernel_size']
+ self.kernel_list = optk['kernel_list']
+ self.kernel_prob = optk['kernel_prob']
+ self.blur_sigma = optk['blur_sigma']
+ self.betag_range = optk['betag_range']
+ self.betap_range = optk['betap_range']
+ self.sinc_prob = optk['sinc_prob']
+
+ # blur settings for the second degradation
+ self.blur_kernel_size2 = optk['blur_kernel_size2']
+ self.kernel_list2 = optk['kernel_list2']
+ self.kernel_prob2 = optk['kernel_prob2']
+ self.blur_sigma2 = optk['blur_sigma2']
+ self.betag_range2 = optk['betag_range2']
+ self.betap_range2 = optk['betap_range2']
+ self.sinc_prob2 = optk['sinc_prob2']
+
+ # a final sinc filter
+ self.final_sinc_prob = optk['final_sinc_prob']
+
+ self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
+ self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
+ self.pulse_tensor[10, 10] = 1
+
+ self.jpeger = DiffJPEG(differentiable=False).to(self.device)
+ self.usm_shaper = USMSharp().to(self.device)
+
+ def color_jitter_pt(self, img, brightness, contrast, saturation, hue):
+ fn_idx = torch.randperm(4)
+ for fn_id in fn_idx:
+ if fn_id == 0 and brightness is not None:
+ brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
+ img = adjust_brightness(img, brightness_factor)
+
+ if fn_id == 1 and contrast is not None:
+ contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
+ img = adjust_contrast(img, contrast_factor)
+
+ if fn_id == 2 and saturation is not None:
+ saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
+ img = adjust_saturation(img, saturation_factor)
+
+ if fn_id == 3 and hue is not None:
+ hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
+ img = adjust_hue(img, hue_factor)
+ return img
+
+ def random_augment(self, img_gt):
+ # random horizontal flip
+ img_gt, status = augment(img_gt, hflip=True, rotation=False, return_status=True)
+ """
+ # random color jitter
+ if np.random.uniform() < self.opt['color_jitter_prob']:
+ jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
+ img_gt = img_gt + jitter_val
+ img_gt = np.clip(img_gt, 0, 1)
+
+ # random grayscale
+ if np.random.uniform() < self.opt['gray_prob']:
+ #img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
+ img_gt = cv2.cvtColor(img_gt, cv2.COLOR_RGB2GRAY)
+ img_gt = np.tile(img_gt[:, :, None], [1, 1, 3])
+ """
+ # BGR to RGB, HWC to CHW, numpy to tensor
+ img_gt = img2tensor([img_gt], bgr2rgb=False, float32=True)[0].unsqueeze(0)
+ return img_gt
+
+ def random_kernels(self):
+ # ------------------------ Generate kernels (used in the first degradation) ------------------------ #
+ kernel_size = random.choice(self.kernel_range)
+ if np.random.uniform() < self.sinc_prob:
+ # this sinc filter setting is for kernels ranging from [7, 21]
+ if kernel_size < 13:
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
+ else:
+ omega_c = np.random.uniform(np.pi / 5, np.pi)
+ kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
+ else:
+ kernel = random_mixed_kernels(
+ self.kernel_list,
+ self.kernel_prob,
+ kernel_size,
+ self.blur_sigma,
+ self.blur_sigma, [-math.pi, math.pi],
+ self.betag_range,
+ self.betap_range,
+ noise_range=None)
+ # pad kernel
+ pad_size = (21 - kernel_size) // 2
+ kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
+
+ # ------------------------ Generate kernels (used in the second degradation) ------------------------ #
+ kernel_size = random.choice(self.kernel_range)
+ if np.random.uniform() < self.sinc_prob2:
+ if kernel_size < 13:
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
+ else:
+ omega_c = np.random.uniform(np.pi / 5, np.pi)
+ kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
+ else:
+ kernel2 = random_mixed_kernels(
+ self.kernel_list2,
+ self.kernel_prob2,
+ kernel_size,
+ self.blur_sigma2,
+ self.blur_sigma2, [-math.pi, math.pi],
+ self.betag_range2,
+ self.betap_range2,
+ noise_range=None)
+
+ # pad kernel
+ pad_size = (21 - kernel_size) // 2
+ kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
+
+ # ------------------------------------- sinc kernel ------------------------------------- #
+ if np.random.uniform() < self.final_sinc_prob:
+ kernel_size = random.choice(self.kernel_range)
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
+ sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
+ sinc_kernel = torch.FloatTensor(sinc_kernel)
+ else:
+ sinc_kernel = self.pulse_tensor
+
+ kernel = torch.FloatTensor(kernel)
+ kernel2 = torch.FloatTensor(kernel2)
+
+ return kernel, kernel2, sinc_kernel
+
+ @torch.no_grad()
+ def degrade_process(self, img_gt, resize_bak=False):
+ img_gt = self.random_augment(img_gt)
+ kernel1, kernel2, sinc_kernel = self.random_kernels()
+ img_gt, kernel1, kernel2, sinc_kernel = img_gt.to(self.device), kernel1.to(self.device), kernel2.to(self.device), sinc_kernel.to(self.device)
+ #img_gt = self.usm_shaper(img_gt) # shaper gt
+ ori_h, ori_w = img_gt.size()[2:4]
+
+ #scale_final = random.randint(4, 16)
+ scale_final = 4
+
+ # ----------------------- The first degradation process ----------------------- #
+ # blur
+ out = filter2D(img_gt, kernel1)
+ # random resize
+ updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
+ if updown_type == 'up':
+ scale = np.random.uniform(1, self.opt['resize_range'][1])
+ elif updown_type == 'down':
+ scale = np.random.uniform(self.opt['resize_range'][0], 1)
+ else:
+ scale = 1
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(out, scale_factor=scale, mode=mode)
+ # noise
+ gray_noise_prob = self.opt['gray_noise_prob']
+ if np.random.uniform() < self.opt['gaussian_noise_prob']:
+ out = random_add_gaussian_noise_pt(
+ out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
+ else:
+ out = random_add_poisson_noise_pt(
+ out,
+ scale_range=self.opt['poisson_scale_range'],
+ gray_prob=gray_noise_prob,
+ clip=True,
+ rounds=False)
+ # JPEG compression
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
+ out = torch.clamp(out, 0, 1)
+ out = self.jpeger(out, quality=jpeg_p)
+
+ # ----------------------- The second degradation process ----------------------- #
+ # blur
+ if self.opt['second_phase_prob'] > random.random():
+ # print('----------------------- The second degradation process -----------------------')
+ if np.random.uniform() < self.opt['second_blur_prob']:
+ out = filter2D(out, kernel2)
+ # random resize
+ updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
+ if updown_type == 'up':
+ scale = np.random.uniform(1, self.opt['resize_range2'][1])
+ elif updown_type == 'down':
+ scale = np.random.uniform(self.opt['resize_range2'][0], 1)
+ else:
+ scale = 1
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(
+ out, size=(int(ori_h / scale_final * scale), int(ori_w / scale_final * scale)), mode=mode)
+ # noise
+ gray_noise_prob = self.opt['gray_noise_prob2']
+ if np.random.uniform() < self.opt['gaussian_noise_prob2']:
+ out = random_add_gaussian_noise_pt(
+ out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
+ else:
+ out = random_add_poisson_noise_pt(
+ out,
+ scale_range=self.opt['poisson_scale_range2'],
+ gray_prob=gray_noise_prob,
+ clip=True,
+ rounds=False)
+
+ # JPEG compression + the final sinc filter
+ # We also need to resize images to desired sizes. We group [resize back + sinc filter] together
+ # as one operation.
+ # We consider two orders:
+ # 1. [resize back + sinc filter] + JPEG compression
+ # 2. JPEG compression + [resize back + sinc filter]
+ # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
+ if np.random.uniform() < 0.5:
+ # resize back + the final sinc filter
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(out, size=(ori_h // scale_final, ori_w // scale_final), mode=mode)
+ out = filter2D(out, sinc_kernel)
+ # JPEG compression
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
+ out = torch.clamp(out, 0, 1)
+ out = self.jpeger(out, quality=jpeg_p)
+ else:
+ # JPEG compression
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
+ out = torch.clamp(out, 0, 1)
+ out = self.jpeger(out, quality=jpeg_p)
+ # resize back + the final sinc filter
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(out, size=(ori_h // scale_final, ori_w // scale_final), mode=mode)
+ out = filter2D(out, sinc_kernel)
+
+ if np.random.uniform() < self.opt['gray_prob']:
+ out = rgb_to_grayscale(out, num_output_channels=1)
+
+ if np.random.uniform() < self.opt['color_jitter_prob']:
+ brightness = self.opt.get('brightness', (0.5, 1.5))
+ contrast = self.opt.get('contrast', (0.5, 1.5))
+ saturation = self.opt.get('saturation', (0, 1.5))
+ hue = self.opt.get('hue', (-0.1, 0.1))
+ out = self.color_jitter_pt(out, brightness, contrast, saturation, hue)
+
+ out1 = out
+ img_lq_noresize = torch.clamp((out1 * 255.0).round(), 0, 255) / 255.
+
+ if resize_bak:
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(out, size=(ori_h, ori_w), mode=mode)
+ # clamp and round
+ img_lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
+
+ return img_gt, img_lq, img_lq_noresize
\ No newline at end of file
diff --git a/GDPOSR/diffusermodels/__pycache__/autoencoder_kl.cpython-310.pyc b/GDPOSR/diffusermodels/__pycache__/autoencoder_kl.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d02a4a2a6e757b1c15eb02e3fd7a652842717bb4
Binary files /dev/null and b/GDPOSR/diffusermodels/__pycache__/autoencoder_kl.cpython-310.pyc differ
diff --git a/GDPOSR/diffusermodels/__pycache__/unet_2d_condition.cpython-310.pyc b/GDPOSR/diffusermodels/__pycache__/unet_2d_condition.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e2cd16338137c6ac70020d12c6760ebd44a4b257
Binary files /dev/null and b/GDPOSR/diffusermodels/__pycache__/unet_2d_condition.cpython-310.pyc differ
diff --git a/GDPOSR/diffusermodels/autoencoder_kl.py b/GDPOSR/diffusermodels/autoencoder_kl.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ac16c6cf1f0ded4f4496a63c00e4e3781066852
--- /dev/null
+++ b/GDPOSR/diffusermodels/autoencoder_kl.py
@@ -0,0 +1,560 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders import FromOriginalVAEMixin
+from diffusers.utils.accelerate_utils import apply_forward_hook
+from diffusers.models.attention_processor import (
+ ADDED_KV_ATTENTION_PROCESSORS,
+ CROSS_ATTENTION_PROCESSORS,
+ Attention,
+ AttentionProcessor,
+ AttnAddedKVProcessor,
+ AttnProcessor,
+)
+from diffusers.models.modeling_outputs import AutoencoderKLOutput
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.autoencoders.vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
+
+class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
+ r"""
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
+ Tuple of downsample block types.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
+ Tuple of upsample block types.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
+ Tuple of block output channels.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
+ scaling_factor (`float`, *optional*, defaults to 0.18215):
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
+ force_upcast (`bool`, *optional*, default to `True`):
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
+ block_out_channels: Tuple[int] = (64,),
+ layers_per_block: int = 1,
+ act_fn: str = "silu",
+ latent_channels: int = 4,
+ norm_num_groups: int = 32,
+ sample_size: int = 32,
+ scaling_factor: float = 0.18215,
+ force_upcast: float = True,
+ ):
+ super().__init__()
+
+ # pass init params to Encoder
+ self.encoder = Encoder(
+ in_channels=in_channels,
+ out_channels=latent_channels,
+ down_block_types=down_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ norm_num_groups=norm_num_groups,
+ double_z=True,
+ )
+
+ # pass init params to Decoder
+ self.decoder = Decoder(
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ up_block_types=up_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ norm_num_groups=norm_num_groups,
+ act_fn=act_fn,
+ )
+
+ self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
+ self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
+
+ self.use_slicing = False
+ self.use_tiling = False
+
+ # only relevant if vae tiling is enabled
+ self.tile_sample_min_size = self.config.sample_size
+ sample_size = (
+ self.config.sample_size[0]
+ if isinstance(self.config.sample_size, (list, tuple))
+ else self.config.sample_size
+ )
+ self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
+ self.tile_overlap_factor = 0.25
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (Encoder, Decoder)):
+ module.gradient_checkpointing = value
+
+ def enable_tiling(self, use_tiling: bool = True):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.use_tiling = use_tiling
+
+ def disable_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.enable_tiling(False)
+
+ def enable_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.use_slicing = True
+
+ def disable_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_slicing = False
+
+ @property
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
+ ):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor, _remove_lora=_remove_lora)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnAddedKVProcessor()
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor, _remove_lora=True)
+
+ @apply_forward_hook
+ def encode(
+ self, x: torch.FloatTensor, return_dict: bool = True
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
+ """
+ Encode a batch of images into latents.
+
+ Args:
+ x (`torch.FloatTensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ The latent representations of the encoded images. If `return_dict` is True, a
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
+ """
+ if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
+ return self.tiled_encode(x, return_dict=return_dict)
+
+ if self.use_slicing and x.shape[0] > 1:
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
+ h = torch.cat(encoded_slices)
+ else:
+ h = self.encoder(x)
+
+ moments = self.quant_conv(h.to(dtype=self.quant_conv.weight.dtype))
+ posterior = DiagonalGaussianDistribution(moments)
+
+ if not return_dict:
+ return (posterior,)
+
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
+ return self.tiled_decode(z, return_dict=return_dict)
+
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ @apply_forward_hook
+ def decode(
+ self, z: torch.FloatTensor, return_dict: bool = True, generator=None
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
+ """
+ Decode a batch of images.
+
+ Args:
+ z (`torch.FloatTensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+
+ """
+ if self.use_slicing and z.shape[0] > 1:
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
+ decoded = torch.cat(decoded_slices)
+ else:
+ decoded = self._decode(z).sample
+
+ if not return_dict:
+ return (decoded,)
+
+ return DecoderOutput(sample=decoded)
+
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[2], b.shape[2], blend_extent)
+ for y in range(blend_extent):
+ b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
+ return b
+
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
+ return b
+
+ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
+ r"""Encode a batch of images using a tiled encoder.
+
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
+ output, but they should be much less noticeable.
+
+ Args:
+ x (`torch.FloatTensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
+ If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
+ `tuple` is returned.
+ """
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
+ row_limit = self.tile_latent_min_size - blend_extent
+
+ # Split the image into 512x512 tiles and encode them separately.
+ rows = []
+ for i in range(0, x.shape[2], overlap_size):
+ row = []
+ for j in range(0, x.shape[3], overlap_size):
+ tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
+ tile = self.encoder(tile)
+ tile = self.quant_conv(tile)
+ row.append(tile)
+ rows.append(row)
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :row_limit, :row_limit])
+ result_rows.append(torch.cat(result_row, dim=3))
+
+ moments = torch.cat(result_rows, dim=2)
+ posterior = DiagonalGaussianDistribution(moments)
+
+ if not return_dict:
+ return (posterior,)
+
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ r"""
+ Decode a batch of images using a tiled decoder.
+
+ Args:
+ z (`torch.FloatTensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
+ row_limit = self.tile_sample_min_size - blend_extent
+
+ # Split z into overlapping 64x64 tiles and decode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ rows = []
+ for i in range(0, z.shape[2], overlap_size):
+ row = []
+ for j in range(0, z.shape[3], overlap_size):
+ tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
+ tile = self.post_quant_conv(tile)
+ decoded = self.decoder(tile)
+ row.append(decoded)
+ rows.append(row)
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :row_limit, :row_limit])
+ result_rows.append(torch.cat(result_row, dim=3))
+
+ dec = torch.cat(result_rows, dim=2)
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ sample_posterior: bool = False,
+ return_dict: bool = True,
+ generator: Optional[torch.Generator] = None,
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
+ r"""
+ Args:
+ sample (`torch.FloatTensor`): Input sample.
+ sample_posterior (`bool`, *optional*, defaults to `False`):
+ Whether to sample from the posterior.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ """
+ x = sample
+ posterior = self.encode(x).latent_dist
+ if sample_posterior:
+ z = posterior.sample(generator=generator)
+ else:
+ z = posterior.mode()
+ dec = self.decode(z).sample
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
+
+
+ def merge_and_unload(
+ self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None
+ ) -> torch.nn.Module:
+
+ return self._unload_and_optionally_merge(
+ progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names
+ )
+
+ def _unload_and_optionally_merge(
+ self,
+ merge=True,
+ progressbar: bool = False,
+ safe_merge: bool = False,
+ adapter_names: Optional[list[str]] = None,
+ ):
+ from tqdm import tqdm
+ from peft.tuners.tuners_utils import onload_layer
+ from peft.utils import _get_submodules, ModulesToSaveWrapper
+
+ key_list = [key for key, _ in self.named_modules() if "lora_" not in key]
+ desc = "Unloading " + ("and merging " if merge else "") + "model"
+ for key in tqdm(key_list, disable=not progressbar, desc=desc):
+ try:
+ parent, target, target_name = _get_submodules(self, key)
+ except AttributeError:
+ continue
+ with onload_layer(target):
+ if hasattr(target, "base_layer"):
+ if merge:
+ target.merge(safe_merge=safe_merge, adapter_names=adapter_names)
+ self._replace_module(parent, target_name, target.get_base_layer(), target)
+ elif isinstance(target, ModulesToSaveWrapper):
+ # save any additional trainable modules part of `modules_to_save`
+ new_module = target.modules_to_save[target.active_adapter]
+ if hasattr(new_module, "base_layer"):
+ # check if the module is itself a tuner layer
+ if merge:
+ new_module.merge(safe_merge=safe_merge, adapter_names=adapter_names)
+ new_module = new_module.get_base_layer()
+ setattr(parent, target_name, new_module)
+
+ return self
+
+ def _replace_module(self, parent, child_name, new_module, child):
+ setattr(parent, child_name, new_module)
+ # It's not necessary to set requires_grad here, as that is handled by
+ # _mark_only_adapters_as_trainable
+
+ # child layer wraps the original module, unpack it
+ if hasattr(child, "base_layer"):
+ child = child.base_layer
+
+ if not hasattr(new_module, "base_layer"):
+ new_module.weight = child.weight
+ if hasattr(child, "bias"):
+ new_module.bias = child.bias
+
+ if getattr(child, "state", None) is not None:
+ if hasattr(new_module, "base_layer"):
+ new_module.base_layer.state = child.state
+ else:
+ new_module.state = child.state
+ new_module.to(child.weight.device)
+
+ # dispatch to correct device
+ for name, module in new_module.named_modules():
+ if ("lora_" in name) or ("ranknum" in name):
+ weight = child.qweight if hasattr(child, "qweight") else child.weight
+ module.to(weight.device)
\ No newline at end of file
diff --git a/GDPOSR/diffusermodels/unet_2d_condition.py b/GDPOSR/diffusermodels/unet_2d_condition.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d30e61c1aab964372dbab233e29e1c47efe6f97
--- /dev/null
+++ b/GDPOSR/diffusermodels/unet_2d_condition.py
@@ -0,0 +1,1280 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders import UNet2DConditionLoadersMixin
+from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
+from diffusers.models.activations import get_activation
+from diffusers.models.attention_processor import (
+ ADDED_KV_ATTENTION_PROCESSORS,
+ CROSS_ATTENTION_PROCESSORS,
+ Attention,
+ AttentionProcessor,
+ AttnAddedKVProcessor,
+ AttnProcessor,
+)
+from diffusers.models.embeddings import (
+ GaussianFourierProjection,
+ ImageHintTimeEmbedding,
+ ImageProjection,
+ ImageTimeEmbedding,
+ PositionNet,
+ TextImageProjection,
+ TextImageTimeEmbedding,
+ TextTimeEmbedding,
+ TimestepEmbedding,
+ Timesteps,
+)
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.unet_2d_blocks import (
+ UNetMidBlock2D,
+ UNetMidBlock2DCrossAttn,
+ UNetMidBlock2DSimpleCrossAttn,
+ get_down_block,
+ get_up_block,
+)
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class UNet2DConditionOutput(BaseOutput):
+ """
+ The output of [`UNet2DConditionModel`].
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
+ """
+
+ sample: torch.FloatTensor = None
+
+
+class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
+ r"""
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
+ shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
+ Height and width of input/output sample.
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ The tuple of downsample blocks to use.
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
+ The tuple of upsample blocks to use.
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
+ Whether to include self-attention in the basic transformer blocks, see
+ [`~models.attention.BasicTransformerBlock`].
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
+ If `None`, normalization and activation layers is skipped in post-processing.
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
+ The dimension of the cross attention features.
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ encoder_hid_dim (`int`, *optional*, defaults to None):
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
+ dimension to `cross_attention_dim`.
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
+ num_attention_heads (`int`, *optional*):
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
+ class_embed_type (`str`, *optional*, defaults to `None`):
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
+ addition_embed_type (`str`, *optional*, defaults to `None`):
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
+ "text". "text" will use the `TextTimeEmbedding` layer.
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
+ Dimension for the timestep embeddings.
+ num_class_embeds (`int`, *optional*, defaults to `None`):
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
+ class conditioning with `class_embed_type` equal to `None`.
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
+ An optional override for the dimension of the projected time embedding.
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
+ timestep_post_act (`str`, *optional*, defaults to `None`):
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
+ The dimension of `cond_proj` layer in the timestep embedding.
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
+ *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
+ *optional*): The dimension of the `class_labels` input when
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
+ embeddings with the class embeddings.
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
+ otherwise.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ center_input_sample: bool = False,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: Union[int, Tuple[int]] = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ dropout: float = 0.0,
+ act_fn: str = "silu",
+ norm_num_groups: Optional[int] = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
+ encoder_hid_dim: Optional[int] = None,
+ encoder_hid_dim_type: Optional[str] = None,
+ attention_head_dim: Union[int, Tuple[int]] = 8,
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ addition_embed_type: Optional[str] = None,
+ addition_time_embed_dim: Optional[int] = None,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ resnet_skip_time_act: bool = False,
+ resnet_out_scale_factor: int = 1.0,
+ time_embedding_type: str = "positional",
+ time_embedding_dim: Optional[int] = None,
+ time_embedding_act_fn: Optional[str] = None,
+ timestep_post_act: Optional[str] = None,
+ time_cond_proj_dim: Optional[int] = None,
+ conv_in_kernel: int = 3,
+ conv_out_kernel: int = 3,
+ projection_class_embeddings_input_dim: Optional[int] = None,
+ attention_type: str = "default",
+ class_embeddings_concat: bool = False,
+ mid_block_only_cross_attention: Optional[bool] = None,
+ cross_attention_norm: Optional[str] = None,
+ addition_embed_type_num_heads=64,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+
+ if num_attention_heads is not None:
+ raise ValueError(
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
+ )
+
+ # If `num_attention_heads` is not defined (which is the case for most models)
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
+ # which is why we correct for the naming here.
+ num_attention_heads = num_attention_heads or attention_head_dim
+
+ # Check inputs
+ if len(down_block_types) != len(up_block_types):
+ raise ValueError(
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
+ )
+
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
+ )
+ if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
+ for layer_number_per_block in transformer_layers_per_block:
+ if isinstance(layer_number_per_block, list):
+ raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
+
+ # input
+ conv_in_padding = (conv_in_kernel - 1) // 2
+ self.conv_in = nn.Conv2d(
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
+ )
+
+ # time
+ if time_embedding_type == "fourier":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
+ if time_embed_dim % 2 != 0:
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
+ self.time_proj = GaussianFourierProjection(
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
+ )
+ timestep_input_dim = time_embed_dim
+ elif time_embedding_type == "positional":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
+
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+ else:
+ raise ValueError(
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
+ )
+
+ self.time_embedding = TimestepEmbedding(
+ timestep_input_dim,
+ time_embed_dim,
+ act_fn=act_fn,
+ post_act_fn=timestep_post_act,
+ cond_proj_dim=time_cond_proj_dim,
+ )
+
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
+ encoder_hid_dim_type = "text_proj"
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
+
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
+ )
+
+ if encoder_hid_dim_type == "text_proj":
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
+ elif encoder_hid_dim_type == "text_image_proj":
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
+ self.encoder_hid_proj = TextImageProjection(
+ text_embed_dim=encoder_hid_dim,
+ image_embed_dim=cross_attention_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ elif encoder_hid_dim_type == "image_proj":
+ # Kandinsky 2.2
+ self.encoder_hid_proj = ImageProjection(
+ image_embed_dim=encoder_hid_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ elif encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
+ )
+ else:
+ self.encoder_hid_proj = None
+
+ # class embedding
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ elif class_embed_type == "projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
+ # 2. it projects from an arbitrary input dimension.
+ #
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ elif class_embed_type == "simple_projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
+ else:
+ self.class_embedding = None
+
+ if addition_embed_type == "text":
+ if encoder_hid_dim is not None:
+ text_time_embedding_from_dim = encoder_hid_dim
+ else:
+ text_time_embedding_from_dim = cross_attention_dim
+
+ self.add_embedding = TextTimeEmbedding(
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
+ )
+ elif addition_embed_type == "text_image":
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
+ self.add_embedding = TextImageTimeEmbedding(
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
+ )
+ elif addition_embed_type == "text_time":
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ elif addition_embed_type == "image":
+ # Kandinsky 2.2
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
+ elif addition_embed_type == "image_hint":
+ # Kandinsky 2.2 ControlNet
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
+ elif addition_embed_type is not None:
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
+
+ if time_embedding_act_fn is None:
+ self.time_embed_act = None
+ else:
+ self.time_embed_act = get_activation(time_embedding_act_fn)
+
+ self.down_blocks = nn.ModuleList([])
+ self.up_blocks = nn.ModuleList([])
+
+ if isinstance(only_cross_attention, bool):
+ if mid_block_only_cross_attention is None:
+ mid_block_only_cross_attention = only_cross_attention
+
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+ if mid_block_only_cross_attention is None:
+ mid_block_only_cross_attention = False
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+ if isinstance(cross_attention_dim, int):
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
+
+ if isinstance(layers_per_block, int):
+ layers_per_block = [layers_per_block] * len(down_block_types)
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+
+ if class_embeddings_concat:
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
+ # regular time embeddings
+ blocks_time_embed_dim = time_embed_dim * 2
+ else:
+ blocks_time_embed_dim = time_embed_dim
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block[i],
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim[i],
+ num_attention_heads=num_attention_heads[i],
+ downsample_padding=downsample_padding,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ attention_type=attention_type,
+ resnet_skip_time_act=resnet_skip_time_act,
+ resnet_out_scale_factor=resnet_out_scale_factor,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ dropout=dropout,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
+ self.mid_block = UNetMidBlock2DCrossAttn(
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ in_channels=block_out_channels[-1],
+ temb_channels=blocks_time_embed_dim,
+ dropout=dropout,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim[-1],
+ num_attention_heads=num_attention_heads[-1],
+ resnet_groups=norm_num_groups,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ attention_type=attention_type,
+ )
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
+ in_channels=block_out_channels[-1],
+ temb_channels=blocks_time_embed_dim,
+ dropout=dropout,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ cross_attention_dim=cross_attention_dim[-1],
+ attention_head_dim=attention_head_dim[-1],
+ resnet_groups=norm_num_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ skip_time_act=resnet_skip_time_act,
+ only_cross_attention=mid_block_only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ )
+ elif mid_block_type == "UNetMidBlock2D":
+ self.mid_block = UNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ temb_channels=blocks_time_embed_dim,
+ dropout=dropout,
+ num_layers=0,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_groups=norm_num_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ add_attention=False,
+ )
+ elif mid_block_type is None:
+ self.mid_block = None
+ else:
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
+
+ # count how many layers upsample the images
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
+ reversed_layers_per_block = list(reversed(layers_per_block))
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
+ reversed_transformer_layers_per_block = (
+ list(reversed(transformer_layers_per_block))
+ if reverse_transformer_layers_per_block is None
+ else reverse_transformer_layers_per_block
+ )
+ only_cross_attention = list(reversed(only_cross_attention))
+
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=reversed_layers_per_block[i] + 1,
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resolution_idx=i,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=reversed_cross_attention_dim[i],
+ num_attention_heads=reversed_num_attention_heads[i],
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ attention_type=attention_type,
+ resnet_skip_time_act=resnet_skip_time_act,
+ resnet_out_scale_factor=resnet_out_scale_factor,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ dropout=dropout,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ if norm_num_groups is not None:
+ self.conv_norm_out = nn.GroupNorm(
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
+ )
+
+ self.conv_act = get_activation(act_fn)
+
+ else:
+ self.conv_norm_out = None
+ self.conv_act = None
+
+ conv_out_padding = (conv_out_kernel - 1) // 2
+ self.conv_out = nn.Conv2d(
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
+ )
+
+ if attention_type in ["gated", "gated-text-image"]:
+ positive_len = 768
+ if isinstance(cross_attention_dim, int):
+ positive_len = cross_attention_dim
+ elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
+ positive_len = cross_attention_dim[0]
+
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
+ self.position_net = PositionNet(
+ positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
+ )
+
+ @property
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_attn_processor(
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
+ ):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor, _remove_lora=_remove_lora)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnAddedKVProcessor()
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor, _remove_lora=True)
+
+ def set_attention_slice(self, slice_size):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_sliceable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_sliceable_dims(module)
+
+ num_sliceable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_sliceable_layers * [1]
+
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+ def enable_freeu(self, s1, s2, b1, b2):
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
+
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
+
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ for i, upsample_block in enumerate(self.up_blocks):
+ setattr(upsample_block, "s1", s1)
+ setattr(upsample_block, "s2", s2)
+ setattr(upsample_block, "b1", b1)
+ setattr(upsample_block, "b2", b2)
+
+ def disable_freeu(self):
+ """Disables the FreeU mechanism."""
+ freeu_keys = {"s1", "s2", "b1", "b2"}
+ for i, upsample_block in enumerate(self.up_blocks):
+ for k in freeu_keys:
+ if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
+ setattr(upsample_block, k, None)
+
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[UNet2DConditionOutput, Tuple]:
+ r"""
+ The [`UNet2DConditionModel`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.FloatTensor`):
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ added_cond_kwargs: (`dict`, *optional*):
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
+ are passed along to the UNet blocks.
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
+ A tensor that if specified is added to the residual of the middle unet block.
+ encoder_attention_mask (`torch.Tensor`):
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
+ added_cond_kwargs: (`dict`, *optional*):
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
+ are passed along to the UNet blocks.
+ down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
+ additional residuals to be added to UNet long skip connections from down blocks to up blocks for
+ example from ControlNet side model(s)
+ mid_block_additional_residual (`torch.Tensor`, *optional*):
+ additional residual to be added to UNet mid block output, for example from ControlNet side model
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
+
+ Returns:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
+ a `tuple` is returned where the first element is the sample tensor.
+ """
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2**self.num_upsamplers
+
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ for dim in sample.shape[-2:]:
+ if dim % default_overall_up_factor != 0:
+ # Forward upsample size to force interpolation output size.
+ forward_upsample_size = True
+ break
+
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # 0. center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ emb = self.time_embedding(t_emb, timestep_cond)
+ aug_emb = None
+
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # there might be better ways to encapsulate this.
+ class_labels = class_labels.to(dtype=sample.dtype)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
+
+ if self.config.class_embeddings_concat:
+ emb = torch.cat([emb, class_emb], dim=-1)
+ else:
+ emb = emb + class_emb
+
+ if self.config.addition_embed_type == "text":
+ aug_emb = self.add_embedding(encoder_hidden_states)
+ elif self.config.addition_embed_type == "text_image":
+ # Kandinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ )
+
+ image_embs = added_cond_kwargs.get("image_embeds")
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
+ aug_emb = self.add_embedding(text_embs, image_embs)
+ elif self.config.addition_embed_type == "text_time":
+ # SDXL - style
+ if "text_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
+ )
+ text_embeds = added_cond_kwargs.get("text_embeds")
+ if "time_ids" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
+ )
+ time_ids = added_cond_kwargs.get("time_ids")
+ time_embeds = self.add_time_proj(time_ids.flatten())
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
+ add_embeds = add_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(add_embeds)
+ elif self.config.addition_embed_type == "image":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ )
+ image_embs = added_cond_kwargs.get("image_embeds")
+ aug_emb = self.add_embedding(image_embs)
+ elif self.config.addition_embed_type == "image_hint":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
+ )
+ image_embs = added_cond_kwargs.get("image_embeds")
+ hint = added_cond_kwargs.get("hint")
+ aug_emb, hint = self.add_embedding(image_embs, hint)
+ sample = torch.cat([sample, hint], dim=1)
+
+ emb = emb + aug_emb if aug_emb is not None else emb
+
+ if self.time_embed_act is not None:
+ emb = self.time_embed_act(emb)
+
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
+ # Kadinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
+ encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
+
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ # 2.5 GLIGEN position net
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
+ cross_attention_kwargs = cross_attention_kwargs.copy()
+ gligen_args = cross_attention_kwargs.pop("gligen")
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
+
+ # 3. down
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
+ is_adapter = down_intrablock_additional_residuals is not None
+ # maintain backward compatibility for legacy usage, where
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
+ # but can only use one or the other
+ if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
+ deprecate(
+ "T2I should not use down_block_additional_residuals",
+ "1.3.0",
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
+ standard_warn=False,
+ )
+ down_intrablock_additional_residuals = down_block_additional_residuals
+ is_adapter = True
+
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ # For t2i-adapter CrossAttnDownBlock2D
+ additional_residuals = {}
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
+ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
+
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ **additional_residuals,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
+ sample += down_intrablock_additional_residuals.pop(0)
+
+ down_block_res_samples += res_samples
+
+ if is_controlnet:
+ new_down_block_res_samples = ()
+
+ for down_block_res_sample, down_block_additional_residual in zip(
+ down_block_res_samples, down_block_additional_residuals
+ ):
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = new_down_block_res_samples
+
+ # 4. mid
+ if self.mid_block is not None:
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ sample = self.mid_block(sample, emb)
+
+ # To support T2I-Adapter-XL
+ if (
+ is_adapter
+ and len(down_intrablock_additional_residuals) > 0
+ and sample.shape == down_intrablock_additional_residuals[0].shape
+ ):
+ sample += down_intrablock_additional_residuals.pop(0)
+
+ if is_controlnet:
+ sample = sample + mid_block_additional_residual
+
+ # 5. up
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ upsample_size=upsample_size,
+ scale=lora_scale,
+ )
+
+ # 6. post-process
+ if self.conv_norm_out:
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet2DConditionOutput(sample=sample)
+
+
+ def merge_and_unload(
+ self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None
+ ) -> torch.nn.Module:
+
+ return self._unload_and_optionally_merge(
+ progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names
+ )
+
+ def _unload_and_optionally_merge(
+ self,
+ merge=True,
+ progressbar: bool = False,
+ safe_merge: bool = False,
+ adapter_names: Optional[list[str]] = None,
+ ):
+ from tqdm import tqdm
+ from peft.tuners.tuners_utils import onload_layer
+ from peft.utils import _get_submodules, ModulesToSaveWrapper
+
+ key_list = [key for key, _ in self.named_modules() if "lora_" not in key]
+ desc = "Unloading " + ("and merging " if merge else "") + "model"
+ for key in tqdm(key_list, disable=not progressbar, desc=desc):
+ try:
+ parent, target, target_name = _get_submodules(self, key)
+ except AttributeError:
+ continue
+ with onload_layer(target):
+ if hasattr(target, "base_layer"):
+ if merge:
+ target.merge(safe_merge=safe_merge, adapter_names=adapter_names)
+ self._replace_module(parent, target_name, target.get_base_layer(), target)
+ elif isinstance(target, ModulesToSaveWrapper):
+ # save any additional trainable modules part of `modules_to_save`
+ new_module = target.modules_to_save[target.active_adapter]
+ if hasattr(new_module, "base_layer"):
+ # check if the module is itself a tuner layer
+ if merge:
+ new_module.merge(safe_merge=safe_merge, adapter_names=adapter_names)
+ new_module = new_module.get_base_layer()
+ setattr(parent, target_name, new_module)
+
+ return self
+
+ def _replace_module(self, parent, child_name, new_module, child):
+ setattr(parent, child_name, new_module)
+ # It's not necessary to set requires_grad here, as that is handled by
+ # _mark_only_adapters_as_trainable
+
+ # child layer wraps the original module, unpack it
+ if hasattr(child, "base_layer"):
+ child = child.base_layer
+
+ if not hasattr(new_module, "base_layer"):
+ new_module.weight = child.weight
+ if hasattr(child, "bias"):
+ new_module.bias = child.bias
+
+ if getattr(child, "state", None) is not None:
+ if hasattr(new_module, "base_layer"):
+ new_module.base_layer.state = child.state
+ else:
+ new_module.state = child.state
+ new_module.to(child.weight.device)
+
+ # dispatch to correct device
+ for name, module in new_module.named_modules():
+ if ("lora_" in name) or ("ranknum" in name):
+ weight = child.qweight if hasattr(child, "qweight") else child.weight
+ module.to(weight.device)
\ No newline at end of file
diff --git a/GDPOSR/inferences/test.py b/GDPOSR/inferences/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..56dda0ddbf5001888cd862784115e61adbad708b
--- /dev/null
+++ b/GDPOSR/inferences/test.py
@@ -0,0 +1,107 @@
+import os
+import argparse
+import numpy as np
+from PIL import Image
+import torch
+from torchvision import transforms
+import torchvision.transforms.functional as F
+import sys
+sys.path.append("GDPOSR")
+from modelfile.GDPOSR import GDPOSRTest
+from my_utils.wavelet_color_fix import adain_color_fix, wavelet_color_fix
+import glob
+sys.path.append('./')
+from ram.models.ram_lora import ram
+from ram import inference_ram as inference
+tensor_transforms = transforms.Compose([
+ transforms.ToTensor(),
+ ])
+
+ram_transforms = transforms.Compose([
+ transforms.Resize((384, 384)),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ ])
+
+def get_validation_prompt(args, image, model, device='cuda'):
+ validation_prompt = ""
+ lq = tensor_transforms(image).unsqueeze(0).to(device)
+ lq = ram_transforms(lq)
+ captions = inference(lq, model)
+ validation_prompt = f"{captions[0]}, {args.prompt},"
+
+ return validation_prompt
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--input_image', type=str, default="", help='path to the input image')
+ parser.add_argument('--model_name', type=str, default='realsr', help='name of the pretrained model to be used')
+ parser.add_argument('--pretrained_path', type=str, default='', help='path to a model state dict to be used')
+ parser.add_argument('--output_dir', type=str, default='', help='the directory to save the output')
+ parser.add_argument('--seed', type=int, default=42, help='Random seed to be used')
+ parser.add_argument("--process_size", type=int, default=512)
+ parser.add_argument("--upscale", type=int, default=4)
+ parser.add_argument("--align_method", type=str, choices=['wavelet', 'adain', 'nofix'], default='adain')
+ parser.add_argument("--pretrained_model_name_or_path", type=str, default="")
+ parser.add_argument('--ram_ft_path', type=str, default=None) #
+ parser.add_argument('--prompt', type=str, default='', help='positive prompts')
+ parser.add_argument('--negprompt', type=str, default='', help='negative prompts')
+ parser.add_argument("--time_step", type=int, default=1)
+ parser.add_argument("--time_step_noise", type=int, default=1)
+ args = parser.parse_args()
+
+ # initialize the model
+ model = GDPOSRTest(args)
+ model.set_eval()
+
+ if os.path.isdir(args.input_image):
+ image_names = sorted(glob.glob(f'{args.input_image}/*.png'))
+ else:
+ image_names = [args.input_image]
+
+ print("=== use ram ===")
+ model_vlm = ram(pretrained='./ckp/ram_swin_large_14m.pth',
+ pretrained_condition=args.ram_ft_path,
+ image_size=384,
+ vit='swin_l')
+ model_vlm.eval()
+ model_vlm.to("cuda")
+
+ # make the output dir
+ os.makedirs(args.output_dir, exist_ok=True)
+ print(f'There are {len(image_names)} images.')
+ for image_name in image_names:
+
+ # make sure that the input image is a multiple of 8
+ input_image = Image.open(image_name).convert('RGB')
+ ori_width, ori_height = input_image.size
+ rscale = args.upscale
+ resize_flag = False
+ if ori_width < args.process_size//rscale or ori_height < args.process_size//rscale:
+ scale = (args.process_size//rscale)/min(ori_width, ori_height)
+ input_image = input_image.resize((int(scale*ori_width), int(scale*ori_height)))
+ resize_flag = True
+ input_image = input_image.resize((input_image.size[0]*rscale, input_image.size[1]*rscale))
+
+ new_width = input_image.width - input_image.width % 8
+ new_height = input_image.height - input_image.height % 8
+ input_image = input_image.resize((new_width, new_height), Image.LANCZOS)
+ bname = os.path.basename(image_name)
+
+ # get caption
+ validation_prompt = get_validation_prompt(args, input_image, model_vlm)
+ # translate the image
+ with torch.no_grad():
+ c_t = F.to_tensor(input_image).unsqueeze(0).cuda()*2-1
+ output_image = model(c_t, positive_prompt=[validation_prompt])
+ output_pil = transforms.ToPILImage()(output_image[0].cpu() * 0.5 + 0.5)
+ if args.align_method == 'adain':
+ output_pil = adain_color_fix(target=output_pil, source=input_image)
+ elif args.align_method == 'wavelet':
+ output_pil = wavelet_color_fix(target=output_pil, source=input_image)
+ else:
+ pass
+ if resize_flag:
+ output_pil.resize((int(args.upscale*ori_width), int(args.upscale*ori_height)))
+
+ output_pil.save(os.path.join(args.output_dir, bname))
diff --git a/GDPOSR/losses/__pycache__/grpo.cpython-310.pyc b/GDPOSR/losses/__pycache__/grpo.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..af58ea62e4f143ef08023627c02054fdda7fd4db
Binary files /dev/null and b/GDPOSR/losses/__pycache__/grpo.cpython-310.pyc differ
diff --git a/GDPOSR/losses/grpo.py b/GDPOSR/losses/grpo.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e48cef2ffc674001d17f4929c36833cfda3e2e0
--- /dev/null
+++ b/GDPOSR/losses/grpo.py
@@ -0,0 +1,52 @@
+import pyiqa
+
+from basicsr.utils import img2tensor, tensor2img
+from torch.utils import data as data
+import glob
+import numpy as np
+import math
+import random
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.distributions import Normal
+import sys
+
+class AdaptiveReward(nn.Module):
+ def __init__(self):
+ super().__init__()
+ device = torch.device("cuda")
+ self.iqa_psnr = pyiqa.create_metric('psnr', test_y_channel=True, color_space='ycbcr').to(device)
+ self.iqa_maniqa = pyiqa.create_metric('maniqa', device=device)
+ self.iqa_musiq = pyiqa.create_metric('musiq', device=device)
+
+ def normalize_tensor(self, tensor):
+
+ min_val = tensor.min()
+ max_val = tensor.max()
+
+ normalized_tensor = (tensor - min_val) / (max_val - min_val)
+ return normalized_tensor
+
+ def forward(self, x, y, fedilty_ratio, detail_ratio):
+ x = x * 0.5 + 0.5
+ y = y * 0.5 + 0.5
+ b,gs,c,h,w=x.shape
+ reward = torch.zeros([b,gs])
+ for i in range(b):
+ fedilty_i = fedilty_ratio[i]
+ detail_i = detail_ratio[i]
+ x_i = x[i]
+ y_i = y[i]
+ psnr_result = self.normalize_tensor(self.iqa_psnr(x_i, y_i))
+ musiq_result = self.normalize_tensor(self.iqa_musiq(x_i).squeeze(1))
+ maniqa_result = self.normalize_tensor(self.iqa_maniqa(x_i).squeeze(1))
+
+ reward_i = fedilty_i*psnr_result + detail_i*0.5*(maniqa_result+musiq_result)
+ combined_mean = torch.mean(reward_i)
+ combined_std = reward_i.std(unbiased=True)
+ reward_i = (reward_i - combined_mean) / (combined_std+1e-8)
+ reward[i] = reward_i
+
+ return reward.detach()
+
diff --git a/GDPOSR/mergelora.py b/GDPOSR/mergelora.py
new file mode 100644
index 0000000000000000000000000000000000000000..0aacd5713a08af77e05bdc86656c3512a8bebb1e
--- /dev/null
+++ b/GDPOSR/mergelora.py
@@ -0,0 +1,72 @@
+import os
+import requests
+import sys
+import copy
+import random
+import time
+import glob
+import math
+import yaml
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from tqdm import tqdm
+from peft import LoraConfig
+from types import SimpleNamespace
+
+import sys
+sys.path.append("./")
+from diffusermodels.autoencoder_kl import AutoencoderKL as AutoencoderKLMerge
+from diffusermodels.unet_2d_condition import UNet2DConditionModel as UNet2DConditionModelMerge
+
+
+def UNetMergeLoRA(basemodel_path='', trainedmodel_path='', savepath='', savename=''):
+
+ loraweight = torch.load(trainedmodel_path)
+ # vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
+ unet = UNet2DConditionModelMerge.from_pretrained(basemodel_path, subfolder="unet")
+
+ # load unet lora
+ lora_conf_encoder = LoraConfig(r=loraweight["rank_unet"], init_lora_weights="gaussian", target_modules=loraweight["unet_lora_encoder_modules"])
+ lora_conf_decoder = LoraConfig(r=loraweight["rank_unet"], init_lora_weights="gaussian", target_modules=loraweight["unet_lora_decoder_modules"])
+ lora_conf_others = LoraConfig(r=loraweight["rank_unet"], init_lora_weights="gaussian", target_modules=loraweight["unet_lora_others_modules"])
+ unet.add_adapter(lora_conf_encoder, adapter_name="default_encoder")
+ unet.add_adapter(lora_conf_decoder, adapter_name="default_decoder")
+ unet.add_adapter(lora_conf_others, adapter_name="default_others")
+ for n, p in unet.named_parameters():
+ if "lora" in n or "conv_in" in n:
+ p.data.copy_(loraweight["state_dict_unet"][n])
+
+ unet.set_adapter(['default_encoder', 'default_decoder', 'default_others'])
+ unet = unet.merge_and_unload()
+ unet.save_pretrained(os.path.join(savepath, savename))
+
+def VAEMergeLoRA(basemodel_path='', trainedmodel_path='', savepath='', savename=''):
+
+ loraweight = torch.load(trainedmodel_path)
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
+
+ # load vae lora
+ vae_lora_conf_encoder = LoraConfig(r=loraweight["rank_vae"], init_lora_weights="gaussian", target_modules=loraweight["vae_lora_encoder_modules"])
+ vae_lora_conf_decoder = LoraConfig(r=loraweight["rank_vae"], init_lora_weights="gaussian", target_modules=loraweight["vae_lora_decoder_modules"])
+ vae.add_adapter(vae_lora_conf_encoder, adapter_name="default_encoder")
+ vae.add_adapter(vae_lora_conf_decoder, adapter_name="default_decoder")
+ for n, p in vae.named_parameters():
+ if "lora" in n:
+ p.data.copy_(loraweight["state_dict_vae"][n])
+
+ vae.set_adapter(['default_encoder'])
+ vae = vae.merge_and_unload()
+ vae.save_pretrained(os.path.join(savepath, savename))
+
+unetbasemodel_path=''
+unettrainedmodel_path=''
+unetsavepath=''
+unetsavename=''
+UNetMergeLoRA(unetbasemodel_path, unettrainedmodel_path, unetsavepath, unetsavename)
+
+vaebasemodel_path=''
+vaetrainedmodel_path=''
+vaesavepath=''
+vaesavename=''
+VAEMergeLoRA(vaebasemodel_path, vaetrainedmodel_path, vaesavepath, vaesavename)
\ No newline at end of file
diff --git a/GDPOSR/modelfile/GDPOSR.py b/GDPOSR/modelfile/GDPOSR.py
new file mode 100644
index 0000000000000000000000000000000000000000..56027f3799b3c4cf3aaaf087612d97c4e5e27316
--- /dev/null
+++ b/GDPOSR/modelfile/GDPOSR.py
@@ -0,0 +1,623 @@
+import os
+import requests
+import sys
+import copy
+import random
+import time
+import glob
+import math
+import yaml
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from tqdm import tqdm
+from peft import LoraConfig
+from types import SimpleNamespace
+from transformers import AutoTokenizer, CLIPTextModel
+from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
+from diffusers.utils.peft_utils import set_weights_and_activate_adapters
+from diffusers.utils.import_utils import is_xformers_available
+
+def make_1step_sched(pretrained_model_path):
+ noise_scheduler_1step = DDPMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
+ noise_scheduler_1step.set_timesteps(1, device="cuda")
+ noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.cuda()
+ return noise_scheduler_1step
+
+def find_filepath(directory, filename):
+ matches = glob.glob(f"{directory}/**/{filename}", recursive=True)
+ return matches[0] if matches else None
+
+
+def read_yaml(file_path):
+ with open(file_path, 'r') as file:
+ data = yaml.safe_load(file)
+ return data
+
+def initialize_vae(rank, return_lora_module_names=False, pretrained_model_name_or_path=None):
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
+ vae.requires_grad_(False)
+ vae.train()
+
+ l_target_modules_encoder, l_target_modules_decoder, l_modules_others = [], [], []
+ l_grep = ["conv1","conv2","conv_in", "conv_shortcut",
+ "conv", "conv_out", "to_k", "to_q", "to_v", "to_out.0",
+ ]
+ for n, p in vae.named_parameters():
+ if "bias" in n or "norm" in n: continue
+ for pattern in l_grep:
+ if pattern in n and ("encoder" in n):
+ l_target_modules_encoder.append(n.replace(".weight",""))
+ break
+ elif pattern in n and ("decoder" in n):
+ l_target_modules_decoder.append(n.replace(".weight",""))
+ break
+ elif ('quant_conv' in n) and ('post_quant_conv' not in n):
+ l_target_modules_encoder.append(n.replace(".weight",""))
+ break
+ elif 'post_quant_conv' in n:
+ l_target_modules_decoder.append(n.replace(".weight",""))
+ break
+ elif pattern in n:
+ l_modules_others.append(n.replace(".weight",""))
+ break
+ lora_conf_encoder = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_target_modules_encoder)
+ lora_conf_decoder = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_target_modules_decoder)
+ vae.add_adapter(lora_conf_encoder, adapter_name="default_encoder")
+ vae.add_adapter(lora_conf_decoder, adapter_name="default_decoder")
+ # vae.set_adapter(["default_encoder", "default_decoder"])
+ if return_lora_module_names:
+ return vae, l_target_modules_encoder, l_target_modules_decoder, l_modules_others
+ else:
+ return vae
+
+def initialize_unet(rank, return_lora_module_names=False, pretrained_model_name_or_path=None):
+ unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet")
+ unet.requires_grad_(False)
+ unet.train()
+
+ l_target_modules_encoder, l_target_modules_decoder, l_modules_others = [], [], []
+ l_grep = ["to_k", "to_q", "to_v", "to_out.0", "conv", "conv1", "conv2", "conv_in", "conv_shortcut", "conv_out", "proj_out", "proj_in", "ff.net.2", "ff.net.0.proj"]
+ for n, p in unet.named_parameters():
+ if "bias" in n or "norm" in n: continue
+ for pattern in l_grep:
+ if pattern in n and ("down_blocks" in n or "conv_in" in n):
+ l_target_modules_encoder.append(n.replace(".weight",""))
+ break
+ elif pattern in n and "up_blocks" in n:
+ l_target_modules_decoder.append(n.replace(".weight",""))
+ break
+ elif pattern in n:
+ l_modules_others.append(n.replace(".weight",""))
+ break
+ lora_conf_encoder = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_target_modules_encoder)
+ lora_conf_decoder = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_target_modules_decoder)
+ lora_conf_others = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_modules_others)
+ unet.add_adapter(lora_conf_encoder, adapter_name="default_encoder")
+ unet.add_adapter(lora_conf_decoder, adapter_name="default_decoder")
+ unet.add_adapter(lora_conf_others, adapter_name="default_others")
+ if return_lora_module_names:
+ return unet, l_target_modules_encoder, l_target_modules_decoder, l_modules_others
+ else:
+ return unet
+
+def initialize_unet_sr(rank, return_lora_module_names=False, pretrained_model_name_or_path=None, args=None):
+ unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet")
+ if args.use_lr_concat_lr_999noise:
+ new_conv_in = torch.nn.Conv2d(8, 320, 3, 1, 1)
+ new_conv_in.weight.data[:, :4, ...] = unet.conv_in.weight.data
+ new_conv_in.weight.data[:, -4:, ...] = unet.conv_in.weight.data
+ new_conv_in.bias.data = unet.conv_in.bias.data
+ unet.conv_in = new_conv_in
+ unet.requires_grad_(False)
+ unet.train()
+
+ l_target_modules_encoder, l_target_modules_decoder, l_modules_others = [], [], []
+ l_grep = ["to_k", "to_q", "to_v", "to_out.0", "conv", "conv1", "conv2", "conv_in", "conv_shortcut", "conv_out", "proj_out", "proj_in", "ff.net.2", "ff.net.0.proj"]
+ for n, p in unet.named_parameters():
+ if "bias" in n or "norm" in n: continue
+ for pattern in l_grep:
+ if pattern in n and ("down_blocks" in n or "conv_in" in n):
+ l_target_modules_encoder.append(n.replace(".weight",""))
+ break
+ elif pattern in n and "up_blocks" in n:
+ l_target_modules_decoder.append(n.replace(".weight",""))
+ break
+ elif pattern in n:
+ l_modules_others.append(n.replace(".weight",""))
+ break
+ lora_conf_encoder = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_target_modules_encoder)
+ lora_conf_decoder = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_target_modules_decoder)
+ lora_conf_others = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_modules_others)
+ unet.add_adapter(lora_conf_encoder, adapter_name="default_encoder")
+ unet.add_adapter(lora_conf_decoder, adapter_name="default_decoder")
+ unet.add_adapter(lora_conf_others, adapter_name="default_others")
+ if return_lora_module_names:
+ return unet, l_target_modules_encoder, l_target_modules_decoder, l_modules_others
+ else:
+ return unet
+
+class VSD(torch.nn.Module):
+ def __init__(self, args, accelerator):
+ super().__init__()
+
+ self.tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
+ self.text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
+ self.sched = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ self.args = args
+
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ self.vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
+ self.unet_fix = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
+ self.unet_update, self.lora_unet_modules_encoder, self.lora_unet_modules_decoder, self.lora_unet_others =\
+ initialize_unet(rank=args.lora_rank_unet_vsd, pretrained_model_name_or_path=args.pretrained_model_name_or_path, return_lora_module_names=True)
+ self.lora_rank_unet = args.lora_rank_unet_vsd
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ self.unet_fix.enable_xformers_memory_efficient_attention()
+ self.unet_update.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available, please install it by running `pip install xformers`")
+
+ if args.gradient_checkpointing:
+ self.unet_fix.enable_gradient_checkpointing()
+ self.unet_update.enable_gradient_checkpointing()
+
+ self.text_encoder.to(accelerator.device, dtype=weight_dtype)
+ self.unet_fix.to(accelerator.device, dtype=weight_dtype)
+ self.unet_update.to(accelerator.device)
+ self.vae.to(accelerator.device)
+
+ self.text_encoder.requires_grad_(False)
+ self.vae.requires_grad_(False)
+ self.unet_fix.requires_grad_(False)
+
+ def set_eval(self):
+ self.unet_fix.eval()
+ self.unet.eval()
+ self.unet_update.eval()
+
+ def set_train(self):
+ self.unet_update.train()
+ for n, _p in self.unet_update.named_parameters():
+ if "lora" in n:
+ _p.requires_grad = True
+
+ def forward(self, c_t, prompt=None, neg_prompt_tokens=None, prompt_tokens=None, deterministic=True, r=1.0, noise_map=None, args=None):
+
+ caption_enc = self.text_encoder(prompt_tokens)[0]
+ neg_caption_enc = self.text_encoder(neg_prompt_tokens)[0]
+
+ encoded_control = self.vae.encode(c_t).latent_dist.sample() * self.vae.config.scaling_factor
+ model_pred = self.unet(encoded_control, self.timesteps, encoder_hidden_states=caption_enc.to(torch.float32),).sample
+ x_denoised = self.sched.step(model_pred, self.timesteps, encoded_control, return_dict=True).prev_sample
+
+ output_image = (self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample).clamp(-1, 1)
+
+ return output_image, caption_enc, neg_caption_enc
+
+ def forward_latent(self, model, latents, timestep, prompt_embeds):
+
+ noise_pred = model(
+ latents,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ ).sample
+
+ return noise_pred
+
+ def compute_lora_loss(self, latents_pred, prompt_embeds, args):
+
+ latents_pred = latents_pred.detach()
+ prompt_embeds = prompt_embeds.detach()
+ noise = torch.randn_like(latents_pred)
+ bsz = latents_pred.shape[0]
+ timesteps = torch.randint(0, self.sched.config.num_train_timesteps, (bsz,), device=latents_pred.device)
+ timesteps = timesteps.long()
+ noisy_latents = self.sched.add_noise(latents_pred, noise, timesteps)
+ disc_pred = self.forward_latent(
+ self.unet_update,
+ timestep=timesteps,
+ latents=noisy_latents,
+ prompt_embeds=prompt_embeds
+ )
+ if args.snr_gamma_vsd is None:
+ loss_d = F.mse_loss(disc_pred.float(), noise.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(self.sched, timesteps)
+ if self.sched.config.prediction_type == "v_prediction":
+ # Velocity objective requires that we add one to SNR values before we divide by them.
+ snr = snr + 1
+ mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss_d = loss.mean()
+
+ return loss_d
+
+ def eps_to_mu(self, scheduler, model_output, sample, timesteps):
+ alphas_cumprod = scheduler.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
+ alpha_prod_t = alphas_cumprod[timesteps]
+ while len(alpha_prod_t.shape) < len(sample.shape):
+ alpha_prod_t = alpha_prod_t.unsqueeze(-1)
+ beta_prod_t = 1 - alpha_prod_t
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ return pred_original_sample
+
+ def distribution_matching_loss(
+ self,
+ real_model,
+ fake_model,
+ noise_scheduler,
+ latents,
+ prompt_embeds,
+ negative_prompt_embeds,
+ args,
+ ):
+ bsz = latents.shape[0]
+ min_dm_step = int(noise_scheduler.config.num_train_timesteps * args.min_dm_step_ratio)
+ max_dm_step = int(noise_scheduler.config.num_train_timesteps * args.max_dm_step_ratio)
+
+ timestep = torch.randint(min_dm_step, max_dm_step, (bsz,), device=latents.device).long()
+ noise = torch.randn_like(latents)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timestep)
+
+ with torch.no_grad():
+ noise_pred = self.forward_latent(
+ fake_model,
+ latents=noisy_latents,
+ timestep=timestep,
+ prompt_embeds=prompt_embeds.float(),
+ )
+ pred_fake_latents = self.eps_to_mu(noise_scheduler, noise_pred, noisy_latents, timestep)
+
+ noisy_latents_input = torch.cat([noisy_latents] * 2)
+ timestep_input = torch.cat([timestep] * 2)
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+
+ noise_pred = self.forward_latent(
+ real_model,
+ latents=noisy_latents_input.to(dtype=torch.float16),
+ timestep=timestep_input,
+ prompt_embeds=prompt_embeds.to(dtype=torch.float16),
+ )
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + args.cfg_vsd * (noise_pred_text - noise_pred_uncond)
+ noise_pred.to(dtype=torch.float32)
+
+ pred_real_latents = self.eps_to_mu(noise_scheduler, noise_pred, noisy_latents, timestep)
+
+ weighting_factor = torch.abs(latents - pred_real_latents).mean(dim=[1, 2, 3], keepdim=True)
+
+ grad = (pred_fake_latents - pred_real_latents) / weighting_factor
+ loss = F.mse_loss(latents, self.stopgrad(latents - grad))
+ return loss
+
+ def stopgrad(self, x):
+ return x.detach()
+
+ def save_model(self, outf):
+ sd = {}
+ sd["unet_lora_encoder_modules"], sd["unet_lora_decoder_modules"], sd["unet_lora_others_modules"] =\
+ self.lora_unet_modules_encoder, self.lora_unet_modules_decoder, self.lora_unet_others
+ sd["rank_unet"] = self.lora_rank_unet
+ sd["state_dict_unet"] = {k: v for k, v in self.unet.state_dict().items() if "lora" in k}
+ torch.save(sd, outf)
+
+class NAOSD(torch.nn.Module):
+ def __init__(self, args):
+ super().__init__()
+
+ self.tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
+ self.text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder").cuda()
+ self.sched = make_1step_sched(args.pretrained_model_name_or_path)
+ self.sched2 = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ self.args = args
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
+ unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
+
+ if args.pretrained_path is None:
+ vae, lora_vae_modules_encoder, lora_vae_modules_decoder, lora_vae_others =\
+ initialize_vae(rank=args.lora_rank_vae, pretrained_model_name_or_path=args.pretrained_model_name_or_path, return_lora_module_names=True)
+ unet, lora_unet_modules_encoder, lora_unet_modules_decoder, lora_unet_others =\
+ initialize_unet_sr(rank=args.lora_rank_unet, pretrained_model_name_or_path=args.pretrained_model_name_or_path, return_lora_module_names=True, args=args)
+ self.lora_rank_unet = args.lora_rank_unet
+ self.lora_rank_vae = args.lora_rank_vae
+ self.lora_vae_modules_encoder, self.lora_vae_modules_decoder, self.lora_vae_others = \
+ lora_vae_modules_encoder, lora_vae_modules_decoder, lora_vae_others
+ self.lora_unet_modules_encoder, self.lora_unet_modules_decoder, self.lora_unet_others = \
+ lora_unet_modules_encoder, lora_unet_modules_decoder, lora_unet_others
+
+ self.unet, self.vae = unet, vae
+
+ if args.pretrained_path is not None:
+ print('==================================> loading pre-trained weight')
+ sd = torch.load(args.pretrained_path)
+ self.load_ckpt_from_state_dict(sd)
+ self.lora_rank_unet = sd['rank_unet']
+ self.lora_rank_vae = sd['rank_vae']
+ self.lora_vae_modules_encoder, self.lora_vae_modules_decoder, self.lora_vae_others = \
+ sd['vae_lora_encoder_modules'], sd['vae_lora_decoder_modules'], sd['vae_lora_others_modules']
+ self.lora_unet_modules_encoder, self.lora_unet_modules_decoder, self.lora_unet_others = \
+ sd['unet_lora_encoder_modules'], sd['unet_lora_decoder_modules'], sd['unet_lora_others_modules']
+
+ self.unet, self.vae = self.unet.cuda(), self.vae.cuda()
+ self.timesteps = torch.tensor([args.time_step], device="cuda").long()
+ self.timestepsnoise = torch.tensor([args.time_step_noise], device="cuda").long()
+ self.text_encoder.requires_grad_(False)
+
+ def set_eval(self):
+ self.unet.eval()
+ self.vae.eval()
+ self.unet.requires_grad_(False)
+ self.vae.requires_grad_(False)
+
+ def set_train(self):
+ self.unet.train()
+ self.vae.train()
+ for n, _p in self.unet.named_parameters():
+ if "lora" in n:
+ _p.requires_grad = True
+ self.unet.conv_in.requires_grad_(True)
+ for n, _p in self.vae.named_parameters():
+ if "lora" in n:
+ _p.requires_grad = True
+
+ def encode_prompt(self, prompt):
+ with torch.no_grad():
+ text_input_ids = self.tokenizer(
+ prompt, max_length=self.tokenizer.model_max_length,
+ padding="max_length", truncation=True, return_tensors="pt"
+ ).input_ids
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(self.text_encoder.device),
+ )[0]
+ return prompt_embeds
+
+ def forward(self, c_t, positive_prompt=None, negative_prompt=None, args=None):
+ caption_enc = self.encode_prompt(positive_prompt)
+ neg_caption_enc = self.encode_prompt(negative_prompt)
+ encoded_control = self.vae.encode(c_t).latent_dist.sample() * self.vae.config.scaling_factor
+ noise = torch.randn_like(encoded_control)
+ encoded_control = self.sched2.add_noise(encoded_control, noise, self.timestepsnoise)
+
+ model_pred = self.unet(encoded_control, self.timesteps, encoder_hidden_states=caption_enc.to(torch.float32),).sample
+ x_denoised = self.sched.step(model_pred, self.timesteps, encoded_control, return_dict=True).prev_sample
+ output_image = self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample
+ output_image = output_image.clamp(-1, 1)
+
+ return output_image, x_denoised, caption_enc, neg_caption_enc, noise
+
+ def save_model(self, outf):
+ sd = {}
+ sd["vae_lora_encoder_modules"], sd["vae_lora_decoder_modules"], sd["vae_lora_others_modules"] =\
+ self.lora_vae_modules_encoder, self.lora_vae_modules_decoder, self.lora_vae_others
+ sd["unet_lora_encoder_modules"], sd["unet_lora_decoder_modules"], sd["unet_lora_others_modules"] =\
+ self.lora_unet_modules_encoder, self.lora_unet_modules_decoder, self.lora_unet_others
+ sd["rank_unet"] = self.lora_rank_unet
+ sd["rank_vae"] = self.lora_rank_vae
+ sd["state_dict_unet"] = {k: v for k, v in self.unet.state_dict().items() if "lora" in k or "conv_in" in k}
+ sd["state_dict_vae"] = {k: v for k, v in self.vae.state_dict().items() if "lora" in k or "skip" in k}
+ torch.save(sd, outf)
+
+ def load_ckpt_from_state_dict(self, sd):
+ # load unet lora
+ lora_conf_encoder = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_encoder_modules"])
+ lora_conf_decoder = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_decoder_modules"])
+ lora_conf_others = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_others_modules"])
+ self.unet.add_adapter(lora_conf_encoder, adapter_name="default_encoder")
+ self.unet.add_adapter(lora_conf_decoder, adapter_name="default_decoder")
+ self.unet.add_adapter(lora_conf_others, adapter_name="default_others")
+ for n, p in self.unet.named_parameters():
+ if "lora" in n or "conv_in" in n:
+ p.data.copy_(sd["state_dict_unet"][n])
+
+ # load vae lora
+ vae_lora_conf_encoder = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", target_modules=sd["vae_lora_encoder_modules"])
+ vae_lora_conf_decoder = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", target_modules=sd["vae_lora_decoder_modules"])
+ self.vae.add_adapter(vae_lora_conf_encoder, adapter_name="default_encoder")
+ self.vae.add_adapter(vae_lora_conf_decoder, adapter_name="default_decoder")
+ for n, p in self.vae.named_parameters():
+ if "lora" in n:
+ p.data.copy_(sd["state_dict_vae"][n])
+
+class GDPOSR(torch.nn.Module):
+ def __init__(self, args):
+ super().__init__()
+
+ self.tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
+ self.text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder").cuda()
+ self.sched = make_1step_sched(args.pretrained_model_name_or_path)
+ self.sched2 = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ self.args = args
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+ vae = AutoencoderKL.from_pretrained(args.basemodel_path, subfolder="vae")
+ unet = UNet2DConditionModel.from_pretrained(args.basemodel_path, subfolder="unet")
+ ref_unet = UNet2DConditionModel.from_pretrained(args.basemodel_path, subfolder="unet")
+
+ if args.pretrained_path is None:
+ print('==================================> randomly initiate the weight')
+ unet, lora_unet_modules_encoder, lora_unet_modules_decoder, lora_unet_others =\
+ initialize_unet_sr(rank=args.lora_rank_unet, pretrained_model_name_or_path=args.basemodel_path, return_lora_module_names=True, args=args)
+ self.lora_rank_unet = args.lora_rank_unet
+ self.lora_unet_modules_encoder, self.lora_unet_modules_decoder, self.lora_unet_others = \
+ lora_unet_modules_encoder, lora_unet_modules_decoder, lora_unet_others
+
+ self.unet, self.vae = unet, vae
+
+ if args.pretrained_path is not None:
+ print('==================================> loading pre-trained weight')
+ sd = torch.load(args.pretrained_path)
+ self.load_ckpt_from_state_dict(sd)
+ self.lora_rank_unet = sd['rank_unet']
+ self.lora_unet_modules_encoder, self.lora_unet_modules_decoder, self.lora_unet_others = \
+ sd['unet_lora_encoder_modules'], sd['unet_lora_decoder_modules'], sd['unet_lora_others_modules']
+
+ self.unet, self.vae = self.unet.cuda(), self.vae.cuda()
+ self.ref_unet = ref_unet.cuda()
+ self.timesteps = torch.tensor([args.time_step], device="cuda").long()
+ self.timestepsnoise = torch.tensor([args.time_step_noise], device="cuda").long()
+ self.text_encoder.requires_grad_(False)
+
+ def set_eval(self):
+ self.unet.eval()
+ self.vae.eval()
+ self.ref_unet.eval()
+ self.unet.requires_grad_(False)
+ self.vae.requires_grad_(False)
+ self.ref_unet.requires_grad_(False)
+
+ def set_train(self):
+ self.unet.train()
+ self.vae.train()
+ for n, _p in self.unet.named_parameters():
+ if "lora" in n:
+ _p.requires_grad = True
+ for n, _p in self.ref_unet.named_parameters():
+ _p.requires_grad = False
+
+ def encode_prompt(self, prompt):
+ with torch.no_grad():
+ text_input_ids = self.tokenizer(
+ prompt, max_length=self.tokenizer.model_max_length,
+ padding="max_length", truncation=True, return_tensors="pt"
+ ).input_ids
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(self.text_encoder.device),
+ )[0]
+ return prompt_embeds
+
+ def forward(self, c_t, positive_prompt=[''], negative_prompt=[''], args=None):
+ caption_enc = self.encode_prompt(positive_prompt)
+ neg_caption_enc = self.encode_prompt(negative_prompt)
+ with torch.no_grad():
+ encoded_control = self.vae.encode(c_t).latent_dist.sample() * self.vae.config.scaling_factor
+ encoded_control_ref = encoded_control
+ noise = torch.randn_like(encoded_control)
+ encoded_control = self.sched2.add_noise(encoded_control, noise, self.timestepsnoise)
+
+ model_pred = self.unet(encoded_control, self.timesteps, encoder_hidden_states=caption_enc.to(torch.float32),).sample
+ x_denoised = self.sched.step(model_pred, self.timesteps, encoded_control, return_dict=True).prev_sample
+ output_image = self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample
+ output_image = output_image.clamp(-1, 1)
+
+ with torch.no_grad():
+ encoded_control_ref = self.sched2.add_noise(encoded_control_ref, noise, self.timestepsnoise)
+ ref_model_pred = self.ref_unet(encoded_control_ref, self.timesteps, encoder_hidden_states=caption_enc.to(torch.float32),).sample
+ ref_x_denoised = self.sched.step(ref_model_pred, self.timesteps, encoded_control_ref, return_dict=True).prev_sample
+ ref_output_image = self.vae.decode(ref_x_denoised / self.vae.config.scaling_factor).sample
+ ref_output_image = ref_output_image.clamp(-1, 1)
+
+ return output_image, x_denoised, model_pred, caption_enc, neg_caption_enc, noise, ref_output_image, ref_x_denoised, ref_model_pred
+
+ def GDPOReference(self, c_t, positive_prompt=[''], negative_prompt=[''], args=None, groupsize=6):
+
+ with torch.no_grad():
+
+ caption_enc = self.encode_prompt(positive_prompt).unsqueeze(1)
+ encoded_control = self.vae.encode(c_t).latent_dist.sample() * self.vae.config.scaling_factor
+ b,c,h,w=encoded_control.shape
+ encoded_control = encoded_control.unsqueeze(1)
+ caption_enc = caption_enc.repeat(1,groupsize,1,1)
+ encoded_control = encoded_control.repeat(1, groupsize, 1, 1, 1)
+ noise = torch.randn_like(encoded_control)
+ output_image = torch.zeros_like(c_t).unsqueeze(1).repeat(1,groupsize,1,1,1)
+ x_denoised = torch.zeros_like(noise)
+ model_pred = torch.zeros_like(noise)
+ for i in range(b):
+ encoded_control_i = self.sched2.add_noise(encoded_control[i], noise[i], self.timestepsnoise)
+ # print(encoded_control.shape, caption_enc.shape, self.timesteps.shape)
+ model_pred_i = self.ref_unet(encoded_control_i, self.timesteps, encoder_hidden_states=caption_enc[i],).sample
+ x_denoised_i = self.sched.step(model_pred_i, self.timesteps, encoded_control_i, return_dict=True).prev_sample
+ output_image_i = self.vae.decode(x_denoised_i / self.vae.config.scaling_factor).sample
+ output_image_i = output_image_i.clamp(-1, 1)
+ output_image[i] = output_image_i
+ x_denoised[i] = x_denoised_i
+ model_pred[i] = model_pred_i
+
+ return output_image, x_denoised, model_pred
+
+ def save_model(self, outf):
+ sd = {}
+ sd["unet_lora_encoder_modules"], sd["unet_lora_decoder_modules"], sd["unet_lora_others_modules"] =\
+ self.lora_unet_modules_encoder, self.lora_unet_modules_decoder, self.lora_unet_others
+ sd["rank_unet"] = self.lora_rank_unet
+ sd["state_dict_unet"] = {k: v for k, v in self.unet.state_dict().items() if "lora" in k or "conv_in" in k}
+ torch.save(sd, outf)
+
+ def load_ckpt_from_state_dict(self, sd):
+ # load unet lora
+ lora_conf_encoder = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_encoder_modules"])
+ lora_conf_decoder = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_decoder_modules"])
+ lora_conf_others = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_others_modules"])
+ self.unet.add_adapter(lora_conf_encoder, adapter_name="default_encoder")
+ self.unet.add_adapter(lora_conf_decoder, adapter_name="default_decoder")
+ self.unet.add_adapter(lora_conf_others, adapter_name="default_others")
+ for n, p in self.unet.named_parameters():
+ if "lora" in n or "conv_in" in n:
+ p.data.copy_(sd["state_dict_unet"][n])
+
+class GDPOSRTest(torch.nn.Module):
+ def __init__(self, args):
+ super().__init__()
+
+ self.tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
+ self.text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder").cuda()
+ self.sched = make_1step_sched(args.pretrained_model_name_or_path)
+ self.sched2 = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ self.args = args
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+ vae = AutoencoderKL.from_pretrained(args.pretrained_path, subfolder="vae")
+ unet = UNet2DConditionModel.from_pretrained(args.pretrained_path, subfolder="unet")
+
+ self.unet, self.vae = unet, vae
+ self.unet, self.vae = self.unet.cuda(), self.vae.cuda()
+ self.timesteps = torch.tensor([args.time_step], device="cuda").long()
+ self.timestepsnoise = torch.tensor([args.time_step_noise], device="cuda").long()
+ self.text_encoder.requires_grad_(False)
+
+ def set_eval(self):
+ self.unet.eval()
+ self.vae.eval()
+ self.unet.requires_grad_(False)
+ self.vae.requires_grad_(False)
+
+ def encode_prompt(self, prompt):
+ with torch.no_grad():
+ text_input_ids = self.tokenizer(
+ prompt, max_length=self.tokenizer.model_max_length,
+ padding="max_length", truncation=True, return_tensors="pt"
+ ).input_ids
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(self.text_encoder.device),
+ )[0]
+ return prompt_embeds
+
+ def forward(self, c_t, positive_prompt=['']):
+
+ caption_enc = self.encode_prompt(positive_prompt)
+ encoded_control = self.vae.encode(c_t).latent_dist.sample() * self.vae.config.scaling_factor
+ noise = torch.randn_like(encoded_control)
+ encoded_control = self.sched2.add_noise(encoded_control, noise, self.timestepsnoise)
+
+ model_pred = self.unet(encoded_control, self.timesteps, encoder_hidden_states=caption_enc.to(torch.float32),).sample
+ x_denoised = self.sched.step(model_pred, self.timesteps, encoded_control, return_dict=True).prev_sample
+ output_image = self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample
+ output_image = output_image.clamp(-1, 1)
+
+
+ return output_image
diff --git a/GDPOSR/modelfile/__pycache__/GDPOSR.cpython-310.pyc b/GDPOSR/modelfile/__pycache__/GDPOSR.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ea62ea6f42e269a6bfbc5551e4c4937413c9e5e4
Binary files /dev/null and b/GDPOSR/modelfile/__pycache__/GDPOSR.cpython-310.pyc differ
diff --git a/GDPOSR/my_utils/__pycache__/mask.cpython-310.pyc b/GDPOSR/my_utils/__pycache__/mask.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4d66ec3f5c94ab9a31ff5c38fec28e259b8d2bf8
Binary files /dev/null and b/GDPOSR/my_utils/__pycache__/mask.cpython-310.pyc differ
diff --git a/GDPOSR/my_utils/__pycache__/training_utils_realsr.cpython-310.pyc b/GDPOSR/my_utils/__pycache__/training_utils_realsr.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..04a40cc82d25d2bf516f42221bad9e2fb4819671
Binary files /dev/null and b/GDPOSR/my_utils/__pycache__/training_utils_realsr.cpython-310.pyc differ
diff --git a/GDPOSR/my_utils/__pycache__/wavelet_color_fix.cpython-310.pyc b/GDPOSR/my_utils/__pycache__/wavelet_color_fix.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2698bfc72e955915979911d64c324b05a40c5856
Binary files /dev/null and b/GDPOSR/my_utils/__pycache__/wavelet_color_fix.cpython-310.pyc differ
diff --git a/GDPOSR/my_utils/mask.py b/GDPOSR/my_utils/mask.py
new file mode 100644
index 0000000000000000000000000000000000000000..528a374cfec4b6ba0458e591dab931d8b5684412
--- /dev/null
+++ b/GDPOSR/my_utils/mask.py
@@ -0,0 +1,126 @@
+import cv2
+import math
+import torch
+import numpy as np
+import matplotlib.pyplot as plt
+
+def calculate_complexity_degree(img, weight=1.0):
+ try:
+ # img = (img + 1.0) / 2.0 # (0, 1)
+ h, w = img.shape
+ # img = cv2.resize(img, dsize=(256,256))
+ img = cv2.resize(img, dsize=(64,64))
+ sobel_x = cv2.Sobel(img, cv2.CV_64F, 1, 0, ksize=3, borderType=cv2.BORDER_REPLICATE) # (-4, 4)
+ sobel_y = cv2.Sobel(img, cv2.CV_64F, 0, 1, ksize=3, borderType=cv2.BORDER_REPLICATE)
+ edge_magnitude = np.sqrt(sobel_x ** 2 + sobel_y ** 2) # [-4*2**0.5, 4*2**0.5]
+ hist, _ = np.histogram(edge_magnitude, bins=512, range=(-5.5, 5.5))
+ hist = hist / (hist.sum()+1e-5)
+ entropy = -np.sum(hist * np.log2(hist + 1e-10))
+
+ complexity_degree = entropy**weight*h*w
+ except Exception as e:
+ print(f"img: {img.shape}")
+ return complexity_degree
+
+def create_complexity_matrix(gray_img, patch_size=60):
+ """
+ Divide a grayscale image into patches and calculate the complexity of each patch
+ to generate a complexity matrix.
+
+ Parameters:
+ gray_img: Input grayscale image (NumPy array).
+ patch_size: Size of each patch (default: 6x6).
+
+ Returns:
+ complexity_matrix: The complexity matrix (same size as the input image).
+ """
+
+ h, w = gray_img.shape
+ complexity_matrix = np.zeros((h, w))
+
+ rows = h // patch_size
+ cols = w // patch_size
+
+ for i in range(rows):
+ for j in range(cols):
+ patch = gray_img[i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size]
+
+ complexity = calculate_complexity_degree(patch)
+
+ complexity_matrix[i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size] = complexity
+
+ if rows * patch_size < h:
+ for j in range(cols):
+ patch = gray_img[rows*patch_size:, j*patch_size:(j+1)*patch_size]
+ complexity = calculate_complexity_degree(patch)
+ complexity_matrix[rows*patch_size:, j*patch_size:(j+1)*patch_size] = complexity
+
+ if cols * patch_size < w:
+ for i in range(rows):
+ patch = gray_img[i*patch_size:(i+1)*patch_size, cols*patch_size:]
+ complexity = calculate_complexity_degree(patch)
+ complexity_matrix[i*patch_size:(i+1)*patch_size, cols*patch_size:] = complexity
+
+ if rows * patch_size < h and cols * patch_size < w:
+ patch = gray_img[rows*patch_size:, cols*patch_size:]
+ complexity = calculate_complexity_degree(patch)
+ complexity_matrix[rows*patch_size:, cols*patch_size:] = complexity
+
+ return complexity_matrix
+
+def binarize_complexity_matrix(complexity_matrix, threshold=50):
+ """
+ Binarize the complexity matrix.
+
+ Parameters:
+ complexity_matrix: The complexity matrix.
+ threshold: The threshold value. Elements greater than this value are set to 1,
+ and elements less than or equal to it are set to 0.
+
+ Returns:
+ binary_matrix: The binarized matrix.
+ fidelity_zero_ratio: The proportion of the fidelity region (ratio of zeros).
+ detail_one_ratio: The proportion of the detail region (ratio of ones).
+ """
+ binary_matrix = np.zeros_like(complexity_matrix, dtype=np.uint8)
+ binary_matrix[complexity_matrix > threshold] = 1
+
+ unique_values = np.unique(binary_matrix)
+ if not np.all(np.isin(unique_values, [0, 1])):
+ raise ValueError("The input matrix must be a binary matrix containing only 0 and 1.")
+
+ total_elements = binary_matrix.size
+
+ zero_count = np.count_nonzero(binary_matrix == 0)
+ one_count = np.count_nonzero(binary_matrix == 1)
+
+ fedilty_zero_ratio = round((zero_count / total_elements), 2)
+ detail_one_ratio = round((one_count / total_elements), 2)
+
+ return binary_matrix, fedilty_zero_ratio, detail_one_ratio
+
+def extract_and_dilate_edges(gray, threshold1=100, threshold2=200, dilation_size=3, downscale_factor=8):
+
+ edges = cv2.Canny(gray, threshold1, threshold2)
+ kernel = np.ones((dilation_size, dilation_size), np.uint8)
+ dilated_edges = cv2.dilate(edges, kernel, iterations=1)
+
+ h, w = dilated_edges.shape
+ new_h, new_w = h // downscale_factor, w // downscale_factor
+ downsampled_edges = cv2.resize(dilated_edges, (new_w, new_h), interpolation=cv2.INTER_AREA)
+ downsampled_edges = downsampled_edges/255.0
+
+ _, downsampled_edges_mask = cv2.threshold(downsampled_edges, 0.498, 1.0, cv2.THRESH_BINARY)
+
+ return downsampled_edges_mask
+
+if __name__ == '__main__':
+ img = cv2.imread("DrealSR/test_HR/DSC_1412_x1.png")
+ gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ gray_img = gray_img/255
+
+ complexity_matrix = create_complexity_matrix(gray_img, patch_size=10)
+ print(complexity_matrix.shape)
+ binary_matrix, fedilty_zero_ratio, detail_one_ratio = binarize_complexity_matrix(complexity_matrix, threshold=50)
+ print(fedilty_zero_ratio, detail_one_ratio)
+ print(binary_matrix.shape)
\ No newline at end of file
diff --git a/GDPOSR/my_utils/training_utils_realsr.py b/GDPOSR/my_utils/training_utils_realsr.py
new file mode 100644
index 0000000000000000000000000000000000000000..1469d0cea8902ebc82a22f460aceeb42c8ceb712
--- /dev/null
+++ b/GDPOSR/my_utils/training_utils_realsr.py
@@ -0,0 +1,305 @@
+import os
+import cv2
+import random
+import argparse
+import json
+import glob
+import torch
+import numpy as np
+from PIL import Image
+from torchvision import transforms
+import torchvision.transforms.functional as F
+
+import sys
+sys.path.append('./')
+from GDPOSR.my_utils.mask import create_complexity_matrix, binarize_complexity_matrix, extract_and_dilate_edges
+from GDPOSR.datasets.realesrgan import RealESRGAN_degradation
+
+
+def parse_args_realsr_training(input_args=None):
+ """
+ Parses command-line arguments used for configuring an paired session (pix2pix-Turbo).
+ This function sets up an argument parser to handle various training options.
+
+ Returns:
+ argparse.Namespace: The parsed command-line arguments.
+ """
+ parser = argparse.ArgumentParser()
+ # args for grpo training
+ parser.add_argument("--groupsize", default=6, type=int)
+ parser.add_argument("--time_min", default=150, type=int)
+ parser.add_argument("--time_max", default=350, type=int)
+ parser.add_argument("--updatestep", default=4000, type=int)
+ parser.add_argument("--patchsize", default=125, type=int)
+ parser.add_argument("--beta_dpo", default=0.25, type=float)
+ parser.add_argument("--klloss", default=1.0, type=float)
+ parser.add_argument("--grpoloss", default=1.0, type=float)
+ # args for the vsd training
+ parser.add_argument("--positive_prompt", type=str, default='')
+ parser.add_argument("--negative_prompt", type=str, default='')
+ parser.add_argument("--lambda_vsd", default=1.0, type=float)
+ parser.add_argument("--lambda_vsd_lora", default=1.0, type=float)
+ parser.add_argument("--lambda_klloss", default=0.0, type=float)
+ parser.add_argument("--min_dm_step_ratio", default=0.02, type=float)
+ parser.add_argument("--max_dm_step_ratio", default=0.98, type=float)
+ parser.add_argument("--cfg_vsd", default=7.5, type=float)
+ parser.add_argument("--cfg_csd", default=7.5, type=float)
+ parser.add_argument("--snr_gamma_vsd", default=None)
+ parser.add_argument("--lora_rank_unet_vsd", default=8, type=int)
+ parser.add_argument("--pretrained_model_name_or_path_vsd", default='', type=str)
+ parser.add_argument("--basemodel_path", default='', type=str)
+
+ # args for the loss function
+ parser.add_argument("--gan_disc_type", default="vagan_clip")
+ parser.add_argument("--gan_loss_type", default="multilevel_sigmoid_s")
+ parser.add_argument("--lambda_gan", default=0.2, type=float)
+ parser.add_argument("--lambda_lpips", default=2, type=float)
+ parser.add_argument("--lambda_l2", default=1.0, type=float)
+
+ # dataset options
+ parser.add_argument("--dataset_folder", default='', type=str)
+ parser.add_argument("--testdataset_folder", default='', type=str)
+ parser.add_argument("--train_image_prep", default="resized_crop_512", type=str)
+ parser.add_argument("--test_image_prep", default="resized_crop_512", type=str)
+ parser.add_argument("--null_text_ratio", default=1., type=float)
+
+ # validation eval args
+ parser.add_argument("--eval_freq", default=500, type=int)
+ parser.add_argument("--track_val_fid", default=False, action="store_true")
+ parser.add_argument("--num_samples_eval", type=int, default=100, help="Number of samples to use for all evaluation")
+
+ parser.add_argument("--viz_freq", type=int, default=100, help="Frequency of visualizing the outputs.")
+ parser.add_argument("--tracker_project_name", type=str, default="train_pix2pix_turbo", help="The name of the wandb project to log to.")
+ parser.add_argument('--tiled_size', type=int, default=768)
+ parser.add_argument('--tiled_overlap', type=int, default=256)
+
+
+ # details about the model architecture
+ parser.add_argument("--pretrained_model_name_or_path", default='', type=str)
+ parser.add_argument("--revision", type=str, default=None,)
+ parser.add_argument("--variant", type=str, default=None,)
+ parser.add_argument("--cliptextmodule", type=str, default=None,)
+ parser.add_argument("--upsampler", type=str, default=None,)
+ parser.add_argument("--tokenizer_name", type=str, default=None)
+ parser.add_argument("--lora_rank_unet", default=8, type=int)
+ parser.add_argument("--lora_rank_unet2", default=0, type=int)
+ parser.add_argument("--lora_rank_vae", default=4, type=int)
+ parser.add_argument("--time_step", default=999, type=int)
+ parser.add_argument("--time_step_noise", default=250, type=int)
+ parser.add_argument("--pretrained_path", default=None, type=str)
+ parser.add_argument("--pretrained_unet_path", default=None, type=str)
+ parser.add_argument("--pretrained_vae_path", default=None, type=str)
+ parser.add_argument("--stage2", default=None, type=str)
+ parser.add_argument("--stage3", default=None, type=str)
+
+
+ # training details
+ parser.add_argument("--output_dir", default='experience/OSSR_vaeEcLora_ntr1_vsd_ntr0_nostage_clip_test')
+ parser.add_argument("--cache_dir", default=None,)
+ parser.add_argument("--seed", type=int, default=123, help="A seed for reproducible training.")
+ parser.add_argument("--resolution", type=int, default=512,)
+ parser.add_argument("--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader.")
+ parser.add_argument("--num_training_epochs", type=int, default=10)
+ parser.add_argument("--max_train_steps", type=int, default=10_000,)
+ parser.add_argument("--checkpointing_steps", type=int, default=500,)
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=2, help="Number of updates steps to accumulate before performing a backward/update pass.",)
+ parser.add_argument("--gradient_checkpointing", action="store_true",)
+ parser.add_argument("--learning_rate", type=float, default=5e-5)
+ parser.add_argument("--lr_scheduler", type=str, default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument("--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler.")
+ parser.add_argument("--lr_num_cycles", type=int, default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+
+ parser.add_argument("--dataloader_num_workers", type=int, default=0,)
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--ema_decay", type=float, default=0.999, help="EMA decay rate for model parameters.")
+ parser.add_argument("--allow_tf32", action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument("--report_to", type=str, default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["no", "fp16", "bf16"],)
+ parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.")
+ parser.add_argument("--set_grads_to_none", action="store_true",)
+
+ parser.add_argument("--logging_dir", type=str, default="logs")
+ parser.add_argument("--use_online_deg", action="store_true",)
+ parser.add_argument("--deg_file_path", default="params_pasd.yml", type=str)
+ parser.add_argument("--align_method", type=str, choices=['wavelet', 'adain', 'nofix'], default='adain')
+
+ # vae lora
+ parser.add_argument("--use_vae_encode_lora", action="store_true",)
+ parser.add_argument("--use_vae_decode_lora", action="store_true",)
+
+ # use_lr_999noise
+ parser.add_argument("--use_lr_999noise", action="store_true",)
+ parser.add_argument("--use_lr_concat_lr_999noise", action="store_true",)
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ # # only for debug
+ # args.enable_xformers_memory_efficient_attention = True
+ # args.use_online_deg = True
+ # args.use_lr_concat_lr_999noise = True
+
+ return args
+
+
+def build_transform(image_prep):
+ """
+ Constructs a transformation pipeline based on the specified image preparation method.
+
+ Parameters:
+ - image_prep (str): A string describing the desired image preparation
+
+ Returns:
+ - torchvision.transforms.Compose: A composable sequence of transformations to be applied to images.
+ """
+ if image_prep == "resized_crop_512":
+ T = transforms.Compose([
+ transforms.Resize(512, interpolation=transforms.InterpolationMode.LANCZOS),
+ transforms.CenterCrop(512),
+ ])
+ elif image_prep == "resize_286_randomcrop_256x256_hflip":
+ T = transforms.Compose([
+ transforms.Resize((286, 286), interpolation=Image.LANCZOS),
+ transforms.RandomCrop((256, 256)),
+ transforms.RandomHorizontalFlip(),
+ ])
+ elif image_prep in ["resize_256", "resize_256x256"]:
+ T = transforms.Compose([
+ transforms.Resize((256, 256), interpolation=Image.LANCZOS)
+ ])
+ elif image_prep in ["resize_512", "resize_512x512"]:
+ T = transforms.Compose([
+ transforms.Resize((512, 512), interpolation=Image.LANCZOS)
+ ])
+ elif image_prep == "no_resize":
+ T = transforms.Lambda(lambda x: x)
+ return T
+
+
+class PairedSROnlineDataset(torch.utils.data.Dataset):
+ def __init__(self, dataset_folder, split, image_prep, deg_file_path=None, image_size=512, args=None):
+ super().__init__()
+ self.split = split
+ self.args = args
+ clip_mean = [0.48145466, 0.4578275, 0.40821073]
+ clip_std = [0.26862954, 0.26130258, 0.27577711]
+ self.clip_normalize = transforms.Normalize(mean=clip_mean, std=clip_std)
+
+ if split == 'train':
+ self.gt_folder = os.path.join(dataset_folder, "gt")
+ self.gt_list = []
+ self.gt_list += glob.glob(os.path.join(self.gt_folder, '*.png'))
+
+ self.T = build_transform(image_prep)
+ self.split = split
+
+ self.degradation = RealESRGAN_degradation(deg_file_path, device='cpu')
+ self.crop_preproc = transforms.Compose([
+ transforms.RandomCrop(image_size),
+ transforms.RandomHorizontalFlip(),
+ ])
+ elif split == 'test':
+ dataset_folder = args.testdataset_folder
+ self.input_folder = os.path.join(dataset_folder, "test_SR_bicubic")
+ self.output_folder = os.path.join(dataset_folder, "test_HR")
+
+ self.lr_list = []
+ self.gt_list = []
+ self.lr_list += glob.glob(os.path.join(self.input_folder, '*.png'))
+ self.gt_list += glob.glob(os.path.join(self.output_folder, '*.png'))
+
+ self.T = build_transform(image_prep)
+ self.split = split
+ assert len(self.lr_list) == len(self.gt_list)
+
+ def __len__(self):
+ return len(self.gt_list)
+
+ def __getitem__(self, idx):
+
+ if self.split == 'train':
+ gt_img = Image.open(self.gt_list[idx]).convert('RGB')
+ gt_img = self.crop_preproc(gt_img)
+
+ output_t, img_t, img_t_noresize = self.degradation.degrade_process(np.asarray(gt_img)/255., resize_bak=True)
+ output_t_0 = output_t
+ output_t, img_t, img_t_noresize = output_t.squeeze(0), img_t.squeeze(0), img_t_noresize.squeeze(0)
+
+ img_t = F.normalize(img_t, mean=[0.5], std=[0.5])
+ img_t_noresize = F.normalize(img_t_noresize, mean=[0.5], std=[0.5])
+ # output images scaled to -1,1
+ output_t = F.normalize(output_t, mean=[0.5], std=[0.5])
+
+ #
+ output_t_0 = output_t_0.permute(0,2,3,1).contiguous()
+ gray_gt1 = 255 * output_t_0.squeeze(0).cpu().numpy()
+ gray_gt1 = gray_gt1.astype(np.uint8)
+ gray_gt_img_org = cv2.cvtColor(gray_gt1, cv2.COLOR_BGR2GRAY)
+ # gray_gt_img_org = cv2.cvtColor(cv2.imread(self.gt_list[idx]), cv2.COLOR_BGR2GRAY)
+ gray_gt_img = gray_gt_img_org/255
+ complexity_matrix = create_complexity_matrix(gray_gt_img, patch_size=10)
+ binary_matrix, fedilty_zero_ratio, detail_one_ratio = binarize_complexity_matrix(complexity_matrix, threshold=50)
+ downsampled_edges_mask = extract_and_dilate_edges(gray_gt_img_org, threshold1=100, threshold2=200, dilation_size=3, downscale_factor=8)
+ complexity_matrix = torch.tensor(complexity_matrix).unsqueeze(0)
+ binary_matrix = torch.tensor(binary_matrix).unsqueeze(0)
+ fedilty_zero_ratio = torch.tensor(fedilty_zero_ratio)
+ detail_one_ratio = torch.tensor(detail_one_ratio)
+ downsampled_edges_mask = torch.tensor(downsampled_edges_mask)
+
+ return {
+ "HR": output_t,
+ "LR": img_t,
+ "negative_prompt": self.args.negative_prompt,
+ 'fedilty_ratio': fedilty_zero_ratio,
+ 'detail_ratio': detail_one_ratio,
+ }
+
+ elif self.split == 'test':
+ input_img = Image.open(self.lr_list[idx]).convert('RGB')
+ input_img_noresize = Image.open(self.gt_list[idx].replace('test_HR/','test_LR/')).convert('RGB')
+ output_img = Image.open(self.gt_list[idx]).convert('RGB')
+
+ # input images scaled to -1, 1
+ img_t = self.T(input_img)
+ img_t = F.to_tensor(img_t)
+
+ img_t_noresize = self.T(input_img_noresize)
+ img_t_noresize = F.to_tensor(img_t_noresize)
+
+ img_t = F.normalize(img_t, mean=[0.5], std=[0.5])
+ img_t_noresize = F.normalize(img_t_noresize, mean=[0.5], std=[0.5])
+ # output images scaled to -1,1
+ output_t = self.T(output_img)
+ output_t = F.to_tensor(output_t)
+ output_t = F.normalize(output_t, mean=[0.5], std=[0.5])
+
+ return {
+ "HR": output_t,
+ "LR": img_t,
+ "negative_prompt": self.args.negative_prompt,
+ "base_name": os.path.basename(self.lr_list[idx]),
+ }
\ No newline at end of file
diff --git a/GDPOSR/my_utils/wavelet_color_fix.py b/GDPOSR/my_utils/wavelet_color_fix.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa741ea9033fb9adf016c486268a4504514e8a1c
--- /dev/null
+++ b/GDPOSR/my_utils/wavelet_color_fix.py
@@ -0,0 +1,119 @@
+'''
+# --------------------------------------------------------------------------------
+# Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py)
+# --------------------------------------------------------------------------------
+'''
+
+import torch
+from PIL import Image
+from torch import Tensor
+from torch.nn import functional as F
+
+from torchvision.transforms import ToTensor, ToPILImage
+
+def adain_color_fix(target: Image, source: Image):
+ # Convert images to tensors
+ to_tensor = ToTensor()
+ target_tensor = to_tensor(target).unsqueeze(0)
+ source_tensor = to_tensor(source).unsqueeze(0)
+
+ # Apply adaptive instance normalization
+ result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)
+
+ # Convert tensor back to image
+ to_image = ToPILImage()
+ result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
+
+ return result_image
+
+def wavelet_color_fix(target: Image, source: Image):
+ # Convert images to tensors
+ to_tensor = ToTensor()
+ target_tensor = to_tensor(target).unsqueeze(0)
+ source_tensor = to_tensor(source).unsqueeze(0)
+
+ # Apply wavelet reconstruction
+ result_tensor = wavelet_reconstruction(target_tensor, source_tensor)
+
+ # Convert tensor back to image
+ to_image = ToPILImage()
+ result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
+
+ return result_image
+
+def calc_mean_std(feat: Tensor, eps=1e-5):
+ """Calculate mean and std for adaptive_instance_normalization.
+ Args:
+ feat (Tensor): 4D tensor.
+ eps (float): A small value added to the variance to avoid
+ divide-by-zero. Default: 1e-5.
+ """
+ size = feat.size()
+ assert len(size) == 4, 'The input feature should be 4D tensor.'
+ b, c = size[:2]
+ feat_var = feat.reshape(b, c, -1).var(dim=2) + eps
+ feat_std = feat_var.sqrt().reshape(b, c, 1, 1)
+ feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1)
+ return feat_mean, feat_std
+
+def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
+ """Adaptive instance normalization.
+ Adjust the reference features to have the similar color and illuminations
+ as those in the degradate features.
+ Args:
+ content_feat (Tensor): The reference feature.
+ style_feat (Tensor): The degradate features.
+ """
+ size = content_feat.size()
+ style_mean, style_std = calc_mean_std(style_feat)
+ content_mean, content_std = calc_mean_std(content_feat)
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
+
+def wavelet_blur(image: Tensor, radius: int):
+ """
+ Apply wavelet blur to the input tensor.
+ """
+ # input shape: (1, 3, H, W)
+ # convolution kernel
+ kernel_vals = [
+ [0.0625, 0.125, 0.0625],
+ [0.125, 0.25, 0.125],
+ [0.0625, 0.125, 0.0625],
+ ]
+ kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
+ # add channel dimensions to the kernel to make it a 4D tensor
+ kernel = kernel[None, None]
+ # repeat the kernel across all input channels
+ kernel = kernel.repeat(3, 1, 1, 1)
+ image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
+ # apply convolution
+ output = F.conv2d(image, kernel, groups=3, dilation=radius)
+ return output
+
+def wavelet_decomposition(image: Tensor, levels=5):
+ """
+ Apply wavelet decomposition to the input tensor.
+ This function only returns the low frequency & the high frequency.
+ """
+ high_freq = torch.zeros_like(image)
+ for i in range(levels):
+ radius = 2 ** i
+ low_freq = wavelet_blur(image, radius)
+ high_freq += (image - low_freq)
+ image = low_freq
+
+ return high_freq, low_freq
+
+def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
+ """
+ Apply wavelet decomposition, so that the content will have the same color as the style.
+ """
+ # calculate the wavelet decomposition of the content feature
+ content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
+ del content_low_freq
+ # calculate the wavelet decomposition of the style feature
+ style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
+ del style_high_freq
+ # reconstruct the content feature with the style's high frequency
+ return content_high_freq + style_low_freq
diff --git a/GDPOSR/train/train_GDPOSR.py b/GDPOSR/train/train_GDPOSR.py
new file mode 100644
index 0000000000000000000000000000000000000000..aae3731dfcf4d40ad946268ea079051b2b2676cb
--- /dev/null
+++ b/GDPOSR/train/train_GDPOSR.py
@@ -0,0 +1,233 @@
+import os
+import gc
+import lpips
+import clip
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.utils import set_seed
+from PIL import Image
+from torchvision import transforms
+from tqdm.auto import tqdm
+import copy
+
+import diffusers
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.optimization import get_scheduler
+
+import wandb
+from cleanfid.fid import get_folder_features, build_feature_extractor, fid_from_feats
+import sys
+sys.path.append("GDPOSR")
+from modelfile.GDPOSR import GDPOSR as GDPOSRModel
+from my_utils.training_utils_realsr import parse_args_realsr_training, PairedSROnlineDataset
+
+from pathlib import Path
+from accelerate.utils import set_seed, ProjectConfiguration
+from accelerate import DistributedDataParallelKwargs
+
+sys.path.append('GDPOSR')
+from GDPOSR.my_utils.wavelet_color_fix import adain_color_fix, wavelet_color_fix
+from diffusers.training_utils import compute_snr
+from diffusers import DDPMScheduler, AutoencoderKL
+from GDPOSR.losses.grpo import AdaptiveReward as RewardFunction
+
+from ram.models.ram_lora import ram
+from ram import inference_ram as inference
+
+
+def main(args):
+ logging_dir = Path(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ kwargs_handlers=[ddp_kwargs],
+ )
+
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ if accelerator.is_main_process:
+ os.makedirs(os.path.join(args.output_dir, "checkpoints"), exist_ok=True)
+ os.makedirs(os.path.join(args.output_dir, "eval"), exist_ok=True)
+
+ net_pix2pix = GDPOSRModel(args)
+ net_pix2pix.set_train()
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ net_pix2pix.unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available, please install it by running `pip install xformers`")
+
+ if args.gradient_checkpointing:
+ net_pix2pix.unet.enable_gradient_checkpointing()
+
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ net_lpips = lpips.LPIPS(net='vgg').cuda()
+ net_lpips.requires_grad_(False)
+ net_ARF = RewardFunction()
+ net_ARF.requires_grad_(False)
+
+ # # set adapter
+ net_pix2pix.unet.set_adapter(['default_encoder', 'default_decoder', 'default_others'])
+
+ # make the optimizer
+ layers_to_opt = []
+ for n, _p in net_pix2pix.unet.named_parameters():
+ if "lora" in n:
+ assert _p.requires_grad
+ layers_to_opt.append(_p)
+
+ optimizer = torch.optim.AdamW(layers_to_opt, lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,)
+ lr_scheduler = get_scheduler(args.lr_scheduler, optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles, power=args.lr_power,)
+
+ # make the dataloader
+ dataset_train = PairedSROnlineDataset(dataset_folder=args.dataset_folder, image_prep=args.train_image_prep, split="train", deg_file_path=args.deg_file_path, args=args)
+ dataset_val = PairedSROnlineDataset(dataset_folder=args.dataset_folder, image_prep=args.test_image_prep, split="test", deg_file_path=args.deg_file_path, args=args)
+ dl_train = torch.utils.data.DataLoader(dataset_train, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers)
+ dl_val = torch.utils.data.DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=0)
+
+ # init RAM
+ ram_transforms = transforms.Compose([
+ transforms.Resize((384, 384)),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ ])
+ RAM = ram(pretrained='./ckp/ram_swin_large_14m.pth',
+ pretrained_condition=None,
+ image_size=384,
+ vit='swin_l')
+ RAM.eval()
+ RAM.to("cuda", dtype=torch.float16)
+
+ # Prepare everything with our `accelerator`.
+ net_pix2pix, optimizer, dl_train, lr_scheduler = accelerator.prepare(
+ net_pix2pix, optimizer, dl_train, lr_scheduler
+ )
+ net_lpips, net_ARF = accelerator.prepare(net_lpips, net_ARF)
+ # renorm with image net statistics
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
+
+ progress_bar = tqdm(range(0, args.max_train_steps), initial=0, desc="Steps",
+ disable=not accelerator.is_local_main_process,)
+
+ # start the training loop
+ global_step = 0
+ for epoch in range(0, args.num_training_epochs):
+ for step, batch in enumerate(dl_train):
+ with accelerator.accumulate(net_pix2pix):
+ x_src = batch["LR"]
+ x_tgt = batch["HR"]
+ fedilty_ratio = batch["fedilty_ratio"]
+ detail_ratio = batch["detail_ratio"]
+
+ B, C, H, W = x_src.shape
+ # image description
+ x_tgt_ram = ram_transforms(x_tgt*0.5+0.5)
+ caption_r = inference(x_tgt_ram.to(dtype=torch.float16), RAM)
+ with torch.no_grad():
+ positive_prompt = []
+ negative_prompt = []
+ for i in range(B):
+ ram_image = x_tgt[i,:,:,:].unsqueeze(0)
+ x_tgt_ram = ram_transforms(ram_image*0.5+0.5)
+ caption = inference(x_tgt_ram.to(dtype=torch.float16), RAM)
+ positive_prompt.append(f'{caption[0]}, {args.positive_prompt}')
+ negative_prompt.append(args.negative_prompt)
+ # generate some samples
+ if torch.cuda.device_count() > 1:
+ sample_images, _, _ = net_pix2pix.module.GDPOReference(x_src, positive_prompt=positive_prompt, negative_prompt=negative_prompt, args=args, groupsize=args.groupsize)
+ else:
+ sample_images, _, _ = net_pix2pix.GDPOReference(x_src, positive_prompt=positive_prompt, negative_prompt=negative_prompt, args=args, groupsize=args.groupsize)
+ # select winning and losing samples:
+ x_tgt_re = x_tgt.unsqueeze(1).repeat(1,args.groupsize,1,1,1)
+ rewards = net_ARF(sample_images, x_tgt_re, fedilty_ratio, detail_ratio)
+ rewards = rewards.cuda()
+ b_sample, g_sample, c_sample, h_sample, w_sample = sample_images.shape
+ x_src_wl = sample_images.view(b_sample*g_sample, c_sample, h_sample, w_sample)
+ ps_wl = []
+ nps_wl = []
+ for i in range(args.groupsize):
+ ps_wl += positive_prompt
+ nps_wl += negative_prompt
+ # forward pass
+ x_tgt_pred, latents_pred, model_pred, prompt_embeds, neg_prompt_embeds, noise, ref_output_image, ref_x_denoised, ref_model_pred = net_pix2pix(x_src_wl, positive_prompt=ps_wl, negative_prompt=nps_wl, args=args)
+ # GDPO
+ model_losses = (model_pred - noise).pow(2).mean(dim=[1,2,3])
+ # b_model, c_model, h_model, w_model = model_losses.shape
+ model_losses = model_losses.view(b_sample, g_sample)
+ model_losses = rewards * model_losses
+ model_diff = model_losses.sum(1)
+ # model_losses_w, model_losses_l = model_losses.chunk(2)
+ ref_losses = (ref_model_pred - noise).pow(2).mean(dim=[1,2,3])
+ ref_losses = ref_losses.view(b_sample, g_sample)
+ ref_losses = rewards * ref_losses
+ ref_diff = ref_losses.sum(1)
+ scale_term = -0.5 * 5000
+ inside_term = scale_term * (model_diff - ref_diff)
+ implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0)
+ gdpo_loss = -1 * F.logsigmoid(inside_term).mean()
+ loss = gdpo_loss
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(layers_to_opt, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
+
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ logs = {}
+ # log all the losses
+ logs["loss"] = gdpo_loss.detach().item()
+ progress_bar.set_postfix(**logs)
+
+ # checkpoint the model
+ if global_step % args.checkpointing_steps == 1:
+ outf = os.path.join(args.output_dir, "checkpoints", f"model_{global_step}.pkl")
+ accelerator.unwrap_model(net_pix2pix).save_model(outf)
+
+ accelerator.log(logs, step=global_step)
+
+
+if __name__ == "__main__":
+ args = parse_args_realsr_training()
+ main(args)
diff --git a/GDPOSR/train/train_NAOSD.py b/GDPOSR/train/train_NAOSD.py
new file mode 100644
index 0000000000000000000000000000000000000000..94f2324df9cded348bd214c294b6c745def0e9da
--- /dev/null
+++ b/GDPOSR/train/train_NAOSD.py
@@ -0,0 +1,256 @@
+import os
+import gc
+import lpips
+import clip
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.utils import set_seed
+from PIL import Image
+from torchvision import transforms
+from tqdm.auto import tqdm
+import copy
+
+import diffusers
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.optimization import get_scheduler
+
+import wandb
+from cleanfid.fid import get_folder_features, build_feature_extractor, fid_from_feats
+import sys
+sys.path.append("GDPOSR")
+from modelfile.GDPOSR import VSD, NAOSD
+from my_utils.training_utils_realsr import parse_args_realsr_training, PairedSROnlineDataset
+
+from pathlib import Path
+from accelerate.utils import set_seed, ProjectConfiguration
+from accelerate import DistributedDataParallelKwargs
+
+sys.path.append('GDPOSR')
+from GDPOSR.my_utils.wavelet_color_fix import adain_color_fix, wavelet_color_fix
+from diffusers.training_utils import compute_snr
+from diffusers import DDPMScheduler, AutoencoderKL
+
+from ram.models.ram_lora import ram
+from ram import inference_ram as inference
+
+
+def main(args):
+ logging_dir = Path(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ kwargs_handlers=[ddp_kwargs],
+ )
+
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ if accelerator.is_main_process:
+ os.makedirs(os.path.join(args.output_dir, "checkpoints"), exist_ok=True)
+ os.makedirs(os.path.join(args.output_dir, "eval"), exist_ok=True)
+
+ net_pix2pix = NAOSD(args)
+ net_pix2pix.set_train()
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ net_pix2pix.unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available, please install it by running `pip install xformers`")
+
+ if args.gradient_checkpointing:
+ net_pix2pix.unet.enable_gradient_checkpointing()
+
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ # init vsd model
+ net_disc = VSD(args=args, accelerator=accelerator)
+ net_disc.set_train()
+
+ net_lpips = lpips.LPIPS(net='vgg').cuda()
+ net_lpips.requires_grad_(False)
+
+ # # set adapter
+ if args.use_vae_encode_lora and (not args.use_vae_decode_lora):
+ print('==== Use Lora at VAE Encoder ====')
+ net_pix2pix.vae.set_adapter(['default_encoder'])
+ elif (not args.use_vae_encode_lora) and args.use_vae_decode_lora:
+ print('==== Use Lora at VAE Decoder ====')
+ net_pix2pix.vae.set_adapter(['default_decoder'])
+ elif args.use_vae_encode_lora and args.use_vae_decode_lora:
+ print('==== Use Lora at VAE En&Decoder ====')
+ net_pix2pix.vae.set_adapter(['default_encoder', 'default_decoder'])
+ else:
+ print('==== Use Fix VAE ====')
+ net_pix2pix.vae.disable_adapters()
+ net_pix2pix.unet.set_adapter(['default_encoder', 'default_decoder', 'default_others'])
+
+ # make the optimizer
+ layers_to_opt = []
+ for n, _p in net_pix2pix.unet.named_parameters():
+ if "lora" in n:
+ assert _p.requires_grad
+ layers_to_opt.append(_p)
+ layers_to_opt += list(net_pix2pix.unet.conv_in.parameters())
+ for n, _p in net_pix2pix.vae.named_parameters():
+ if "lora" in n:
+ # assert _p.requires_grad
+ layers_to_opt.append(_p)
+
+ optimizer = torch.optim.AdamW(layers_to_opt, lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,)
+ lr_scheduler = get_scheduler(args.lr_scheduler, optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles, power=args.lr_power,)
+
+ layers_to_opt_disc = []
+ for n, _p in net_disc.unet_update.named_parameters():
+ if "lora" in n:
+ assert _p.requires_grad
+ layers_to_opt_disc.append(_p)
+ optimizer_disc = torch.optim.AdamW(layers_to_opt_disc, lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,)
+ lr_scheduler_disc = get_scheduler(args.lr_scheduler, optimizer=optimizer_disc,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles, power=args.lr_power)
+
+ # make the dataloader
+ dataset_train = PairedSROnlineDataset(dataset_folder=args.dataset_folder, image_prep=args.train_image_prep, split="train", deg_file_path=args.deg_file_path, args=args)
+ dataset_val = PairedSROnlineDataset(dataset_folder=args.dataset_folder, image_prep=args.test_image_prep, split="test", deg_file_path=args.deg_file_path, args=args)
+ dl_train = torch.utils.data.DataLoader(dataset_train, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers)
+ dl_val = torch.utils.data.DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=0)
+
+ # init RAM
+ ram_transforms = transforms.Compose([
+ transforms.Resize((384, 384)),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ ])
+ RAM = ram(pretrained='./ckp/ram_swin_large_14m.pth',
+ pretrained_condition=None,
+ image_size=384,
+ vit='swin_l')
+ RAM.eval()
+ RAM.to("cuda", dtype=torch.float16)
+
+ # Prepare everything with our `accelerator`.
+ net_pix2pix, net_disc, optimizer, optimizer_disc, dl_train, lr_scheduler, lr_scheduler_disc = accelerator.prepare(
+ net_pix2pix, net_disc, optimizer, optimizer_disc, dl_train, lr_scheduler, lr_scheduler_disc
+ )
+ net_lpips = accelerator.prepare(net_lpips)
+ # renorm with image net statistics
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
+
+ progress_bar = tqdm(range(0, args.max_train_steps), initial=0, desc="Steps",
+ disable=not accelerator.is_local_main_process,)
+
+ # start the training loop
+ global_step = 0
+ for epoch in range(0, args.num_training_epochs):
+ for step, batch in enumerate(dl_train):
+ l_acc = [net_pix2pix, net_disc]
+ with accelerator.accumulate(*l_acc):
+ x_src = batch["LR"]
+ x_tgt = batch["HR"]
+ B, C, H, W = x_src.shape
+ # image description
+ x_tgt_ram = ram_transforms(x_tgt*0.5+0.5)
+ caption_r = inference(x_tgt_ram.to(dtype=torch.float16), RAM)
+ with torch.no_grad():
+ positive_prompt = []
+ negative_prompt = []
+ for i in range(B):
+ ram_image = x_tgt[i,:,:,:].unsqueeze(0)
+ x_tgt_ram = ram_transforms(ram_image*0.5+0.5)
+ caption = inference(x_tgt_ram.to(dtype=torch.float16), RAM)
+ positive_prompt.append(f'{caption[0]}, {args.positive_prompt}')
+ negative_prompt.append(args.negative_prompt)
+ # forward pass
+ x_tgt_pred, latents_pred, prompt_embeds, neg_prompt_embeds, noise = net_pix2pix(x_src, positive_prompt=positive_prompt, negative_prompt=negative_prompt, args=args)
+ # Reconstruction loss
+ loss_l2 = F.mse_loss(x_tgt_pred.float(), x_tgt.float(), reduction="mean") * args.lambda_l2
+ loss_lpips = net_lpips(x_tgt_pred.float(), x_tgt.float()).mean() * args.lambda_lpips
+ loss = loss_l2 + loss_lpips
+ # KL loss
+ if torch.cuda.device_count() > 1:
+ loss_kl = net_disc.module.distribution_matching_loss(net_disc.module.unet_fix, net_disc.module.unet_update, net_disc.module.sched, latents_pred, prompt_embeds, neg_prompt_embeds, args, ) * args.lambda_vsd
+ else:
+ loss_kl = net_disc.distribution_matching_loss(net_disc.unet_fix, net_disc.unet_update, net_disc.sched, latents_pred, prompt_embeds, neg_prompt_embeds, args, ) * args.lambda_vsd
+ loss = loss + loss_kl
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(layers_to_opt, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
+
+ """
+ Disc loss: let lora model closed to generator
+ """
+ if torch.cuda.device_count() > 1:
+ loss_d = net_disc.module.compute_lora_loss(latents_pred, prompt_embeds, args)*args.lambda_vsd_lora
+ else:
+ loss_d = net_disc.compute_lora_loss(latents_pred, prompt_embeds, args)*args.lambda_vsd_lora
+ accelerator.backward(loss_d)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(net_disc.parameters(), args.max_grad_norm)
+ optimizer_disc.step()
+ lr_scheduler_disc.step()
+ optimizer_disc.zero_grad(set_to_none=args.set_grads_to_none)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ logs = {}
+ # log all the losses
+ logs["loss_d"] = loss_d.detach().item()
+ logs["loss_kl"] = loss_kl.detach().item()
+ logs["loss_l2"] = loss_l2.detach().item()
+ logs["loss_lpips"] = loss_lpips.detach().item()
+ progress_bar.set_postfix(**logs)
+
+ # checkpoint the model
+ if global_step % args.checkpointing_steps == 1:
+ outf = os.path.join(args.output_dir, "checkpoints", f"model_{global_step}.pkl")
+ accelerator.unwrap_model(net_pix2pix).save_model(outf)
+
+
+ accelerator.log(logs, step=global_step)
+
+
+if __name__ == "__main__":
+ args = parse_args_realsr_training()
+ main(args)
diff --git a/ram/__init__.py b/ram/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..31d536a683944674c92421f6bf7b0d838924b686
--- /dev/null
+++ b/ram/__init__.py
@@ -0,0 +1,2 @@
+from .inference import inference_tag2text, inference_ram, inference_ram_openset
+from .transform import get_transform
diff --git a/ram/configs/condition_config.json b/ram/configs/condition_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..f4670c7d2c9c4007c8e5dffa14ee56f2ddf2d21c
--- /dev/null
+++ b/ram/configs/condition_config.json
@@ -0,0 +1,3 @@
+{
+ "nf": 64
+ }
\ No newline at end of file
diff --git a/ram/configs/med_config.json b/ram/configs/med_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..49d64f890cc38d558c4fd3bab048cc521a69a2be
--- /dev/null
+++ b/ram/configs/med_config.json
@@ -0,0 +1,21 @@
+{
+ "architectures": [
+ "BertModel"
+ ],
+ "attention_probs_dropout_prob": 0.1,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 768,
+ "initializer_range": 0.02,
+ "intermediate_size": 3072,
+ "layer_norm_eps": 1e-12,
+ "max_position_embeddings": 512,
+ "model_type": "bert",
+ "num_attention_heads": 12,
+ "num_hidden_layers": 12,
+ "pad_token_id": 0,
+ "type_vocab_size": 2,
+ "vocab_size": 30524,
+ "encoder_width": 768,
+ "add_cross_attention": true
+ }
\ No newline at end of file
diff --git a/ram/configs/q2l_config.json b/ram/configs/q2l_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..a8eba56c27769cadd1506c8e88fe33aced92668f
--- /dev/null
+++ b/ram/configs/q2l_config.json
@@ -0,0 +1,22 @@
+{
+ "architectures": [
+ "BertModel"
+ ],
+ "attention_probs_dropout_prob": 0.1,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 768,
+ "initializer_range": 0.02,
+ "intermediate_size": 3072,
+ "layer_norm_eps": 1e-12,
+ "max_position_embeddings": 512,
+ "model_type": "bert",
+ "num_attention_heads": 4,
+ "num_hidden_layers": 2,
+ "pad_token_id": 0,
+ "type_vocab_size": 2,
+ "vocab_size": 30522,
+ "encoder_width": 768,
+ "add_cross_attention": true,
+ "add_tag_cross_attention": false
+ }
\ No newline at end of file
diff --git a/ram/configs/swin/config_swinB_384.json b/ram/configs/swin/config_swinB_384.json
new file mode 100644
index 0000000000000000000000000000000000000000..d2f3e0724319655e7a084d602db3712abda746ee
--- /dev/null
+++ b/ram/configs/swin/config_swinB_384.json
@@ -0,0 +1,9 @@
+{
+ "ckpt": "pretrain_model/swin_base_patch4_window7_224_22k.pth",
+ "vision_width": 1024,
+ "image_res": 384,
+ "window_size": 12,
+ "embed_dim": 128,
+ "depths": [ 2, 2, 18, 2 ],
+ "num_heads": [ 4, 8, 16, 32 ]
+ }
\ No newline at end of file
diff --git a/ram/configs/swin/config_swinL_384.json b/ram/configs/swin/config_swinL_384.json
new file mode 100644
index 0000000000000000000000000000000000000000..e6443a2d209fef96f4a183b7499323976f3a88e5
--- /dev/null
+++ b/ram/configs/swin/config_swinL_384.json
@@ -0,0 +1,9 @@
+{
+ "ckpt": "pretrain_model/swin_large_patch4_window12_384_22k.pth",
+ "vision_width": 1536,
+ "image_res": 384,
+ "window_size": 12,
+ "embed_dim": 192,
+ "depths": [ 2, 2, 18, 2 ],
+ "num_heads": [ 6, 12, 24, 48 ]
+ }
\ No newline at end of file
diff --git a/ram/configs/swin/config_swinL_444.json b/ram/configs/swin/config_swinL_444.json
new file mode 100644
index 0000000000000000000000000000000000000000..65b5c5e7d899a746a2995b19cc99ee93df70099b
--- /dev/null
+++ b/ram/configs/swin/config_swinL_444.json
@@ -0,0 +1,9 @@
+{
+ "ckpt": "pretrain_model/swin_large_patch4_window12_384_22k.pth",
+ "vision_width": 1536,
+ "image_res": 444,
+ "window_size": 12,
+ "embed_dim": 192,
+ "depths": [ 2, 2, 18, 2 ],
+ "num_heads": [ 6, 12, 24, 48 ]
+ }
\ No newline at end of file
diff --git a/ram/data/ram_tag_list.txt b/ram/data/ram_tag_list.txt
new file mode 100644
index 0000000000000000000000000000000000000000..49c840b71915f639fb79cd83ac4a3e313cfbc2b1
--- /dev/null
+++ b/ram/data/ram_tag_list.txt
@@ -0,0 +1,4585 @@
+3D CG rendering
+3D glasses
+abacus
+abalone
+monastery
+belly
+academy
+accessory
+accident
+accordion
+acorn
+acrylic paint
+act
+action
+action film
+activity
+actor
+adaptation
+add
+adhesive tape
+adjust
+adult
+adventure
+advertisement
+antenna
+aerobics
+spray can
+afro
+agriculture
+aid
+air conditioner
+air conditioning
+air sock
+aircraft cabin
+aircraft model
+air field
+air line
+airliner
+airman
+plane
+airplane window
+airport
+airport runway
+airport terminal
+airship
+airshow
+aisle
+alarm
+alarm clock
+mollymawk
+album
+album cover
+alcohol
+alcove
+algae
+alley
+almond
+aloe vera
+alp
+alpaca
+alphabet
+german shepherd
+altar
+amber
+ambulance
+bald eagle
+American shorthair
+amethyst
+amphitheater
+amplifier
+amusement park
+amusement ride
+anchor
+ancient
+anemone
+angel
+angle
+animal
+animal sculpture
+animal shelter
+animation
+animation film
+animator
+anime
+ankle
+anklet
+anniversary
+trench coat
+ant
+antelope
+antique
+antler
+anvil
+apartment
+ape
+app
+app icon
+appear
+appearance
+appetizer
+applause
+apple
+apple juice
+apple pie
+apple tree
+applesauce
+appliance
+appointment
+approach
+apricot
+apron
+aqua
+aquarium
+aquarium fish
+aqueduct
+arcade
+arcade machine
+arch
+arch bridge
+archaelogical excavation
+archery
+archipelago
+architect
+architecture
+archive
+archway
+area
+arena
+argument
+arm
+armadillo
+armband
+armchair
+armoire
+armor
+army
+army base
+army tank
+array
+arrest
+arrow
+art
+art exhibition
+art gallery
+art print
+art school
+art studio
+art vector illustration
+artichoke
+article
+artifact
+artist
+artists loft
+ash
+ashtray
+asia temple
+asparagus
+asphalt road
+assemble
+assembly
+assembly line
+association
+astronaut
+astronomer
+athlete
+athletic
+atlas
+atm
+atmosphere
+atrium
+attach
+fighter jet
+attend
+attraction
+atv
+eggplant
+auction
+audi
+audio
+auditorium
+aurora
+author
+auto factory
+auto mechanic
+auto part
+auto show
+auto showroom
+car battery
+automobile make
+automobile model
+motor vehicle
+autumn
+autumn forest
+autumn leave
+autumn park
+autumn tree
+avatar
+avenue
+aviator sunglasses
+avocado
+award
+award ceremony
+award winner
+shed
+ax
+azalea
+baboon
+baby
+baby bottle
+baby carriage
+baby clothe
+baby elephant
+baby food
+baby seat
+baby shower
+back
+backdrop
+backlight
+backpack
+backyard
+bacon
+badge
+badger
+badlands
+badminton
+badminton racket
+bag
+bagel
+bagpipe
+baguette
+bait
+baked goods
+baker
+bakery
+baking
+baking sheet
+balance
+balance car
+balcony
+ball
+ball pit
+ballerina
+ballet
+ballet dancer
+ballet skirt
+balloon
+balloon arch
+baseball player
+ballroom
+bamboo
+bamboo forest
+banana
+banana bread
+banana leaf
+banana tree
+band
+band aid
+bandage
+headscarf
+bandeau
+bangs
+bracelet
+balustrade
+banjo
+bank
+bank card
+bank vault
+banknote
+banner
+banquet
+banquet hall
+banyan tree
+baozi
+baptism
+bar
+bar code
+bar stool
+barbecue
+barbecue grill
+barbell
+barber
+barber shop
+barbie
+barge
+barista
+bark
+barley
+barn
+barn owl
+barn door
+barrel
+barricade
+barrier
+handcart
+bartender
+baseball
+baseball base
+baseball bat
+baseball hat
+baseball stadium
+baseball game
+baseball glove
+baseball pitcher
+baseball team
+baseball uniform
+basement
+basil
+basin
+basket
+basket container
+basketball
+basketball backboard
+basketball coach
+basketball court
+basketball game
+basketball hoop
+basketball player
+basketball stadium
+basketball team
+bass
+bass guitar
+bass horn
+bassist
+bat
+bath
+bath heater
+bath mat
+bath towel
+swimwear
+bathrobe
+bathroom
+bathroom accessory
+bathroom cabinet
+bathroom door
+bathroom mirror
+bathroom sink
+toilet paper
+bathroom window
+batman
+wand
+batter
+battery
+battle
+battle rope
+battleship
+bay
+bay bridge
+bay window
+bayberry
+bazaar
+beach
+beach ball
+beach chair
+beach house
+beach hut
+beach towel
+beach volleyball
+lighthouse
+bead
+beagle
+beak
+beaker
+beam
+bean
+bean bag chair
+beanbag
+bear
+bear cub
+beard
+beast
+beat
+beautiful
+beauty
+beauty salon
+beaver
+bed
+bedcover
+bed frame
+bedroom
+bedding
+bedpan
+bedroom window
+bedside lamp
+bee
+beech tree
+beef
+beekeeper
+beeper
+beer
+beer bottle
+beer can
+beer garden
+beer glass
+beer hall
+beet
+beetle
+beige
+clock
+bell pepper
+bell tower
+belt
+belt buckle
+bench
+bend
+bengal tiger
+bento
+beret
+berry
+berth
+beverage
+bib
+bibimbap
+bible
+bichon
+bicycle
+bicycle helmet
+bicycle wheel
+biker
+bidet
+big ben
+bike lane
+bike path
+bike racing
+bike ride
+bikini
+bikini top
+bill
+billard
+billboard
+billiard table
+bin
+binder
+binocular
+biology laboratory
+biplane
+birch
+birch tree
+bird
+bird bath
+bird feeder
+bird house
+bird nest
+birdbath
+bird cage
+birth
+birthday
+birthday cake
+birthday candle
+birthday card
+birthday party
+biscuit
+bishop
+bison
+bit
+bite
+black
+black sheep
+blackberry
+blackbird
+blackboard
+blacksmith
+blade
+blanket
+sports coat
+bleacher
+blender
+blessing
+blind
+eye mask
+flasher
+snowstorm
+block
+blog
+blood
+bloom
+blossom
+blouse
+blow
+hair drier
+blowfish
+blue
+blue artist
+blue jay
+blue sky
+blueberry
+bluebird
+pig
+board
+board eraser
+board game
+boardwalk
+boat
+boat deck
+boat house
+paddle
+boat ride
+bobfloat
+bobcat
+body
+bodyboard
+bodybuilder
+boiled egg
+boiler
+bolo tie
+bolt
+bomb
+bomber
+bonasa umbellu
+bone
+bonfire
+bonnet
+bonsai
+book
+book cover
+bookcase
+folder
+bookmark
+bookshelf
+bookstore
+boom microphone
+boost
+boot
+border
+Border collie
+botanical garden
+bottle
+bottle cap
+bottle opener
+bottle screw
+bougainvillea
+boulder
+bouquet
+boutique
+boutique hotel
+bow
+bow tie
+bow window
+bowl
+bowling
+bowling alley
+bowling ball
+bowling equipment
+box
+box girder bridge
+box turtle
+boxer
+underdrawers
+boxing
+boxing glove
+boxing ring
+boy
+brace
+bracket
+braid
+brain
+brake
+brake light
+branch
+brand
+brandy
+brass
+brass plaque
+bread
+breadbox
+break
+breakfast
+seawall
+chest
+brewery
+brick
+brick building
+wall
+brickwork
+wedding dress
+bride
+groom
+bridesmaid
+bridge
+bridle
+briefcase
+bright
+brim
+broach
+broadcasting
+broccoli
+bronze
+bronze medal
+bronze sculpture
+bronze statue
+brooch
+creek
+broom
+broth
+brown
+brown bear
+brownie
+brunch
+brunette
+brush
+coyote
+brussels sprout
+bubble
+bubble gum
+bubble tea
+bucket cabinet
+shield
+bud
+buddha
+buffalo
+buffet
+bug
+build
+builder
+building
+building block
+building facade
+building material
+lamp
+bull
+bulldog
+bullet
+bullet train
+bulletin board
+bulletproof vest
+bullfighting
+megaphone
+bullring
+bumblebee
+bumper
+roll
+bundle
+bungee
+bunk bed
+bunker
+bunny
+buoy
+bureau
+burial chamber
+burn
+burrito
+bus
+bus driver
+bus interior
+bus station
+bus stop
+bus window
+bush
+business
+business card
+business executive
+business suit
+business team
+business woman
+businessman
+bust
+butcher
+butchers shop
+butte
+butter
+cream
+butterfly
+butterfly house
+button
+buttonwood
+buy
+taxi
+cabana
+cabbage
+cabin
+cabin car
+cabinet
+cabinetry
+cable
+cable car
+cactus
+cafe
+canteen
+cage
+cake
+cake stand
+calculator
+caldron
+calendar
+calf
+call
+phone box
+calligraphy
+calm
+camcorder
+camel
+camera
+camera lens
+camouflage
+camp
+camper
+campfire
+camping
+campsite
+campus
+can
+can opener
+canal
+canary
+cancer
+candle
+candle holder
+candy
+candy bar
+candy cane
+candy store
+cane
+jar
+cannon
+canopy
+canopy bed
+cantaloupe
+cantilever bridge
+canvas
+canyon
+cap
+cape
+cape cod
+cappuccino
+capsule
+captain
+capture
+car
+car dealership
+car door
+car interior
+car logo
+car mirror
+parking lot
+car seat
+car show
+car wash
+car window
+caramel
+card
+card game
+cardboard
+cardboard box
+cardigan
+cardinal
+cargo
+cargo aircraft
+cargo ship
+caribbean
+carnation
+carnival
+carnivore
+carousel
+carp
+carpenter
+carpet
+slipper
+house finch
+coach
+dalmatian
+aircraft carrier
+carrot
+carrot cake
+carry
+cart
+carton
+cartoon
+cartoon character
+cartoon illustration
+cartoon style
+carve
+case
+cash
+cashew
+casino
+casserole
+cassette
+cassette deck
+plaster bandage
+casting
+castle
+cat
+cat bed
+cat food
+cat furniture
+cat tree
+catacomb
+catamaran
+catamount
+catch
+catcher
+caterpillar
+catfish
+cathedral
+cattle
+catwalk
+catwalk show
+cauliflower
+cave
+caviar
+CD
+CD player
+cedar
+ceiling
+ceiling fan
+celebrate
+celebration
+celebrity
+celery
+cello
+smartphone
+cement
+graveyard
+centerpiece
+centipede
+ceramic
+ceramic tile
+cereal
+ceremony
+certificate
+chain
+chain saw
+chair
+chairlift
+daybed
+chalet
+chalice
+chalk
+chamber
+chameleon
+champagne
+champagne flute
+champion
+championship
+chandelier
+changing table
+channel
+chap
+chapel
+character sculpture
+charcoal
+charge
+charger
+chariot
+charity
+charity event
+charm
+graph
+chase
+chassis
+check
+checkbook
+chessboard
+checklist
+cheer
+cheerlead
+cheese
+cheeseburger
+cheesecake
+cheetah
+chef
+chemical compound
+chemist
+chemistry
+chemistry lab
+cheongsam
+cherry
+cherry blossom
+cherry tomato
+cherry tree
+chess
+chestnut
+chicken
+chicken breast
+chicken coop
+chicken salad
+chicken wing
+garbanzo
+chiffonier
+chihuahua
+child
+child actor
+childs room
+chile
+chili dog
+chimney
+chimpanzee
+chinaware
+chinese cabbage
+chinese garden
+chinese knot
+chinese rose
+chinese tower
+chip
+chipmunk
+chisel
+chocolate
+chocolate bar
+chocolate cake
+chocolate chip
+chocolate chip cookie
+chocolate milk
+chocolate mousse
+truffle
+choir
+kitchen knife
+cutting board
+chopstick
+christmas
+christmas ball
+christmas card
+christmas decoration
+christmas dinner
+christmas eve
+christmas hat
+christmas light
+christmas market
+christmas ornament
+christmas tree
+chrysanthemum
+church
+church tower
+cider
+cigar
+cigar box
+cigarette
+cigarette case
+waistband
+cinema
+photographer
+cinnamon
+circle
+circuit
+circuit board
+circus
+water tank
+citrus fruit
+city
+city bus
+city hall
+city nightview
+city park
+city skyline
+city square
+city street
+city wall
+city view
+clam
+clarinet
+clasp
+class
+classic
+classroom
+clavicle
+claw
+clay
+pottery
+clean
+clean room
+cleaner
+cleaning product
+clear
+cleat
+clementine
+client
+cliff
+climb
+climb mountain
+climber
+clinic
+clip
+clip art
+clipboard
+clipper
+clivia
+cloak
+clogs
+close-up
+closet
+cloth
+clothe
+clothing
+clothespin
+clothesline
+clothing store
+cloud
+cloud forest
+cloudy
+clover
+joker
+clown fish
+club
+clutch
+clutch bag
+coal
+coast
+coat
+coatrack
+cob
+cock
+cockatoo
+cocker
+cockpit
+roach
+cocktail
+cocktail dress
+cocktail shaker
+cocktail table
+cocoa
+coconut
+coconut tree
+coffee
+coffee bean
+coffee cup
+coffee machine
+coffee shop
+coffeepot
+coffin
+cognac
+spiral
+coin
+coke
+colander
+cold
+slaw
+collaboration
+collage
+collection
+college student
+sheepdog
+crash
+color
+coloring book
+coloring material
+pony
+pillar
+comb
+combination lock
+comic
+comedy
+comedy film
+comet
+comfort
+comfort food
+comic book
+comic book character
+comic strip
+commander
+commentator
+community
+commuter
+company
+compass
+compete
+contest
+competitor
+composer
+composition
+compost
+computer
+computer box
+computer chair
+computer desk
+keyboard
+computer monitor
+computer room
+computer screen
+computer tower
+concept car
+concert
+concert hall
+conch
+concrete
+condiment
+condom
+condominium
+conductor
+cone
+meeting
+conference center
+conference hall
+meeting room
+confetti
+conflict
+confluence
+connect
+connector
+conservatory
+constellation
+construction site
+construction worker
+contain
+container
+container ship
+continent
+profile
+contract
+control
+control tower
+convenience store
+convention
+conversation
+converter
+convertible
+transporter
+cook
+cooking
+cooking spray
+cooker
+cool
+cooler
+copper
+copy
+coral
+coral reef
+rope
+corded phone
+liquor
+corgi
+cork
+corkboard
+cormorant
+corn
+corn field
+cornbread
+corner
+trumpet
+cornice
+cornmeal
+corral
+corridor
+corset
+cosmetic
+cosmetics brush
+cosmetics mirror
+cosplay
+costume
+costumer film designer
+infant bed
+cottage
+cotton
+cotton candy
+couch
+countdown
+counter
+counter top
+country artist
+country house
+country lane
+country pop artist
+countryside
+coupe
+couple
+couple photo
+courgette
+course
+court
+courthouse
+courtyard
+cousin
+coverall
+cow
+cowbell
+cowboy
+cowboy boot
+cowboy hat
+crab
+crabmeat
+crack
+cradle
+craft
+craftsman
+cranberry
+crane
+crape
+crapper
+crate
+crater lake
+lobster
+crayon
+cream cheese
+cream pitcher
+create
+creature
+credit card
+crescent
+croissant
+crest
+crew
+cricket
+cricket ball
+cricket team
+cricketer
+crochet
+crock pot
+crocodile
+crop
+crop top
+cross
+crossbar
+crossroad
+crosstalk
+crosswalk
+crouton
+crow
+crowbar
+crowd
+crowded
+crown
+crt screen
+crucifix
+cruise
+cruise ship
+cruiser
+crumb
+crush
+crutch
+crystal
+cub
+cube
+cucumber
+cue
+cuff
+cufflink
+cuisine
+farmland
+cup
+cupcake
+cupid
+curb
+curl
+hair roller
+currant
+currency
+curry
+curtain
+curve
+pad
+customer
+cut
+cutlery
+cycle
+cycling
+cyclone
+cylinder
+cymbal
+cypress
+cypress tree
+dachshund
+daffodil
+dagger
+dahlia
+daikon
+dairy
+daisy
+dam
+damage
+damp
+dance
+dance floor
+dance room
+dancer
+dandelion
+dark
+darkness
+dart
+dartboard
+dashboard
+date
+daughter
+dawn
+day bed
+daylight
+deadbolt
+death
+debate
+debris
+decanter
+deck
+decker bus
+decor
+decorate
+decorative picture
+deer
+defender
+deity
+delicatessen
+deliver
+demolition
+monster
+demonstration
+den
+denim jacket
+dentist
+department store
+depression
+derby
+dermopathy
+desert
+desert road
+design
+designer
+table
+table lamp
+desktop
+desktop computer
+dessert
+destruction
+detective
+detergent
+dew
+dial
+diamond
+diaper
+diaper bag
+journal
+die
+diet
+excavator
+number
+digital clock
+dill
+dinner
+rowboat
+dining room
+dinner party
+dinning table
+dinosaur
+dip
+diploma
+direct
+director
+dirt
+dirt bike
+dirt field
+dirt road
+dirt track
+disaster
+disciple
+disco
+disco ball
+discotheque
+disease
+plate
+dish antenna
+dish washer
+dishrag
+dishes
+dishsoap
+Disneyland
+dispenser
+display
+display window
+trench
+dive
+diver
+diving board
+paper cup
+dj
+doberman
+dock
+doctor
+document
+documentary
+dog
+dog bed
+dog breed
+dog collar
+dog food
+dog house
+doll
+dollar
+dollhouse
+dolly
+dolphin
+dome
+domicile
+domino
+donkey
+donut
+doodle
+door
+door handle
+doormat
+doorplate
+doorway
+dormitory
+dough
+downtown
+dozer
+drag
+dragon
+dragonfly
+drain
+drama
+drama film
+draw
+drawer
+drawing
+drawing pin
+pigtail
+dress
+dress hat
+dress shirt
+dress shoe
+dress suit
+dresser
+dressing room
+dribble
+drift
+driftwood
+drill
+drink
+drinking water
+drive
+driver
+driveway
+drone
+drop
+droplight
+dropper
+drought
+medicine
+pharmacy
+drum
+drummer
+drumstick
+dry
+duchess
+duck
+duckbill
+duckling
+duct tape
+dude
+duet
+duffel
+canoe
+dumbbell
+dumpling
+dune
+dunk
+durian
+dusk
+dust
+garbage truck
+dustpan
+duvet
+DVD
+dye
+eagle
+ear
+earmuff
+earphone
+earplug
+earring
+earthquake
+easel
+easter
+easter bunny
+easter egg
+eat
+restaurant
+eclair
+eclipse
+ecosystem
+edit
+education
+educator
+eel
+egg
+egg roll
+egg tart
+eggbeater
+egret
+Eiffel tower
+elastic band
+senior
+electric chair
+electric drill
+electrician
+electricity
+electron
+electronic
+elephant
+elevation map
+elevator
+elevator car
+elevator door
+elevator lobby
+elevator shaft
+embankment
+embassy
+embellishment
+ember
+emblem
+embroidery
+emerald
+emergency
+emergency service
+emergency vehicle
+emotion
+Empire State Building
+enamel
+enclosure
+side table
+energy
+engagement
+engagement ring
+engine
+engine room
+engineer
+engineering
+english shorthair
+ensemble
+enter
+entertainer
+entertainment
+entertainment center
+entrance
+entrance hall
+envelope
+equestrian
+equipment
+eraser
+erhu
+erosion
+escalator
+escargot
+espresso
+estate
+estuary
+eucalyptus tree
+evening
+evening dress
+evening light
+evening sky
+evening sun
+event
+evergreen
+ewe
+excavation
+exercise
+exhaust hood
+exhibition
+exit
+explorer
+explosion
+extension cord
+extinguisher
+extractor
+extrude
+eye
+eye shadow
+eyebrow
+eyeliner
+fabric
+fabric store
+facade
+face
+face close-up
+face powder
+face towel
+facial tissue holder
+facility
+factory
+factory workshop
+fair
+fairground
+fairy
+falcon
+fall
+family
+family car
+family photo
+family room
+fan
+fang
+farm
+farmer
+farmer market
+farmhouse
+fashion
+fashion accessory
+fashion designer
+fashion girl
+fashion illustration
+fashion look
+fashion model
+fashion show
+fast food
+fastfood restaurant
+father
+faucet
+fault
+fauna
+fawn
+fax
+feast
+feather
+fedora
+feed
+feedbag
+feeding
+feeding chair
+feline
+mountain lion
+fence
+fender
+fern
+ferret
+ferris wheel
+ferry
+fertilizer
+festival
+fiber
+fiction
+fiction book
+field
+field road
+fig
+fight
+figure skater
+figurine
+file
+file photo
+file cabinet
+fill
+film camera
+film director
+film format
+film premiere
+film producer
+filming
+filter
+fin
+hand
+finish line
+fir
+fir tree
+fire
+fire alarm
+fire department
+fire truck
+fire escape
+fire hose
+fire pit
+fire station
+firecracker
+fireman
+fireplace
+firework
+firework display
+first-aid kit
+fish
+fish boat
+fish market
+fish pond
+fishbowl
+fisherman
+fishing
+fishing boat
+fishing net
+fishing pole
+fishing village
+fitness
+fitness course
+five
+fixture
+fjord
+flag
+flag pole
+flake
+flame
+flamingo
+flannel
+flap
+flare
+flash
+flask
+flat
+flatfish
+flavor
+flea
+flea market
+fleet
+flight
+flight attendant
+flip
+flip-flop
+flipchart
+float
+flock
+flood
+floor
+floor fan
+floor mat
+floor plan
+floor window
+floral arrangement
+florist
+floss
+flour
+flow
+flower
+flower basket
+flower bed
+flower box
+flower field
+flower girl
+flower market
+fluid
+flush
+flute
+fly
+fly fishing
+flyer
+horse
+foam
+fog
+foggy
+foie gra
+foil
+folding chair
+leaf
+folk artist
+folk dance
+folk rock artist
+fondant
+hotpot
+font
+food
+food coloring
+food court
+food processor
+food stand
+food truck
+foosball
+foot
+foot bridge
+football
+football coach
+football college game
+football match
+football field
+football game
+football helmet
+football player
+football stadium
+football team
+path
+footprint
+footrest
+footstall
+footwear
+forbidden city
+ford
+forehead
+forest
+forest fire
+forest floor
+forest path
+forest road
+forge
+fork
+forklift
+form
+formal garden
+formation
+formula 1
+fort
+fortification
+forward
+fossil
+foundation
+fountain
+fountain pen
+fox
+frame
+freckle
+highway
+lorry
+French
+French bulldog
+French fries
+French toast
+freshener
+fridge
+fried chicken
+fried egg
+fried rice
+friendship
+frisbee
+frog
+frost
+frosting
+frosty
+frozen
+fruit
+fruit cake
+fruit dish
+fruit market
+fruit salad
+fruit stand
+fruit tree
+fruits shop
+fry
+frying pan
+fudge
+fuel
+fume hood
+fun
+funeral
+fungi
+funnel
+fur
+fur coat
+furniture
+futon
+gadget
+muzzle
+galaxy
+gallery
+game
+game board
+game controller
+ham
+gang
+garage
+garage door
+garage kit
+garbage
+garden
+garden asparagus
+garden hose
+garden spider
+gardener
+gardening
+garfield
+gargoyle
+wreath
+garlic
+garment
+gas
+gas station
+gas stove
+gasmask
+collect
+gathering
+gauge
+gazebo
+gear
+gecko
+geisha
+gel
+general store
+generator
+geranium
+ghost
+gift
+gift bag
+gift basket
+gift box
+gift card
+gift shop
+gift wrap
+gig
+gin
+ginger
+gingerbread
+gingerbread house
+ginkgo tree
+giraffe
+girl
+give
+glacier
+gladiator
+glass bead
+glass bottle
+glass bowl
+glass box
+glass building
+glass door
+glass floor
+glass house
+glass jar
+glass plate
+glass table
+glass vase
+glass wall
+glass window
+glasses
+glaze
+glider
+earth
+glove
+glow
+glue pudding
+go
+go for
+goal
+goalkeeper
+goat
+goat cheese
+gobi
+goggles
+gold
+gold medal
+Golden Gate Bridge
+golden retriever
+goldfish
+golf
+golf cap
+golf cart
+golf club
+golf course
+golfer
+goose
+gorilla
+gothic
+gourd
+government
+government agency
+gown
+graduate
+graduation
+grain
+grampus
+grand prix
+grandfather
+grandmother
+grandparent
+granite
+granola
+grape
+grapefruit
+wine
+grass
+grasshopper
+grassland
+grassy
+grater
+grave
+gravel
+gravestone
+gravy
+gravy boat
+gray
+graze
+grazing
+green
+greenery
+greet
+greeting
+greeting card
+greyhound
+grid
+griddle
+grill
+grille
+grilled eel
+grind
+grinder
+grits
+grocery bag
+grotto
+ground squirrel
+group
+group photo
+grove
+grow
+guacamole
+guard
+guard dog
+guest house
+guest room
+guide
+guinea pig
+guitar
+guitarist
+gulf
+gull
+gun
+gundam
+gurdwara
+guzheng
+gym
+gymnast
+habitat
+hacker
+hail
+hair
+hair color
+hair spray
+hairbrush
+haircut
+hairgrip
+hairnet
+hairpin
+hairstyle
+half
+hall
+halloween
+halloween costume
+halloween pumpkin
+halter top
+hamburg
+hamburger
+hami melon
+hammer
+hammock
+hamper
+hamster
+hand dryer
+hand glass
+hand towel
+handbag
+handball
+handcuff
+handgun
+handkerchief
+handle
+handsaw
+handshake
+handstand
+handwriting
+hanfu
+hang
+hangar
+hanger
+happiness
+harbor
+harbor seal
+hard rock artist
+hardback book
+safety helmet
+hardware
+hardware store
+hardwood
+hardwood floor
+mouth organ
+pipe organ
+harpsichord
+harvest
+harvester
+hassock
+hat
+hatbox
+hautboy
+hawthorn
+hay
+hayfield
+hazelnut
+head
+head coach
+headlight
+headboard
+headdress
+headland
+headquarter
+hearing
+heart
+heart shape
+heat
+heater
+heather
+hedge
+hedgehog
+heel
+helicopter
+heliport
+helmet
+help
+hen
+henna
+herb
+herd
+hermit crab
+hero
+heron
+hibiscus
+hibiscus flower
+hide
+high bar
+high heel
+highland
+highlight
+hike
+hiker
+hiking boot
+hiking equipment
+hill
+hill country
+hill station
+hillside
+hindu temple
+hinge
+hip
+hip hop artist
+hippo
+historian
+historic
+history
+hockey
+hockey arena
+hockey game
+hockey player
+hockey stick
+hoe
+hole
+vacation
+holly
+holothurian
+home
+home appliance
+home base
+home decor
+home interior
+home office
+home theater
+homework
+hummus
+honey
+beehive
+honeymoon
+hood
+hoodie
+hook
+jump
+horizon
+hornbill
+horned cow
+hornet
+horror
+horror film
+horse blanket
+horse cart
+horse farm
+horse ride
+horseback
+horseshoe
+hose
+hospital
+hospital bed
+hospital room
+host
+inn
+hot
+hot air balloon
+hot dog
+hot sauce
+hot spring
+hotel
+hotel lobby
+hotel room
+hotplate
+hourglass
+house
+house exterior
+houseplant
+hoverboard
+howler
+huddle
+hug
+hula hoop
+person
+humidifier
+hummingbird
+humpback whale
+hunt
+hunting lodge
+hurdle
+hurricane
+husky
+hut
+hyaena
+hybrid
+hydrangea
+hydrant
+seaplane
+ice
+ice bag
+polar bear
+ice cave
+icecream
+ice cream cone
+ice cream parlor
+ice cube
+ice floe
+ice hockey player
+ice hockey team
+lollipop
+ice maker
+rink
+ice sculpture
+ice shelf
+skate
+ice skating
+iceberg
+icicle
+icing
+icon
+id photo
+identity card
+igloo
+light
+iguana
+illuminate
+illustration
+image
+impala
+incense
+independence day
+individual
+indoor
+indoor rower
+induction cooker
+industrial area
+industry
+infantry
+inflatable boat
+information desk
+infrastructure
+ingredient
+inhalator
+injection
+injury
+ink
+inking pad
+inlet
+inscription
+insect
+install
+instrument
+insulated cup
+interaction
+interior design
+website
+intersection
+interview
+invertebrate
+invitation
+ipad
+iphone
+ipod
+iris
+iron
+ironing board
+irrigation system
+island
+islet
+isopod
+ivory
+ivy
+izakaya
+jack
+jackcrab
+jacket
+jacuzzi
+jade
+jaguar
+jail cell
+jam
+japanese garden
+jasmine
+jaw
+jay
+jazz
+jazz artist
+jazz fusion artist
+jeans
+jeep
+jelly
+jelly bean
+jellyfish
+jet
+motorboat
+jewel
+jewellery
+jewelry shop
+jigsaw puzzle
+rickshaw
+jockey
+jockey cap
+jog
+joint
+journalist
+joystick
+judge
+jug
+juggle
+juice
+juicer
+jujube
+jump rope
+jumpsuit
+jungle
+junkyard
+kale
+kaleidoscope
+kangaroo
+karaoke
+karate
+karting
+kasbah
+kayak
+kebab
+key
+keycard
+khaki
+kick
+kilt
+kimono
+kindergarden classroom
+kindergarten
+king
+king crab
+kiss
+kit
+kitchen
+kitchen cabinet
+kitchen counter
+kitchen floor
+kitchen hood
+kitchen island
+kitchen sink
+kitchen table
+kitchen utensil
+kitchen window
+kitchenware
+kite
+kiwi
+knee pad
+kneel
+knife
+rider
+knit
+knitting needle
+knob
+knocker
+knot
+koala
+koi
+ktv
+laboratory
+lab coat
+label
+labrador
+maze
+lace
+lace dress
+ladder
+ladle
+ladybird
+lagoon
+lake
+lake district
+lake house
+lakeshore
+lamb
+lamb chop
+lamp post
+lamp shade
+spear
+land
+land vehicle
+landfill
+landing
+landing deck
+landmark
+landscape
+landslide
+lanyard
+lantern
+lap
+laptop
+laptop keyboard
+larva
+lasagne
+laser
+lash
+lasso
+latch
+latex
+latte
+laugh
+launch
+launch event
+launch party
+laundromat
+laundry
+laundry basket
+laundry room
+lava
+lavender
+lawn
+lawn wedding
+lawyer
+lay
+lead
+lead singer
+lead to
+leader
+leak
+lean
+learn
+leash
+leather
+leather jacket
+leather shoe
+speech
+lecture hall
+lecture room
+ledge
+leftover
+leg
+legend
+legging
+legislative chamber
+lego
+legume
+lemon
+lemon juice
+lemonade
+lemur
+lens
+lens flare
+lentil
+leopard
+leotard
+tights
+leprechaun
+lesson
+letter
+mailbox
+letter logo
+lettering
+lettuce
+level
+library
+license
+license plate
+lichen
+lick
+lid
+lie
+life belt
+life jacket
+lifeboat
+lifeguard
+lift
+light fixture
+light show
+light switch
+lighting
+lightning
+lightning rod
+lilac
+lily
+limb
+lime
+limestone
+limo
+line
+line art
+line up
+linen
+liner
+lion
+lip balm
+lipstick
+liquid
+liquor store
+list
+litchi
+live
+livestock
+living room
+living space
+lizard
+load
+loading dock
+loafer
+hallway
+locate
+lock
+lock chamber
+locker
+loft
+log
+log cabin
+logo
+loki
+long hair
+longboard
+loom
+loop
+lose
+lottery
+lotus
+love
+loveseat
+luggage
+lumber
+lumberjack
+lunch
+lunch box
+lush
+luxury
+luxury yacht
+mac
+macadamia
+macaque
+macaroni
+macaw
+machete
+machine
+machine gun
+magazine
+magic
+magician
+magnet
+magnifying glass
+magnolia
+magpie
+mahjong
+mahout
+maid
+chain mail
+mail slot
+make
+makeover
+makeup artist
+makeup tool
+mallard
+mallard duck
+mallet
+mammal
+mammoth
+man
+management
+manager
+manatee
+mandala
+mandarin orange
+mandarine
+mane
+manga
+manger
+mango
+mangosteen
+mangrove
+manhattan
+manhole
+manhole cover
+manicure
+mannequin
+manor house
+mansion
+mantid
+mantle
+manufactured home
+manufacturing
+manuscript
+map
+maple
+maple leaf
+maple syrup
+maraca
+marathon
+marble
+march
+marching band
+mare
+marigold
+marine
+marine invertebrate
+marine mammal
+puppet
+mark
+market
+market square
+market stall
+marriage
+martial
+martial artist
+martial arts gym
+martini
+martini glass
+mascara
+mascot
+mashed potato
+masher
+mask
+massage
+mast
+mat
+matador
+match
+matchbox
+material
+mattress
+mausoleum
+maxi dress
+meal
+measuring cup
+measuring tape
+meat
+meatball
+mechanic
+mechanical fan
+medal
+media
+medical equipment
+medical image
+medical staff
+medicine cabinet
+medieval
+medina
+meditation
+meerkat
+meet
+melon
+monument
+menu
+mermaid
+net
+mess
+messenger bag
+metal
+metal artist
+metal detector
+meter
+mezzanine
+microphone
+microscope
+microwave
+midnight
+milestone
+military uniform
+milk
+milk can
+milk tea
+milkshake
+mill
+mine
+miner
+mineral
+mineral water
+miniskirt
+miniature
+minibus
+minister
+minivan
+mint
+mint candy
+mirror
+miss
+missile
+mission
+mistletoe
+mix
+mixer
+mixing bowl
+mixture
+moat
+mobility scooter
+model
+model car
+modern
+modern tower
+moisture
+mold
+molding
+mole
+monarch
+money
+monitor
+monk
+monkey
+monkey wrench
+monochrome
+monocycle
+monster truck
+moon
+moon cake
+moonlight
+moor
+moose
+swab
+moped
+morning
+morning fog
+morning light
+morning sun
+mortar
+mosaic
+mosque
+mosquito
+moss
+motel
+moth
+mother
+motherboard
+motif
+sport
+motor
+motorbike
+motorcycle
+motorcycle helmet
+motorcycle racer
+motorcyclist
+motorsport
+mound
+mountain
+mountain bike
+mountain biker
+mountain biking
+mountain gorilla
+mountain lake
+mountain landscape
+mountain pass
+mountain path
+mountain range
+mountain river
+mountain snowy
+mountain stream
+mountain view
+mountain village
+mountaineer
+mountaineering bag
+mouse
+mousepad
+mousetrap
+mouth
+mouthwash
+move
+movie poster
+movie ticket
+mower
+mp3 player
+mr
+mud
+muffin
+mug
+mulberry
+mulch
+mule
+municipality
+mural
+muscle
+muscle car
+museum
+mushroom
+music
+music festival
+music stool
+music studio
+music video performer
+musical keyboard
+musician
+mussel
+mustard
+mythology
+nacho
+nail polish
+nailfile
+nanny
+napkin
+narrow
+national flag
+nativity scene
+natural history museum
+nature
+nature reserve
+navigation
+navratri
+navy
+nebula
+neck
+neckband
+necklace
+neckline
+nectar
+nectarine
+needle
+neighbor
+neighbourhood
+neon
+neon light
+nerve
+nest
+new year
+newborn
+newfoundland
+newlywed
+news
+news conference
+newsstand
+night
+night market
+night sky
+night view
+nightclub
+nightstand
+noodle
+nose
+noseband
+note
+notebook
+notepad
+notepaper
+notice
+number icon
+nun
+nurse
+nursery
+nursing home
+nut
+nutcracker
+oak
+oak tree
+oar
+oasis
+oast house
+oatmeal
+oats
+obelisk
+observation tower
+observatory
+obstacle course
+sea
+octopus
+offer
+office
+office building
+office chair
+office cubicle
+office desk
+office supply
+office window
+officer
+official
+oil
+oil lamp
+oil painting
+oilrig
+okra
+old photo
+olive
+olive oil
+olive tree
+omelet
+onion
+onion ring
+opal
+open
+opening
+opening ceremony
+opera
+opera house
+operate
+operating room
+operation
+optical shop
+orangutan
+orange
+orange juice
+orange tree
+orangery
+orbit
+orchard
+orchestra pit
+orchid
+order
+organization
+origami
+ornament
+osprey
+ostrich
+otter
+out
+outcrop
+outdoor
+outhouse
+electric outlet
+outline
+oval
+oven
+overall
+overcoat
+overpass
+owl
+oyster
+teething ring
+pack
+package
+paddock
+police van
+padlock
+paella
+pagoda
+pain
+paint brush
+painter
+paisley bandanna
+palace
+palette
+paling
+pall
+palm tree
+pan
+pancake
+panda
+panel
+panorama
+pansy
+pant
+pantry
+pants
+pantyhose
+papaya
+paper
+paper bag
+paper cutter
+paper lantern
+paper plate
+paper towel
+paperback book
+paperweight
+parachute
+parade
+paradise
+parrot
+paramedic
+paraquet
+parasail
+paratrooper
+parchment
+parish
+park
+park bench
+parking
+parking garage
+parking meter
+parking sign
+parliament
+parsley
+participant
+partner
+partridge
+party
+party hat
+pass
+passage
+passbook
+passenger
+passenger ship
+passenger train
+passion fruit
+passport
+pasta
+paste
+pastry
+pasture
+patch
+patient
+pattern
+pavement
+pavilion
+paw
+pay
+payphone
+pea
+peace
+peach
+peacock
+peak
+peanut
+peanut butter
+pear
+pearl
+pebble
+pecan
+pedestrian
+pedestrian bridge
+pedestrian street
+peel
+peeler
+pegboard
+pegleg
+pelican
+pen
+penalty kick
+pencil
+pencil case
+pencil sharpener
+pencil skirt
+pendant
+pendulum
+penguin
+peninsula
+pennant
+penny
+piggy bank
+peony
+pepper
+pepper grinder
+peppercorn
+pepperoni
+perch
+perform
+performance
+performance arena
+perfume
+pergola
+persian cat
+persimmon
+personal care
+personal flotation device
+pest
+pet
+pet shop
+pet store
+petal
+petunia
+church bench
+pheasant
+phenomenon
+philosopher
+phone
+phonebook
+record player
+photo
+photo booth
+photo frame
+photography
+physicist
+physics laboratory
+pianist
+piano
+plectrum
+pick up
+pickle
+picnic
+picnic area
+picnic basket
+picnic table
+picture
+picture frame
+pie
+pigeon
+pilgrim
+tablet
+pillow
+pilot
+pilot boat
+pin
+pine
+pine cone
+pine forest
+pine nut
+pineapple
+table tennis table
+table tennis
+pink
+pint
+pipa
+pipe
+pipe bowl
+pirate
+pirate flag
+pirate ship
+pistachio
+ski slope
+pocket bread
+pitaya
+pitbull
+pitch
+pitcher
+pitcher plant
+pitchfork
+pizza
+pizza cutter
+pizza pan
+pizzeria
+placard
+place
+place mat
+plaid
+plain
+plan
+planet
+planet earth
+plank
+plant
+plantation
+planting
+plaque
+plaster
+plastic
+plasticine
+plateau
+platform
+platinum
+platter
+play
+play badminton
+play baseball
+play basketball
+play billiard
+play football
+play pong
+play tennis
+play volleyball
+player
+playground
+playhouse
+playing card
+playing chess
+playing golf
+playing mahjong
+playingfield
+playpen
+playroom
+plaza
+plier
+plot
+plow
+plug
+plug hat
+plum
+plumber
+plumbing fixture
+plume
+plywood
+pocket
+pocket watch
+pocketknife
+pod
+podium
+poetry
+poinsettia
+point
+pointer
+poker card
+poker chip
+poker table
+pole
+polecat
+police
+police car
+police dog
+police station
+politician
+polka dot
+pollen
+pollution
+polo
+polo neck
+polo shirt
+pomegranate
+pomeranian
+poncho
+pond
+ponytail
+poodle
+pool
+pop
+pop artist
+popcorn
+pope
+poppy
+porcelain
+porch
+pork
+porridge
+portable battery
+portal
+portfolio
+porthole
+portrait
+portrait session
+pose
+possum
+post
+post office
+stamp
+postcard
+poster
+poster page
+pot
+potato
+potato chip
+potato salad
+potholder
+potty
+pouch
+poultry
+pound
+pour
+powder
+power line
+power plugs and sockets
+power see
+power station
+practice
+Prague Castle
+prayer
+preacher
+premiere
+prescription
+show
+presentation
+president
+press room
+pressure cooker
+pretzel
+prince
+princess
+print
+printed page
+printer
+printing
+prison
+produce
+product
+profession
+professional
+professor
+project picture
+projection screen
+projector
+prom
+promenade
+propeller
+prophet
+proposal
+protective suit
+protest
+protester
+publication
+publicity portrait
+ice hockey
+pudding
+puddle
+puff
+puffin
+pug
+pull
+pulpit
+pulse
+pump
+pumpkin
+pumpkin pie
+pumpkin seed
+punch bag
+punch
+student
+purple
+push
+putt
+puzzle
+tower
+pyramid
+python
+qr code
+quail
+quarry
+quarter
+quartz
+queen
+quesadilla
+queue
+quiche
+quilt
+quilting
+quote
+rabbit
+raccoon
+race
+race track
+raceway
+race car
+racket
+radar
+radiator
+radio
+raft
+rag doll
+rail
+railcar
+railroad
+railroad bridge
+railway line
+railway station
+rain
+rain boot
+rainbow
+rainbow trout
+raincoat
+rainforest
+rainy
+raisin
+rake
+ram
+ramp
+rapeseed
+rapid
+rapper
+raspberry
+rat
+ratchet
+raven
+ravine
+ray
+razor
+razor blade
+read
+reading
+reamer
+rear
+rear light
+rear view
+rearview mirror
+receipt
+receive
+reception
+recipe
+record
+record producer
+recorder
+recording studio
+recreation room
+recreational vehicle
+rectangle
+recycling
+recycling bin
+red
+red carpet
+red flag
+red panda
+red wine
+redwood
+reed
+reef
+reel
+referee
+reflect
+reflection
+reflector
+register
+rein
+reindeer
+relax
+release
+relief
+religion
+religious
+relish
+remain
+remodel
+remote
+remove
+repair
+repair shop
+reptile
+rescue
+rescuer
+research
+researcher
+reservoir
+residence
+residential neighborhood
+resin
+resort
+resort town
+restaurant kitchen
+restaurant patio
+restroom
+retail
+retriever
+retro
+reveal
+rhinoceros
+rhododendron
+rib
+ribbon
+rice
+rice cooker
+rice field
+ride
+ridge
+riding
+rifle
+rim
+ring
+riot
+ripple
+rise
+rise building
+river
+river bank
+river boat
+river valley
+riverbed
+road
+road sign
+road trip
+roadside
+roast chicken
+robe
+robin
+robot
+stone
+rock arch
+rock artist
+rock band
+rock climber
+rock climbing
+rock concert
+rock face
+rock formation
+rocker
+rocket
+rocking chair
+rocky
+rodent
+rodeo
+rodeo arena
+roe
+roe deer
+roller
+coaster
+roller skate
+roller skates
+rolling pin
+romance
+romantic
+roof
+roof garden
+room
+room divider
+root
+root beer
+rope bridge
+rosary
+rose
+rosemary
+rosy cloud
+rottweiler
+round table
+router
+row
+rowan
+royal
+rubber stamp
+rubble
+rubik's cube
+ruby
+ruffle
+rugby
+rugby ball
+rugby player
+ruins
+ruler
+rum
+run
+runner
+running shoe
+rural
+rust
+rustic
+rye
+sack
+saddle
+saddlebag
+safari
+safe
+safety vest
+sage
+sail
+sailboat
+sailing
+sailor
+squirrel monkey
+sake
+salad
+salad bowl
+salamander
+salami
+sale
+salmon
+salon
+salsa
+salt
+salt and pepper shakers
+salt lake
+salt marsh
+salt shaker
+salute
+samoyed
+samurai
+sand
+sand bar
+sand box
+sand castle
+sand sculpture
+sandal
+sandwich
+sanitary napkin
+santa claus
+sapphire
+sardine
+sari
+sashimi
+satay
+satchel
+satellite
+satin
+sauce
+saucer
+sauna
+sausage
+savanna
+saw
+sawbuck
+sax
+saxophonist
+scaffold
+scale
+scale model
+scallop
+scar
+strawman
+scarf
+scene
+scenery
+schnauzer
+school
+school bus
+school uniform
+schoolhouse
+schooner
+science
+science fiction film
+science museum
+scientist
+scissors
+wall lamp
+scone
+scoop
+scooter
+score
+scoreboard
+scorpion
+scout
+scrambled egg
+scrap
+scraper
+scratch
+screen
+screen door
+screenshot
+screw
+screwdriver
+scroll
+scrub
+scrubbing brush
+sculptor
+sculpture
+sea cave
+sea ice
+sea lion
+sea turtle
+sea urchin
+seabass
+seabed
+seabird
+seafood
+seahorse
+seal
+sea view
+seashell
+seaside resort
+season
+seat
+seat belt
+seaweed
+secretary
+security
+sedan
+see
+seed
+seesaw
+segway
+selfie
+sell
+seminar
+sense
+sensor
+server
+server room
+service
+set
+sewing machine
+shadow
+shake
+shaker
+shampoo
+shape
+share
+shark
+sharpener
+sharpie
+shaver
+shaving cream
+shawl
+shear
+shears
+sheep
+sheet
+sheet music
+shelf
+shell
+shellfish
+shelter
+shelve
+shepherd
+sherbert
+shiba inu
+shine
+shipping
+shipping container
+shipwreck
+shipyard
+shirt
+shirtless
+shoal
+shoe
+shoe box
+shoe shop
+shoe tree
+shoot
+shooting basketball guard
+shop window
+shopfront
+shopper
+shopping
+shopping bag
+shopping basket
+shopping cart
+mall
+shopping street
+shore
+shoreline
+short
+short hair
+shorts
+shot glass
+shotgun
+shoulder
+shoulder bag
+shovel
+showcase
+shower
+shower cap
+shower curtain
+shower door
+shower head
+shredder
+shrew
+shrimp
+shrine
+shrub
+shutter
+siamese
+siberia
+sibling
+side
+side cabinet
+side dish
+sidecar
+sideline
+siding
+sign
+signage
+signal
+signature
+silk
+silk stocking
+silo
+silver
+silver medal
+silverware
+sing
+singe
+singer
+sink
+sip
+sit
+sitting
+skate park
+skateboard
+skateboarder
+skater
+skating rink
+skeleton
+sketch
+skewer
+ski
+ski boot
+ski equipment
+ski jacket
+ski lift
+ski pole
+ski resort
+snowboard
+skier
+skiing shoes
+skin
+skull
+skullcap
+sky
+sky tower
+skylight
+skyline
+skyscraper
+slalom
+slate
+sleigh
+sleep
+sleeping bag
+sleepwear
+sleeve
+slice
+slide
+slider
+sling
+slope
+slot
+slot machine
+sloth
+slow cooker
+slug
+slum
+smell
+smile
+smoke
+snack
+snail
+snake
+snapper
+snapshot
+snorkel
+snout
+snow
+snow leopard
+snow mountain
+snowball
+snowboarder
+snowfield
+snowflake
+snowman
+snowmobile
+snowplow
+snowshoe
+snowy
+soap
+soap bubble
+soap dispenser
+soccer goalkeeper
+socialite
+sock
+socket
+soda
+softball
+software
+solar battery
+soldier
+solo
+solution
+sombrero
+song
+sound
+soup
+soup bowl
+soupspoon
+sour cream
+souvenir
+soybean milk
+spa
+space
+space shuttle
+space station
+spacecraft
+spaghetti
+span
+wrench
+spark
+sparkle
+sparkler
+sparkling wine
+sparrow
+spatula
+speaker
+spectator
+speech bubble
+speed limit
+speed limit sign
+speedboat
+speedometer
+sphere
+spice
+spice rack
+spider
+spider web
+spike
+spin
+spinach
+spire
+splash
+sponge
+spoon
+sport association
+sport equipment
+sport team
+sports ball
+sports equipment
+sports meet
+sportswear
+dot
+spray
+spread
+spring
+spring roll
+sprinkle
+sprinkler
+sprout
+spruce
+spruce forest
+squad
+square
+squash
+squat
+squeeze
+squid
+squirrel
+water gun
+stab
+stable
+stack
+stadium
+staff
+stage
+stage light
+stagecoach
+stain
+stainless steel
+stair
+stairs
+stairwell
+stall
+stallion
+stand
+standing
+staple
+stapler
+star
+stare
+starfish
+starfruit
+starling
+state park
+state school
+station
+stationary bicycle
+stationery
+statue
+steak
+steak knife
+steam
+steam engine
+steam locomotive
+steam train
+steamed bread
+steel
+steering wheel
+stem
+stencil
+step stool
+stereo
+stethoscope
+stew
+stick
+stick insect
+sticker
+still life
+stilt
+stingray
+stir
+stirrer
+stirrup
+sew
+stock
+stocking
+stomach
+stone building
+stone carving
+stone house
+stone mill
+stool
+stop
+stop at
+stop light
+stop sign
+stop watch
+traffic light
+storage box
+storage room
+tank
+store
+storefront
+stork
+storm
+storm cloud
+stormy
+stove
+poker
+straddle
+strainer
+strait
+strap
+straw
+straw hat
+strawberry
+stream
+street art
+street artist
+street corner
+street dog
+street food
+street light
+street market
+street photography
+street scene
+street sign
+street vendor
+stretch
+stretcher
+strike
+striker
+string
+string cheese
+strip
+stripe
+stroll
+structure
+studio
+studio shot
+stuff
+stuffed animal
+stuffed toy
+stuffing
+stump
+stunning
+stunt
+stupa
+style
+stylus
+submarine
+submarine sandwich
+submarine water
+suburb
+subway
+subway station
+subwoofer
+succulent
+suede
+sugar
+sugar bowl
+sugar cane
+sugar cube
+suit
+suite
+summer
+summer evening
+summit
+sun
+sun hat
+sunbathe
+sunday
+sundial
+sunflower
+sunflower field
+sunflower seed
+sunglasses
+sunny
+sunrise
+sunset
+sunshade
+sunshine
+super bowl
+sports car
+superhero
+supermarket
+supermarket shelf
+supermodel
+supporter
+surf
+surface
+surfboard
+surfer
+surgeon
+surgery
+surround
+sushi
+sushi bar
+suspenders
+suspension
+suspension bridge
+suv
+swallow
+swallowtail butterfly
+swamp
+swan
+swan boat
+sweat pant
+sweatband
+sweater
+sweatshirt
+sweet
+sweet potato
+swim
+swim cap
+swimmer
+swimming hole
+swimming pool
+swing
+swing bridge
+swinge
+swirl
+switch
+swivel chair
+sword
+swordfish
+symbol
+symmetry
+synagogue
+syringe
+syrup
+system
+t shirt
+t-shirt
+tabasco sauce
+tabby
+table tennis racket
+table top
+tablecloth
+tablet computer
+tableware
+tachometer
+tackle
+taco
+tae kwon do
+tai chi
+tail
+tailor
+take
+takeoff
+talk
+tambourine
+tan
+tangerine
+tape
+tapestry
+tarmac
+taro
+tarp
+tart
+tassel
+taste
+tatami
+tattoo
+tattoo artist
+tavern
+tea
+tea bag
+tea party
+tea plantation
+tea pot
+tea set
+teach
+teacher
+teacup
+teal
+team photo
+team presentation
+tear
+technician
+technology
+teddy
+tee
+teenager
+telegraph pole
+zoom lens
+telescope
+television
+television camera
+television room
+television studio
+temperature
+temple
+tempura
+tennis
+tennis court
+tennis match
+tennis net
+tennis player
+tennis racket
+tent
+tequila
+terminal
+terrace
+terrain
+terrarium
+territory
+test
+test match
+test tube
+text
+text message
+textile
+texture
+thanksgiving
+thanksgiving dinner
+theater
+theatre actor
+therapy
+thermometer
+thermos
+thermos bottle
+thermostat
+thicket
+thimble
+thing
+thinking
+thistle
+throne
+throne room
+throw
+throw pillow
+thunder
+thunderstorm
+thyme
+tiara
+tick
+ticket
+ticket booth
+tide pool
+tie
+tiger
+tight
+tile
+tile flooring
+tile roof
+tile wall
+tin
+tinfoil
+tinsel
+tiramisu
+tire
+tissue
+toast
+toaster
+tobacco
+tobacco pipe
+toddler
+toe
+tofu
+toilet bowl
+toilet seat
+toiletry
+tokyo tower
+tomato
+tomato sauce
+tomato soup
+tomb
+tong
+tongs
+tool
+toolbox
+toothbrush
+toothpaste
+toothpick
+topiary garden
+topping
+torch
+tornado
+tortilla
+tortoise
+tote bag
+totem pole
+totoro
+toucan
+touch
+touchdown
+tour
+tour bus
+tour guide
+tourist
+tourist attraction
+tournament
+tow truck
+towel
+towel bar
+tower block
+tower bridge
+town
+town square
+toy
+toy car
+toy gun
+toyshop
+track
+tractor
+trade
+tradition
+traditional
+traffic
+traffic cone
+traffic congestion
+traffic jam
+traffic sign
+trail
+trailer
+trailer truck
+train
+train bridge
+train car
+train interior
+train track
+train window
+trainer
+training
+training bench
+training ground
+trolley
+trampoline
+transformer
+transparency
+travel
+tray
+treadmill
+treat
+tree
+tree branch
+tree farm
+tree frog
+tree house
+tree root
+tree trunk
+trial
+triangle
+triathlon
+tribe
+tributary
+trick
+tricycle
+trim
+trio
+tripod
+trombone
+troop
+trophy
+trophy cup
+tropic
+trout
+truck
+truck driver
+tub
+tube
+tugboat
+tulip
+tuna
+tundra
+tunnel
+turbine
+turkey
+turn
+turnip
+turquoise
+turret
+turtle
+tusk
+tv actor
+tv cabinet
+tv drama
+tv genre
+tv personality
+tv show
+tv sitcom
+tv tower
+twig
+twilight
+twin
+twine
+twist
+type
+type on
+typewriter
+ukulele
+ultraman
+umbrella
+underclothes
+underwater
+unicorn
+uniform
+universe
+university
+up
+urban
+urinal
+urn
+use
+utensil
+utility room
+vacuum
+valley
+valve
+vampire
+van
+vanilla
+vanity
+variety
+vase
+vault
+vector cartoon illustration
+vector icon
+vegetable
+vegetable garden
+vegetable market
+vegetation
+vehicle
+veil
+vein
+velvet
+vending machine
+vendor
+vent
+vespa
+vessel
+vest
+vet
+veteran
+veterinarians office
+viaduct
+video
+video camera
+video game
+videotape
+view mirror
+vigil
+villa
+village
+vine
+vinegar
+vineyard
+violence
+violet
+violin
+violinist
+violist
+vision
+visor
+vodka
+volcano
+volleyball
+volleyball court
+volleyball player
+volunteer
+voyage
+vulture
+waffle
+waffle iron
+wagon
+wagon wheel
+waist
+waiter
+waiting hall
+waiting room
+walk
+walking
+walking cane
+wall clock
+wallpaper
+walnut
+walrus
+war
+warehouse
+warm
+warning sign
+warrior
+warship
+warthog
+wash
+washer
+washing
+washing machine
+wasp
+waste
+waste container
+watch
+water
+water bird
+water buffalo
+water cooler
+water drop
+water feature
+water heater
+water level
+water lily
+water park
+water pipe
+water purifier
+water ski
+water sport
+water surface
+water tower
+watercolor
+watercolor illustration
+watercolor painting
+waterfall
+watering can
+watermark overlay stamp
+watermelon
+waterproof jacket
+waterway
+wave
+wax
+weapon
+wear
+weather
+vane
+web
+webcam
+wedding
+wedding ring
+wedding bouquet
+wedding cake
+wedding couple
+wedding invitation
+wedding party
+wedding photo
+wedding photographer
+wedding photography
+wedding reception
+wedge
+weed
+weight
+weight scale
+welder
+well
+western food
+western restaurant
+wet
+wet bar
+wet suit
+wetland
+wetsuit
+whale
+whale shark
+wheat
+wheat field
+wheel
+wheelchair
+wheelie
+whipped cream
+whisk
+whisker
+whiskey
+whistle
+white
+white house
+white wine
+whiteboard
+wicket
+wide
+wield
+wig
+Wii
+Wii controller
+wild
+wildebeest
+wildfire
+wildflower
+wildlife
+willow
+wind
+wind chime
+wind farm
+wind turbine
+windmill
+window
+window box
+window display
+window frame
+window screen
+window seat
+window sill
+wiper
+windshield
+windy
+wine bottle
+wine cooler
+wine cabinet
+wine cellar
+wine glass
+wine rack
+wine tasting
+winery
+wing
+winter
+winter melon
+winter morning
+winter scene
+winter sport
+winter storm
+wire
+wisteria
+witch
+witch hat
+wok
+wolf
+woman
+wood
+wood duck
+wood floor
+wood wall
+wood-burning stove
+wooden spoon
+woodland
+woodpecker
+woodworking plane
+wool
+job
+work card
+workbench
+worker
+workplace
+workshop
+world
+worm
+worship
+wound
+wrap
+wrap dress
+wrapping paper
+wrestle
+wrestler
+wrinkle
+wristband
+write
+writer
+writing
+writing brush
+writing desk
+yacht
+yak
+yard
+yellow
+yoga
+yoga mat
+yoghurt
+yoke
+yolk
+youth
+youth hostel
+yurt
+zebra
+zebra crossing
+zen garden
+zip
+zipper
+zombie
+zongzi
+zoo
\ No newline at end of file
diff --git a/ram/data/ram_tag_list_chinese.txt b/ram/data/ram_tag_list_chinese.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3f61dc0b84ed58e019d7e331555ef438f2ded2de
--- /dev/null
+++ b/ram/data/ram_tag_list_chinese.txt
@@ -0,0 +1,4585 @@
+三维CG渲染
+3d眼镜
+算盘
+鲍鱼
+修道院
+肚子
+学院
+附件
+事故
+手风琴
+橡子
+丙烯颜料
+表演
+行动
+动作电影
+活动
+演员
+改编本
+添加
+胶带
+调整
+成人
+冒险
+广告
+天线
+有氧运动
+喷雾罐
+爆炸头
+农业
+帮助
+空调
+空调系统
+风向标
+飞机客舱
+飞机模型
+机场
+航线
+客机
+飞行员
+飞机
+飞机窗口
+机场
+机场跑道
+航站楼
+飞艇
+航展
+过道
+警报
+闹钟
+信天翁
+唱片
+唱片封面
+酒精
+壁龛
+水藻
+胡同/球道
+杏仁
+芦荟
+高山
+羊驼
+字母表
+德国牧羊犬
+圣坛
+琥珀
+救护车
+秃鹰
+美国短毛猫
+紫水晶
+圆形剧场
+扩音器
+游乐园
+游乐设施
+锚
+古老的
+海葵
+天使
+角
+动物
+动物雕塑
+动物收容所
+动画片
+动画电影
+动画师
+动漫
+脚踝
+短袜
+周年庆
+风衣
+蚂蚁
+羚羊
+古董
+鹿角
+铁砧
+公寓
+猿
+应用程序
+应用图标
+出现
+外观
+开胃菜
+掌声
+苹果
+苹果汁
+苹果派
+苹果树
+苹果酱
+设备
+约定
+通道
+杏子
+围裙
+浅绿色
+水族馆
+观赏鱼
+渡槽
+游乐中心
+商场游戏机
+拱门
+拱桥
+考古现场
+射箭
+群岛
+建筑师
+建筑设计
+档案
+拱门
+地区
+竞技场
+争论
+手臂
+穿山甲
+臂章
+扶手椅
+衣柜
+盔甲
+军队
+军事基地
+坦克
+阵列
+逮捕
+箭头
+艺术
+艺术展
+美术馆
+艺术印刷品
+艺术学校
+艺术工作室
+艺术矢量插图
+洋蓟
+文章
+手工艺品
+艺术家
+艺术阁楼
+灰
+烟灰缸
+亚洲寺庙
+芦笋
+沥青道路
+组装
+集会
+生产流水线
+协会
+宇航员
+天文学家
+运动员
+运动
+地图集
+自助取款机
+大气层
+中庭
+连接
+战斗机
+参加
+吸引力
+全地形车
+茄子
+拍卖
+奥迪汽车
+音频
+礼堂
+极光
+作者
+汽车厂
+汽车修理工
+汽车零件
+车展
+汽车展厅
+汽车电池
+汽车制造
+汽车模型
+汽车
+秋天
+秋天的森林
+秋天的叶子
+秋天的公园
+秋天的树
+阿凡达
+林荫大道
+飞行员太阳镜
+牛油果
+奖品
+颁奖典礼
+获奖者
+棚
+斧头
+杜鹃花
+狒狒
+婴儿
+奶瓶
+婴儿车
+婴儿衣服
+小象
+婴儿食品
+婴儿座椅
+迎婴派对
+背后/后面
+背景
+背光
+背包
+后院
+培根
+徽章
+獾
+荒地
+羽毛球运动
+羽毛球拍
+袋子
+面包圈
+风笛
+法棍
+诱饵
+焙烤食品
+面包师
+面包店
+烘焙
+烤盘
+平衡
+平衡车
+阳台
+球
+球池
+芭蕾舞女演员
+芭蕾舞
+芭蕾舞演员
+芭蕾舞裙
+气球
+气球拱门
+棒球手
+舞厅
+竹子
+竹林
+香蕉
+香蕉面包
+香蕉叶子
+香蕉树
+乐队
+创可贴
+绷带
+头巾
+束发带
+刘海
+手镯
+栏杆
+五弦琴
+银行
+银行卡
+银行金库
+纸币
+横幅/旗帜
+宴会
+宴会厅
+榕树
+包子
+洗礼
+酒吧
+条形码
+高脚凳
+烧烤
+烧烤架
+杠铃
+理发师
+理发店
+芭比娃娃
+驳船
+咖啡师
+树皮
+大麦
+谷仓
+仓鸮
+挡光板
+桶
+路障
+屏障
+手推车
+酒保
+棒球
+棒球基地
+棒球棒
+棒球帽
+棒球场
+棒球比赛
+棒球手套
+棒球投手
+棒球队
+棒球制服
+地下室
+罗勒
+水盆
+篮子
+篮子
+篮球
+篮球篮板
+篮球教练
+篮球场
+篮球比赛
+篮球框
+篮球运动员
+篮球馆
+篮球队
+贝斯
+低音吉他
+低音喇叭
+贝斯手
+球棒/球拍
+浴室
+水浴加热器
+浴垫
+浴巾
+泳装
+浴袍
+浴室
+浴室配件
+浴室柜
+浴室门
+浴室镜子
+浴室水槽
+卫生纸
+浴室窗户
+蝙蝠侠
+棒子
+接连猛打/击球员
+电池
+战斗
+战绳
+战舰
+海湾
+海湾大桥
+凸窗
+杨梅
+集市
+海滩
+沙滩球
+沙滩椅
+海滨别墅
+海滩小屋
+沙滩毛巾
+沙滩排球
+灯塔
+珠子
+比格犬
+鸟嘴
+烧杯
+横梁
+豆子
+豆袋椅
+豆袋
+熊
+幼熊
+胡子
+野兽
+击打/击败
+美丽的
+美丽
+美容院
+海狸
+床
+床单
+床架
+卧室
+床上用品
+便盆
+卧室窗户
+床头灯
+蜜蜂
+山毛榉
+牛肉
+养蜂人
+蜂鸣器
+啤酒
+啤酒瓶
+啤酒罐
+啤酒花园
+啤酒杯
+啤酒馆
+甜菜
+甲虫
+米色
+时钟
+甜椒
+钟楼
+皮带
+皮带扣
+长凳
+弯曲
+孟加拉虎
+盒饭
+贝雷帽
+浆果
+停泊位
+饮料
+围嘴
+拌饭
+圣经
+比熊
+自行车
+自行车头盔
+自行车车轮
+自行车骑士
+坐浴盆
+大本钟
+自行车道
+自行车道
+自行车赛
+骑车
+比基尼
+比基尼上衣
+账单
+台球
+广告牌
+台球台
+垃圾箱
+活页夹
+双筒望远镜
+生物学实验室
+双翼飞机
+桦木
+桦树
+鸟
+鸟池
+喂鸟器
+鸟舍
+鸟巢
+鸟池
+鸟笼
+出生
+生日
+生日蛋糕
+生日蜡烛
+生日贺卡
+生日聚会
+饼干
+主教
+野牛
+钻头
+咬
+黑色
+黑山羊
+黑莓
+乌鸦
+黑板
+铁匠
+叶片/刀片
+毯子/覆盖层
+运动外套
+看台
+搅拌机
+祝福
+窗帘
+眼罩
+闪光
+暴风雪
+块
+博客
+血
+开花
+花
+女装衬衫
+吹
+吹风机
+河豚
+蓝色
+蓝色艺术家
+蓝松鸦
+蓝天
+蓝莓
+蓝知更鸟
+猪
+板子
+板擦
+棋盘游戏
+木板路
+船
+船甲板
+船屋
+桨
+乘船
+浮标
+山猫
+躯干
+身体冲浪板
+健美运动员
+水煮鸡蛋
+锅炉
+饰扣式领带
+门闩
+炸弹
+轰炸机
+披肩榛鸡
+骨骼
+篝火
+阀盖
+盆景
+书
+书籍封面
+书柜
+文件夹
+书签
+书架
+书店
+远程拾音器
+推动
+靴子
+边界
+边境牧羊犬
+植物园
+瓶
+瓶盖
+开瓶器
+螺旋开瓶器
+三角梅
+巨石
+花束
+时装店
+精品酒店
+鞠躬/蝴蝶结
+领结
+弓形窗
+碗
+保龄球运动
+保龄球馆
+保龄球
+保龄球设备
+盒子
+箱形梁桥
+箱龟
+拳击手
+内裤
+拳击
+拳击手套
+拳击台
+男孩
+支撑物
+支架
+辫子
+大脑
+刹车
+刹车灯
+树枝
+商标
+白兰地
+黄铜
+黄铜牌匾
+面包
+面包箱
+休息
+早餐
+防浪堤
+胸部
+啤酒厂
+砖块
+砖建筑物
+墙
+砖块
+婚纱
+新娘
+新郎
+伴娘
+桥
+缰绳
+公文包
+明亮的
+边沿
+钻头
+广播
+西兰花
+青铜
+铜牌
+青铜雕塑
+青铜雕像
+胸针
+小溪
+扫帚
+肉汤
+棕色
+棕熊
+巧克力蛋糕
+早午餐
+浅黑肤色的女人
+刷子
+郊狼
+包菜
+气泡
+泡泡糖
+珍珠奶茶
+斗柜
+盾牌
+芽
+佛
+水牛
+自助餐
+昆虫
+建造
+建造者
+建筑
+积木
+建筑立面
+建筑材料
+灯
+牛
+斗牛犬
+子弹
+动车
+公告栏
+防弹背心
+斗牛
+扩音器
+斗牛场
+大黄蜂
+保险杠
+卷/地形起伏
+捆
+蹦极
+双层床
+地堡/击球
+兔子
+浮标
+书桌
+墓室
+燃烧
+玉米煎饼
+公交车
+公交车司机
+公交车内部
+公交车站
+公交车站
+公交车窗户
+灌木
+商业
+名片
+业务主管
+商务西装
+业务团队
+女商人
+商人
+半身像
+屠夫
+肉铺
+孤峰
+黄油
+奶油
+蝴蝶
+蝴蝶馆
+按钮
+梧桐树
+购买
+出租车
+小屋
+卷心菜
+小屋/机舱
+守车
+储藏柜
+橱柜
+电缆
+缆车
+仙人掌
+咖啡馆
+食堂
+笼子
+蛋糕
+蛋糕台
+计算器
+大锅
+日历
+小腿
+通话
+电话亭
+书法
+平静的
+摄像机
+骆驼
+相机
+相机镜头
+迷彩
+露营
+露营者
+篝火
+露营
+营地
+校园
+罐
+开罐器
+运河
+金丝雀
+癌症
+蜡烛
+烛台
+糖果
+块状糖
+柺杖糖
+糖果店
+拐杖
+罐子
+大炮
+树冠/顶棚
+四柱床
+香瓜
+悬臂桥
+帆布
+峡谷
+帽子
+斗篷
+科德角
+卡布奇诺
+胶囊
+队长
+捕获
+车
+汽车经销商
+车门
+汽车内饰
+车标
+后视镜
+停车场
+汽车座椅
+车展
+洗车
+车窗
+焦糖
+卡片
+纸牌游戏
+纸板
+纸板盒
+羊毛衫
+红衣凤头鸟
+货物
+货运飞机
+货船
+加勒比
+康乃馨
+狂欢节
+食肉动物
+旋转木马
+鲤鱼
+木匠
+地毯
+拖鞋
+红雀
+长途客车
+斑点狗
+航空母舰
+胡萝卜
+胡萝卜蛋糕
+携带
+手推车
+纸箱/纸盒
+卡通
+卡通人物
+卡通插图
+卡通风格
+雕刻
+容器
+现金
+腰果
+赌场
+砂锅
+磁带
+盒式录音机
+石膏绷带
+铸造
+城堡
+猫
+猫窝
+猫粮
+猫器具
+猫架
+地下墓穴
+双体船
+美洲狮
+握着/抓着
+捕手
+毛毛虫
+鲶鱼
+教堂
+牛
+猫步
+走秀
+菜花
+洞穴
+鱼子酱
+光盘
+CD播放器
+雪松
+天花板
+吊扇
+庆祝
+庆典
+名人
+芹菜
+大提琴
+手机
+水泥
+墓地
+中心装饰品
+蜈蚣
+陶瓷
+瓷砖
+麦片
+仪式
+证书
+链条
+链锯
+椅子
+升降椅
+躺椅
+木屋
+圣杯
+粉笔
+房间
+变色龙
+香槟酒
+香槟杯
+冠军
+锦标赛
+吊灯
+婴儿换尿布台
+通道
+皴裂处
+小教堂
+人物雕塑
+木炭
+充电
+充电器
+战车
+慈善机构
+慈善活动
+魅力
+图表
+追逐
+底盘
+检查/支票
+支票簿
+棋盘
+检查表
+欢呼声
+鼓励/啦啦队
+奶酪
+奶酪汉堡
+奶酪蛋糕
+猎豹
+厨师
+化合物
+化学家
+化学
+化学实验室
+旗袍
+樱桃
+樱花
+樱桃番茄
+樱桃树
+国际象棋
+栗子
+鸡
+鸡胸肉
+鸡笼
+鸡肉沙拉
+鸡翅
+鹰嘴豆
+小衣橱
+吉娃娃
+孩子
+童星
+孩子的房间
+红番椒
+辣热狗
+烟囱
+黑猩猩
+瓷器
+白菜
+中国园林
+中国结
+月季
+中国塔
+炸薯条/炸薯条
+花栗鼠
+凿子
+巧克力
+巧克力棒
+巧克力蛋糕
+巧克力碎片
+巧克力饼干
+巧克力牛奶
+巧克力慕斯
+松露
+唱诗班
+厨房刀
+砧板
+筷子
+圣诞节
+圣诞球
+圣诞贺卡
+圣诞装饰
+圣诞晚宴
+平安夜
+圣诞帽
+圣诞灯
+圣诞市场
+圣诞装饰
+圣诞树
+菊花
+教堂
+教堂塔
+苹果酒
+雪茄
+雪茄盒
+香烟
+烟盒
+腰带
+电影院
+摄影师
+肉桂
+圆
+电路
+电路板
+马戏团
+水箱
+柑橘类水果
+城市
+城市公交
+市政厅
+城市夜景
+城市公园
+城市天际线
+城市广场
+城市街道
+城墙
+城市景观
+蛤蜊
+单簧管
+扣子
+班级
+经典
+教室
+锁骨
+爪子
+黏土
+陶器
+清洁
+洁净室
+清洁工人
+清洁用品
+清晰的
+栓
+克莱门氏小柑橘
+客户端
+悬崖
+爬
+爬山
+登山者
+诊所
+夹子
+剪贴画
+剪贴板
+快速帆船
+君子兰
+斗篷
+木底鞋
+特写
+壁橱
+布
+穿衣
+衣服
+晒衣夹
+晒衣绳
+服装店
+云
+云雾森林
+多云
+三叶草
+小丑
+小丑鱼
+俱乐部
+离合器
+手拿包
+煤炭
+海岸
+外套
+衣帽架
+玉米
+公鸡
+凤头鹦鹉
+可卡犬
+驾驶
+蟑螂
+鸡尾酒
+小礼服
+鸡尾酒调制器
+鸡尾酒桌
+可可
+椰子
+椰子树
+咖啡
+咖啡豆
+咖啡杯
+咖啡机
+咖啡店
+咖啡壶
+棺材
+法国白兰地
+螺旋
+硬币
+可口可乐
+滤器
+冷的
+卷心菜沙拉
+合作
+拼贴画
+收藏品
+大学生
+牧羊犬
+碰撞
+颜色
+涂色书
+染色材料
+矮种马
+柱子
+梳子
+密码锁
+喜剧演员
+喜剧
+喜剧电影
+彗星
+舒服
+安慰食物
+漫画书
+漫画人物
+连环画
+指挥官
+评论员
+社区
+通勤
+公司
+指南针
+比赛
+比赛
+竞争者
+作曲家
+作文
+堆肥
+电脑
+电脑机箱
+电脑椅
+电脑桌
+键盘
+计算机显示器
+计算机房
+电脑屏幕
+机箱
+概念车
+音乐会
+音乐厅
+贝壳
+混凝土
+调味品
+避孕套
+独立产权的公寓
+指挥
+锥形物
+会议
+会议中心
+会议厅
+会议室
+五彩纸屑
+冲突
+合流
+连接
+连接器
+温室
+星座
+建筑工地
+建筑工人
+包含
+容器
+集装箱船
+大陆
+轮廓
+合同
+控制
+控制塔
+便利店
+集会
+交谈
+转换器
+可转换的
+输送机
+厨师/烹饪
+烹饪
+烹饪喷雾剂
+炊具
+凉的
+冷却器
+铜
+一本/一册
+珊瑚
+珊瑚礁
+粗绳
+有线电话
+酒
+威尔士矮脚狗
+瓶塞
+软木板
+鸬鹚
+玉米
+玉米田
+玉米面包
+角落
+小号
+飞檐
+燕麦片
+围栏
+走廊
+紧身衣
+化妆品
+化妆刷
+化妆镜
+角色扮演
+服装
+服装电影设计师
+婴儿床
+小屋
+棉花
+棉花糖
+沙发
+倒计时
+柜台
+台面
+最佳乡村歌手
+乡村别墅
+乡村公路
+乡村流行歌手
+农村
+双门小轿车
+夫妇/两人/几个
+情侣写真
+小胡瓜
+课程
+球场
+法院
+院子
+堂兄弟
+工作服
+奶牛
+母牛的颈铃
+牛仔
+牛仔靴
+牛仔帽
+螃蟹
+蟹肉
+裂纹
+摇篮
+工艺
+工匠
+蔓越莓
+起重机
+黑纱
+厕所
+板条箱
+火山口湖
+龙虾
+蜡笔
+奶油乳酪
+奶油罐
+创建
+生物
+信用卡
+新月形
+新月形面包
+山顶
+全体船员
+蟋蟀
+板球用球
+板球队
+板球队员
+钩边
+克罗克电锅
+鳄鱼
+庄稼
+露脐上衣
+交叉
+横木
+十字路口
+相声
+人行横道
+油煎面包块
+乌鸦
+撬棍
+人群
+拥挤的
+皇冠
+阴极射线管屏幕
+耶稣受难像
+巡游
+游轮
+巡洋艇
+面包屑
+压坏
+拐杖
+水晶
+幼兽
+立方体
+黄瓜
+球杆
+袖口
+袖扣
+烹饪
+农田
+杯子
+纸杯蛋糕
+丘比特
+马路牙子
+旋度
+卷发器
+无籽葡萄干
+货币
+咖喱
+窗帘
+曲线
+软垫
+顾客
+切
+餐具
+自行车
+骑自行车
+龙卷风
+汽缸
+铙钹
+柏树
+柏树
+达克斯猎狗
+水仙花
+匕首
+大丽花
+萝卜
+乳制品
+雏菊
+大坝
+损害
+潮湿的
+跳舞
+舞池
+舞蹈室
+舞者
+蒲公英
+黑暗
+黑暗
+飞镖
+圆靶
+指示板
+日期
+女儿
+黎明
+天床上
+日光
+门栓
+死亡
+辩论
+碎片
+玻璃水瓶
+甲板
+双层巴士
+装饰
+装修/装饰
+装饰画
+鹿
+后卫
+神
+熟食
+投递
+拆迁
+怪兽
+演示
+兽窝/休闲室
+牛仔夹克
+牙医
+百货商店
+抑郁症
+德比
+皮肤病
+沙漠
+沙漠公路
+设计
+设计师
+桌子/表格
+台灯
+桌面
+台式电脑
+甜点
+破坏
+侦探
+洗涤剂
+露水
+仪表盘
+钻石
+尿布
+尿布包
+杂志
+死
+饮食
+挖掘机
+数字
+数字时钟
+莳萝
+晚餐
+小船
+餐厅
+晚宴
+餐桌
+恐龙
+浸
+文凭
+指引
+导演
+尘埃
+越野摩托车
+泥土地
+泥土路
+泥路/土路
+灾难
+信徒
+迪斯科舞厅
+迪斯科灯秋
+迪斯科舞厅
+疾病
+盘子
+碟形天线
+洗碗机
+抹布
+菜肴
+洗碗液
+迪斯尼乐园
+自动售货机
+展示
+陈列窗
+壕沟
+潜水
+潜水员
+跳水板
+纸杯
+流行音乐播音员
+杜宾犬
+码头
+医生
+文件
+纪录片
+狗
+狗窝
+犬种
+狗项圈
+狗粮
+狗窝
+洋娃娃
+美元
+玩偶之家
+洋娃娃
+海豚
+穹顶
+住宅
+多米诺骨牌
+驴
+甜甜圈
+涂鸦
+门
+门把手
+受气包
+门牌
+门口
+宿舍
+面团
+市中心
+推土机
+拖
+龙
+蜻蜓
+排水沟
+剧本
+戏剧电影
+画
+抽屉里
+图画/画画
+图钉
+辫子
+连衣裙/特定场合的服装
+礼帽
+正装衬衫
+皮鞋
+大礼服
+梳妆台
+更衣室
+运球
+漂移
+浮木
+钻
+饮品/喝
+饮用水
+开车
+司机
+车道
+无人机
+水滴/下降
+吊灯
+滴管
+干旱
+药物
+药店
+鼓
+鼓手
+鸡腿
+干的
+公爵夫人
+鸭子
+鸭嘴兽
+小鸭子
+布基胶带
+伙计
+二重唱
+粗呢
+独木舟
+哑铃
+饺子
+沙丘
+扣篮
+榴莲
+黄昏
+灰尘
+垃圾车
+簸箕
+羽绒被
+DVD
+染料
+鹰
+耳朵
+御寒耳罩
+耳机
+耳塞
+耳环
+地震
+画架
+复活节
+复活节兔子
+复活节彩蛋
+吃
+餐厅
+泡芙
+日食
+生态系统
+编辑
+教育
+教育家
+鳗鱼
+蛋
+蛋卷
+蛋挞
+打蛋器
+白鹭
+埃菲尔铁塔
+橡皮筋
+上级
+电椅
+电钻
+电工
+电
+电子
+电子器件
+大象
+高度图
+电梯
+电梯轿厢
+电梯门
+电梯大堂
+电梯井
+路堤
+大使馆
+装饰
+灰烬
+会徽
+刺绣
+翡翠
+紧急
+紧急服务
+紧急车辆
+情感
+帝国大厦
+搪瓷
+外壳/围墙
+茶几
+能源
+订婚
+订婚戒指
+引擎
+机舱
+工程师
+工程
+英国短毛猫
+乐团
+回车键
+演艺人员
+娱乐
+娱乐中心
+入口
+入口大厅
+信封
+马术
+设备
+橡皮擦
+二胡
+侵蚀
+自动扶梯
+食用蜗牛
+浓缩咖啡
+房地产
+河口
+桉树
+晚上
+晚礼服
+夜光
+傍晚天空
+晚上的太阳
+事件
+常绿的
+母羊
+挖掘
+运动
+排气罩
+展览
+出口
+探险者
+爆炸
+延长线
+灭火器
+排气扇
+挤压
+眼睛
+眼影
+眉
+眼线笔
+布料
+纺织品商店
+外观
+脸
+脸部特写
+蜜粉
+毛巾
+面巾纸架
+设施
+工厂
+工厂车间
+集市
+露天市场
+仙女
+猎鹰
+秋天
+家庭
+家庭轿车
+全家福
+家庭房
+风扇/扇子
+尖牙
+农场
+农民
+农民市场
+农舍
+时尚
+时尚配饰
+时装设计师
+时尚的女孩
+时装插图
+时装大片
+时装模特
+时装表演
+快餐
+西式快餐
+父亲
+水龙头
+故障
+动物
+小鹿
+传真
+宴会
+羽毛
+软呢帽
+饲料
+一餐
+饲养
+喂养的椅子
+猫科
+美洲狮
+栅栏
+芬达
+蕨类植物
+雪貂
+摩天轮
+渡船
+肥料
+节日
+纤维
+小说
+小说书
+田野/场地/野外
+田间道路
+无花果
+打架
+花样滑冰运动员
+小雕像
+文件
+档案照片
+文件柜
+填满
+胶片相机
+电影导演
+电影格式
+电影首映礼
+电影制片人
+拍摄
+过滤器
+鳍
+手
+终点线
+冷杉
+冷杉树
+火
+火灾报警
+消防部门
+消防车
+消防通道
+消防水带
+火坑
+消防站
+爆竹
+消防队员
+壁炉
+烟花
+烟花表演
+急救箱
+鱼
+鱼船
+海鲜市场
+鱼塘
+鱼缸
+渔夫
+钓鱼
+渔船
+渔网
+钓鱼
+渔村
+健身
+健身课程
+五个
+固定装置
+峡湾
+国旗
+旗杆
+小薄片
+火焰
+火烈鸟
+法兰绒
+拍打
+耀斑
+闪光
+烧瓶
+平
+比目鱼
+风味
+跳蚤
+跳蚤市场
+舰队
+飞行
+空中乘务员
+翻转
+触发器
+翻转图
+浮动
+群
+洪水
+地板/地面
+落地扇
+脚垫
+楼层平面图
+落地窗
+插花艺术
+花店
+牙线
+面粉
+流动
+花
+花篮
+花坛
+花箱
+花田
+花童
+花卉市场
+流体
+冲洗
+长笛
+飞
+飞行钓鱼
+传单
+马
+泡沫
+雾
+多雾的
+鹅肝酱
+箔纸
+折椅
+树叶
+民间艺术家
+民间舞蹈
+民间摇滚艺术家
+方旦糖
+火锅
+圣洗池
+食物
+食用色素
+美食广场
+食品加工机
+小吃摊
+快餐车
+桌上足球
+脚
+人行桥
+足球
+足球教练
+大学橄榄球赛
+足球比赛
+足球场
+足球比赛
+橄榄球头盔
+足球运动员
+足球场
+足球队
+小路
+脚印
+脚踏板
+台座
+鞋子
+故宫
+浅滩
+额头
+森林
+森林大火
+森林地面
+森林小路
+森林公路
+锻造
+餐叉
+叉车
+表格
+园林
+队列/形成物
+F1方程式赛车
+堡垒
+碉堡
+追逐
+化石
+粉底
+喷泉
+钢笔
+狐狸
+框架
+雀斑
+高速公路
+卡车
+法国
+法国斗牛犬
+薯条
+法式吐司
+化妆水
+冰箱
+炸鸡
+煎蛋
+炒饭
+友谊
+飞盘
+青蛙
+霜
+结霜
+严寒
+结冰
+水果
+水果蛋糕
+水果盘
+水果市场
+水果沙拉
+水果摊
+果树
+水果商店
+油炸食品
+煎锅
+软糖
+燃料
+吸烟罩
+有趣的
+葬礼
+真菌
+漏斗
+毛皮衣服
+毛皮大衣
+家具
+蒲团
+小工具
+枪口
+星云/星系
+美术馆
+游戏
+游戏棋盘
+游戏手柄
+火腿
+团伙
+车库
+车库门
+手工模型
+垃圾
+花园
+花园芦笋
+橡胶软管
+花园蜘蛛
+园丁
+园艺
+加菲猫
+滴水嘴
+花环
+大蒜
+衣服
+气体
+加油站
+煤气炉
+防毒面具
+收集
+聚集
+测量仪器
+露台
+齿轮
+壁虎
+艺妓
+凝胶
+百货商店
+发电机
+天竺葵
+幽灵
+礼物
+礼品袋
+礼品篮
+礼物盒
+礼品卡
+礼品商店
+礼物包装
+演唱会
+杜松子酒
+姜
+姜饼
+姜饼屋
+银杏树
+长颈鹿
+女孩
+给
+冰川
+角斗士
+玻璃珠
+玻璃瓶
+玻璃碗
+玻璃箱
+玻璃建筑
+玻璃门
+玻璃地板
+玻璃屋
+玻璃罐
+玻璃板
+玻璃桌子
+玻璃花瓶
+玻璃墙
+玻璃窗
+眼镜
+光滑面
+滑翔机
+地球
+手套
+发光
+汤圆
+去
+袭击
+球门
+守门员
+山羊
+羊奶酪
+戈壁
+护目镜/墨镜
+黄金
+金牌
+金门大桥
+金毛猎犬
+金鱼
+高尔夫运动
+高尔夫球帽
+高尔夫球车
+高尔夫球杆
+高尔夫球场
+高尔夫球手
+鹅
+大猩猩
+哥特式
+葫芦
+政府
+政府机构
+礼服
+毕业生
+毕业典礼
+谷物
+逆戟鲸
+大奖赛
+祖父
+祖母
+祖父母
+花岗岩
+格兰诺拉麦片
+葡萄
+西柚
+葡萄酒
+草
+蚱蜢
+草原
+长满草的
+擦菜器
+坟墓
+碎石
+墓碑
+肉汁
+调味汁瓶
+灰色
+吃草
+放牧
+绿色
+绿色植物
+欢迎
+问候
+贺卡
+灰狗
+网格
+筛子
+烧烤架
+格栅
+烤鳗鱼
+磨
+研磨机
+粗燕麦粉
+杂货袋
+洞穴
+地松鼠
+群体
+合影
+小树林
+生长
+牛油果酱
+警卫
+看门狗
+宾馆
+客房
+指南
+豚鼠
+吉他
+吉他手
+海湾
+海鸥
+枪
+高达
+谒师所
+古筝
+健身房
+体操运动员
+栖息地
+黑客
+冰雹
+头发
+头发颜色
+发胶
+毛刷
+发型
+发夹
+发网
+发夹
+发型
+一半
+礼堂
+万圣节
+万圣节服装
+万圣节南瓜
+露背装
+汉堡
+汉堡包
+哈密瓜
+锤子
+吊床
+阻碍
+仓鼠
+烘手机
+放大镜
+擦手巾
+手提包
+手球
+手铐
+手枪
+手帕
+把手
+手锯
+握手
+倒立
+手写
+汉服
+悬挂
+飞机库
+衣架
+幸福
+海港
+斑海豹
+硬摇滚艺术家
+精装书
+建筑工人
+硬件
+五金店
+硬木
+硬木地板
+口琴
+管风琴
+羽管键琴
+收获
+收割机
+坐垫/搁脚凳/草丛
+帽子
+帽盒
+双簧管
+山楂
+干草
+干草地
+榛子
+头
+主教练
+大灯
+床头板
+头饰
+海岬
+总部
+听力
+心脏
+心形
+热能
+加热器
+帚石楠
+树篱
+刺猬
+脚后跟
+直升机
+直升机机场
+头盔
+帮助
+母鸡
+指甲花
+药草
+兽群
+寄居蟹
+英雄
+苍鹭
+芙蓉花
+芙蓉花
+隐藏/隐蔽处
+高杠
+高跟鞋
+高地
+突出
+徒步旅行
+徒步旅行者
+徒步靴
+登山设备
+山丘
+丘陵地
+别墅
+山坡
+印度教寺庙
+铰链
+臀部
+嘻哈艺人
+河马
+历史学家
+历史遗迹
+历史
+曲棍球
+冰球馆
+曲棍球比赛
+曲棍球运动员
+曲棍球棒
+锄头
+洞
+假日
+冬青树
+海参
+家/住宅
+家用电器
+基地
+家居装饰
+室内设计
+内政部
+家庭影院
+家庭作业
+鹰嘴豆泥
+蜂蜜
+蜂窝
+蜜月
+风帽
+连帽衫
+挂钩/勾住
+跳
+地平线
+犀鸟
+长角牛
+大黄蜂
+震惊
+恐怖电影
+马鞍褥
+马车
+马场
+骑马
+马背
+马蹄铁
+软管
+医院
+医院病床
+病房
+主持人
+小旅馆
+热
+热气球
+热狗
+辣椒酱
+温泉
+旅馆
+酒店大堂
+酒店房间
+电炉
+沙漏
+房子
+房子外部
+室内植物
+悬滑板
+吼
+蜷缩
+拥抱
+呼啦圈
+人
+增湿器
+蜂鸟
+座头鲸
+打猎
+狩猎小屋
+障碍
+飓风
+哈士奇
+小屋
+鬣狗
+混合物
+绣球花
+消火栓
+水上飞机
+冰
+冰袋
+北极熊
+冰洞
+冰淇淋
+冰淇淋蛋卷
+冰淇淋商店
+冰块
+浮冰
+冰球运动员
+冰球队
+棒棒糖
+制冰机
+溜冰场
+冰雕
+冰架
+溜冰鞋
+滑冰
+冰山
+冰柱
+糖衣/酥皮
+图标
+身份证照片
+身份证
+冰屋
+光/灯光/光线
+鬣蜥蜴
+照亮
+插图
+形象
+黑斑羚
+熏香
+独立日
+个人
+室内
+划船器
+电磁炉
+工业区
+工业
+步兵
+充气艇
+服务台
+基础设施
+成分
+吸入器
+注射
+受伤
+墨水
+印泥
+小湖湾
+题词
+昆虫
+安装
+乐器/器械
+绝缘杯
+互动
+室内设计
+网站
+十字路口
+面试
+无脊椎动物
+邀请
+平板电脑
+苹果手机
+苹果音乐播放器
+虹膜
+铁
+熨衣板
+灌溉系统
+岛
+小岛
+等足类动物
+象牙
+常青藤
+居酒屋
+千斤顶
+帝王蟹/蟹
+夹克衫
+按摩浴缸
+玉
+美洲虎
+监狱牢房
+果酱
+日式花园
+茉莉花
+下巴
+松鸦
+爵士乐
+爵士乐艺术家
+爵士融合艺术家
+牛仔裤
+吉普车
+果冻
+果冻豆
+水母
+喷气式飞机
+摩托艇
+珠宝
+珠宝
+珠宝店
+拼图游戏
+人力车
+赛马骑师
+赛马帽
+慢跑
+联合的
+记者
+操纵杆
+法官
+水壶
+玩杂耍
+果汁
+榨汁器
+枣子
+跳绳
+连身裤
+丛林
+废品堆放场
+羽衣甘蓝
+万花筒
+袋鼠
+卡拉ok
+空手道
+卡丁车运动
+旧城区
+皮船
+烤肉串
+按键/钥匙
+门卡
+卡其色
+踢
+苏格兰裙
+和服
+幼儿园教室
+幼儿园
+国王
+帝王蟹
+亲吻
+工具包
+厨房
+厨房橱柜
+厨房台面
+厨房地板
+厨房抽油烟机
+厨房岛
+厨房水槽
+厨房桌子
+厨房用具
+厨房窗户
+厨房用具
+风筝
+猕猴桃
+护膝
+跪下
+餐刀
+骑手
+编织
+编织针
+球形把手
+门环
+结
+考拉
+锦鲤
+ktv
+实验室
+实验室外套
+标签
+拉布拉多
+迷宫
+网眼织物
+蕾丝连衣裙
+梯子
+长柄杓
+瓢虫
+环礁湖
+湖泊
+湖区
+湖边小屋
+湖岸
+羊肉
+羊排
+灯柱
+灯罩
+矛
+土地
+陆地车辆
+废物填埋
+着陆
+降落甲板
+地标
+风景
+山崩
+挂带
+灯笼
+腿/大腿
+笔记本电脑
+笔记本键盘
+幼体
+烤宽面条
+激光
+睫毛
+套索
+门闩
+乳胶
+拿铁咖啡
+笑
+发射
+发布会
+举办会议
+自助洗衣店
+洗衣房
+洗衣篮
+洗衣房
+熔岩
+薰衣草
+草坪
+草坪婚礼
+律师
+躺
+引领
+主唱
+通向
+领袖
+泄漏
+倾斜/倚靠
+学习
+皮带
+皮革
+皮夹克
+皮鞋
+演讲
+演讲厅
+教学室
+窗台
+剩饭
+腿
+传说
+紧身裤/秋裤
+立法院
+乐高
+豆类
+柠檬
+柠檬汁
+柠檬水
+狐猴
+镜头
+眩光
+扁豆
+豹
+紧身连衣裤
+紧身裤袜
+小妖精
+课程
+信函
+信箱
+信的标志
+刻字
+生菜
+水平
+图书馆
+许可证
+车牌
+地衣
+舔
+盖子
+躺着
+安全带
+救生衣
+救生艇
+救生员
+提起
+灯具
+灯光秀
+电灯开关
+照明/照明设备
+闪电
+避雷针
+淡紫色
+百合
+肢体
+石灰
+石灰石
+豪华轿车
+线条
+艺术线条
+排队
+亚麻
+邮轮
+狮子
+润唇膏
+口红
+液体
+酒类商店
+列表
+荔枝
+生活
+家畜
+客厅
+生活空间
+蜥蜴
+负载
+装卸码头
+游手好闲的人
+走廊
+定位
+锁
+闸室
+储物柜
+阁楼
+原木
+小木屋
+标志
+洛基
+长头发
+冲浪板
+隐约显现/织布机
+环状
+遗失
+彩票
+莲花
+爱
+双人沙发
+行李
+木材
+伐木工人
+午餐
+午餐盒
+郁郁葱葱的
+奢侈品
+豪华游艇
+雨衣
+澳洲胡桃
+短尾猿
+通心粉
+金刚鹦鹉
+弯刀
+机器
+机枪
+杂志
+魔法
+魔术师
+磁铁
+放大镜
+木兰花
+喜鹊
+麻将
+象夫
+女仆
+邮件
+邮件槽
+制作
+改造
+化妆师
+化妆工具
+野鸭
+野鸭
+槌棒
+哺乳动物
+猛犸象
+男人
+管理
+经理
+海牛
+曼荼罗
+橘子
+普通话
+鬃毛
+漫画
+食槽
+芒果
+山竹果
+红树林
+曼哈顿
+检修孔
+井盖
+修指甲
+人体模型
+庄园主宅
+大厦
+螳螂
+地幔
+活动房层
+制造业
+手稿
+地图
+枫木
+枫叶
+枫糖浆
+沙球
+马拉松
+大理石
+行进
+行进乐队
+母马
+金盏花
+水兵
+海洋无脊椎动物
+海洋哺乳动物
+木偶
+标志
+集市
+市场广场
+市场摊位
+结婚
+武术
+武术家
+武术馆
+马提尼
+马丁尼酒杯
+睫毛膏
+吉祥物
+土豆泥
+搅碎机
+面具/口罩
+按摩
+桅杆
+地垫
+斗牛士
+比赛
+火柴盒
+衣料
+床垫
+陵墓
+长裙
+一餐
+量杯
+卷尺
+肉类
+肉丸
+机械师
+机械风扇
+奖牌
+媒体
+医疗设备
+医学图像
+医务人员
+医药箱
+中世纪的
+麦地那市
+冥想
+猫鼬
+赛事
+香瓜
+纪念碑
+菜单
+美人鱼
+网
+肮脏
+信使袋
+金属
+金属艺术家
+金属探测器
+计量器
+中层楼
+麦克风
+显微镜
+微波炉
+午夜
+里程碑
+军装
+牛奶
+牛奶罐
+奶茶
+奶昔
+磨坊
+矿井
+矿工
+矿物质
+矿泉水
+迷你
+微缩模型
+面包车
+部长
+小型货车
+薄荷
+薄荷糖
+镜子
+小姐
+投掷物
+任务
+槲寄生
+混合
+搅拌机
+搅拌碗
+混合物
+护城河
+电动踏板车
+模型/模特
+汽车模型
+现代
+现代大厦
+潮湿
+模具
+模具
+鼹鼠
+君主
+钱
+监控器
+和尚
+猴子
+活动扳手
+黑白照片
+独轮脚踏车
+怪物卡车
+月亮
+月饼
+月光
+沼泽
+驼鹿
+拖把
+助力车
+早晨
+晨雾
+晨光
+朝阳
+砂浆
+马赛克
+清真寺
+蚊子
+藓类植物
+汽车旅馆
+蛾
+母亲
+主板
+主题
+动作
+电动机
+摩托车
+摩托车
+摩托车头盔
+摩托车赛车手
+骑摩托车的人
+赛车运动
+土堆
+山
+山地自行车
+山地自行车员
+山地自行车运动
+山地大猩猩
+山湖
+山景观
+山口
+山路
+山脉
+山区河流
+山雪
+山间溪流
+山景城
+山村
+登山者
+登山包
+鼠标/鼠
+鼠标垫
+捕鼠器
+嘴
+漱口水
+移动
+电影海报
+电影票
+割草机
+mp3播放器
+先生
+泥
+松饼
+马克杯
+桑树
+覆盖物
+骡子
+直辖市
+壁画
+肌肉
+肌肉车
+博物馆
+蘑菇
+音乐
+音乐节
+音乐凳子
+音乐工作室
+音乐录影带表演者
+音乐键盘
+音乐家
+贻贝
+芥末
+神话
+烤干酪辣味玉米片
+指甲油
+指甲锉
+保姆
+餐巾
+狭窄的
+国旗
+基督诞生的场景
+自然历史博物馆
+自然
+自然保护区
+导航
+九夜节
+海军
+星云
+脖子
+围颈带/领口
+项链
+领口
+花蜜
+油桃
+针状物
+邻居
+与某处邻近的地区
+霓虹灯
+霓虹灯
+神经
+巢
+新年
+新生的
+纽芬兰
+新婚
+新闻
+记者招待会
+报摊
+晚上
+夜市
+夜空
+夜景
+夜总会
+床头柜
+面条
+鼻子
+鼻羁
+注解
+笔记本
+记事本
+信纸
+公告
+数字图标
+修女
+护士
+托儿所
+养老院
+螺母
+胡桃夹子
+橡木
+橡树
+桨
+绿洲
+烘干室
+燕麦片
+燕麦
+方尖塔
+观察塔
+天文台
+超越障碍训练场
+海洋
+章鱼
+提供
+办公室
+办公大楼
+办公椅
+办公室隔间
+办公桌
+办公用品
+办公室的窗户
+军官
+行政官员
+石油
+油灯
+油画
+石油钻台
+秋葵
+老照片
+橄榄
+橄榄油
+橄榄树
+煎蛋卷
+洋葱
+洋葱圈
+蛋白石
+开阔的/张开
+开始
+开幕式
+歌剧
+歌剧院
+操作
+手术室
+操作
+眼镜店
+猩猩
+橙子/橙色
+橙汁
+橙树
+橘园
+轨道
+果园
+乐池
+兰花
+订单
+组织
+折纸
+点缀
+鱼鹰
+鸵鸟
+水獭
+外面的
+露头
+户外
+厕所
+电源插头
+大纲
+椭圆形
+烤箱
+整体
+大衣
+天桥
+猫头鹰
+牡蛎
+橡皮环
+包裹
+包/包装/包裹
+围场
+警车
+挂锁
+肉菜饭
+宝塔
+疼痛
+油漆刷
+画家
+佩斯利印花大手帕
+宫殿
+调色板
+栅栏
+棺罩
+棕榈树
+平底锅
+煎饼
+熊猫
+面板
+全景
+三色堇
+喘息
+储藏室
+裤子
+连裤袜
+木瓜
+纸
+纸袋
+切纸机
+纸灯笼
+纸盘子
+纸巾
+平装书
+压纸器
+降落伞
+游行
+天堂
+鹦鹉
+护理人员
+长尾小鹦鹉
+滑翔伞
+伞兵
+羊皮纸
+教区
+公园
+公园长椅
+停车
+停车场
+停车费
+停车标志
+议会
+欧芹/香菜
+参与者
+合作伙伴
+帕特里奇
+聚会
+派对帽
+通过
+通道
+存折
+乘客
+客船
+旅客列车
+百香果
+护照
+面食
+粘贴
+糕点
+牧场
+补丁
+病人
+图案/款式
+人行道/硬路面
+大帐篷
+爪子
+支付
+付费电话
+豌豆
+和平
+桃子
+孔雀
+山峰/尖顶
+花生
+花生酱
+梨
+珍珠
+卵石
+山核桃
+行人
+人行天桥
+步行街
+果皮
+削皮器
+小钉板
+木质腿
+鹈鹕
+笔/围栏
+点球
+铅笔
+铅笔盒
+卷笔刀
+铅笔裙
+吊坠
+钟摆
+企鹅
+半岛
+锦标旗
+便士
+储蓄罐
+牡丹
+胡椒/辣椒
+胡椒研磨机
+胡椒子
+意大利辣香肠
+栖息/鲈鱼
+表演
+表演
+表演舞台
+香水
+绿廊
+波斯猫
+柿子
+个人护理
+个人漂浮装置
+害虫
+宠物
+宠物店
+宠物店
+花瓣
+佩妮
+教堂的长椅
+野鸡
+现象
+哲学家
+电话
+电话簿
+留声机
+照片
+照相亭
+相框
+摄影
+物理学家
+物理实验室
+钢琴家
+钢琴
+选择
+捡起
+泡菜
+野餐
+野餐区
+野餐篮
+野餐桌
+图片
+相框
+馅饼
+鸽子
+朝圣者
+药片
+枕头
+飞行员
+领航艇
+别针
+松树
+松果
+松林
+松子
+菠萝
+乒乓球桌
+乒乓球
+粉色
+一品脱的量
+琵琶
+管子
+管碗
+海盗
+海盗旗
+海盗船
+阿月浑子
+滑雪场
+口袋里的面包
+火龙果
+斗牛犬
+球场
+大水罐
+猪笼草
+干草叉
+披萨
+披萨刀
+比萨锅
+披萨店
+招牌
+地方
+餐具垫
+格子
+平原
+示意图
+行星
+行星地球
+厚木板
+植物
+种植园
+种植
+匾额
+石膏
+塑料
+橡皮泥
+高原
+平台
+白金
+大浅盘
+玩/演奏/运动
+打羽毛球
+打棒球
+打篮球
+玩台球
+踢足球
+玩乒乓球
+打网球
+打排球
+选手/运动员
+操场
+剧场
+扑克牌
+下棋
+打高尔夫球
+打麻将
+运动场
+护栏
+游戏室
+广场
+钳子
+故事情节
+犁
+插头
+插头帽
+李子
+水管工
+卫生洁具
+羽毛
+夹板
+口袋
+怀表
+随身小折刀
+圆荚体
+乐队指挥台
+诗歌
+一品红
+指/朝向
+指针
+扑克卡
+筹码
+扑克表
+杆/柱
+臭猫
+警察
+警车
+警犬
+警察局
+政治家
+圆点
+花粉
+污染
+马球
+马球领
+马球衬衫
+石榴
+波美拉尼亚的
+雨披
+池塘
+马尾辫
+贵宾犬
+池
+流行
+流行艺术家
+爆米花
+教皇
+罂粟
+瓷
+玄关
+猪肉
+粥
+便携式电池
+门户网站
+投资组合
+汽门
+肖像
+肖像会话
+摆姿势拍照
+负鼠
+帖子
+邮局
+邮票
+明信片
+海报
+海报页
+锅/罐/陶盆
+土豆
+土豆片
+土豆沙拉
+布垫子
+便壶
+袋
+家禽
+英镑
+倾泻
+粉末
+电源线
+电源插头及插座
+权力看
+电站
+练习
+布拉格城堡
+祈祷
+牧师
+首映
+处方
+显示
+演讲
+总统
+新闻发布室
+高压锅
+椒盐卷饼
+王子
+公主
+打印
+打印页面
+打印机
+印刷
+监狱
+农产品/生产
+产品
+职业
+专业的
+教授
+项目图片
+投影屏幕
+投影仪
+毕业舞会
+散步
+螺旋桨
+先知
+建议
+防护服
+抗议
+抗议者
+出版
+宣传画像
+冰上曲棍球
+布丁
+水坑
+泡芙
+角嘴海雀
+哈巴狗
+拉
+讲坛
+脉冲
+泵
+南瓜
+南瓜饼
+南瓜种子
+拳击吊袋
+拳头猛击/穿孔
+学生
+紫色
+推
+轻轻一击
+谜题
+塔
+金字塔
+大蟒
+二维码
+鹌鹑
+采石场
+季度
+石英
+女王
+油炸玉米粉饼
+队列
+乳蛋饼
+被子
+绗缝
+引用
+兔子
+浣熊
+比赛
+赛道
+水沟/跑道
+赛车
+球拍
+雷达
+散热器
+广播
+木筏/橡皮艇
+布娃娃
+栏杆/铁轨
+轨道车
+铁道
+铁路桥梁
+轨道线
+火车站
+雨
+雨靴
+彩虹
+虹鳟鱼
+雨衣
+热带雨林
+多雨的
+葡萄干
+耙子
+公羊
+斜坡
+油菜籽
+快速
+说唱歌手
+树莓
+老鼠
+棘轮
+乌鸦
+峡谷
+雷
+剃须刀
+锋利的
+阅读
+阅读材料
+钻孔器
+后面
+尾灯
+后视图
+后视镜
+收据
+收到
+接待
+配方
+记录
+唱片制作人
+记录器/竖笛
+录音室
+娱乐室
+休闲车
+矩形
+回收
+回收站
+红色
+红地毯
+红旗
+红熊猫
+红酒
+红木
+芦苇
+礁石
+卷轴
+裁判
+倒影
+倒影
+反射器
+注册
+控制
+驯鹿
+放松
+释放
+救援
+宗教
+宗教的
+享受
+保持
+改造
+遥控器
+移除
+修复
+维修店
+爬行动物
+救援
+救助者
+研究
+研究员
+储层
+住宅
+居民区
+树脂
+度假胜地
+度假小镇
+餐厅的厨房
+餐厅的露台
+厕所
+零售
+寻回犬
+制动火箭
+揭示
+犀牛
+杜鹃
+肋骨
+丝带
+大米
+电饭煲
+稻田
+骑/搭乘
+脊
+骑马
+步枪
+边缘
+环/戒指
+暴乱
+涟漪
+上升
+高层建筑
+河
+河岸
+河船
+河谷
+河床
+路
+路标
+公路旅行
+路边
+烤鸡
+长袍
+罗宾
+机器人
+石头
+岩石拱
+摇滚艺术家
+摇滚乐队
+攀岩者
+攀岩
+摇滚音乐会
+岩石表面
+岩层
+摇滚歌手
+火箭
+摇椅
+岩石
+啮齿动物
+牛仔竞技表演
+竞技舞台
+罗伊
+狍子
+辊
+过山车
+轮式溜冰鞋
+溜冰鞋
+擀面杖
+浪漫
+浪漫的
+屋顶
+屋顶花园
+房间
+房间分频器
+根
+根啤酒
+绳索桥
+念珠
+玫瑰
+迷迭香
+玫瑰色的云
+罗特韦尔犬
+圆桌
+路由器
+行
+罗文
+皇家
+橡皮图章
+废墟
+魔方
+红宝石
+莱夫
+橄榄球
+橄榄球
+橄榄球运动员
+毁坏
+尺
+朗姆酒
+跑
+跑步者
+跑步鞋
+农村的
+锈
+乡村的
+黑麦
+袋
+鞍
+鞍囊
+旅行
+安全
+安全背心
+圣人
+帆
+帆船
+航行
+水手
+松鼠猴
+缘故
+沙拉
+沙拉碗
+火蜥蜴
+意大利蒜味腊肠
+出售
+三文鱼
+沙龙
+萨尔萨舞
+盐
+盐和胡椒瓶
+盐湖
+盐沼
+盐瓶
+敬礼
+萨莫耶德人
+武士
+沙子
+沙洲
+砂箱
+沙堡
+沙雕
+凉鞋
+三明治
+卫生巾
+圣诞老人
+蓝宝石
+沙丁鱼
+莎丽
+生鱼片
+沙爹
+书包
+卫星
+缎
+酱汁
+碟子
+桑拿
+香肠
+稀树大草原
+锯
+锯木架
+萨克斯管
+萨克斯手
+脚手架
+秤/标尺
+比例模型
+扇贝
+疤痕
+稻草人
+围巾
+场景
+风景
+雪纳瑞犬
+学校
+校车
+校服
+校舍
+纵帆船
+科学
+科幻电影
+科学博物馆
+科学家
+剪刀
+壁灯
+司康饼
+勺子
+踏板车/摩托车
+分数
+记分板
+蝎子
+童子军
+炒蛋
+废弃
+刮板
+刮伤
+屏幕
+纱门
+截图
+螺杆
+螺丝刀
+长卷纸/卷轴
+擦洗
+硬毛刷
+雕塑家
+雕塑
+海洞穴
+海冰
+海狮
+海龟
+海胆
+尖吻鲈
+海底
+海鸟
+海鲜
+海马
+海豹
+海景
+海贝
+海滨度假胜地
+季节
+座位
+安全带
+海藻
+秘书
+安全
+小轿车
+看到
+种子
+跷跷板
+赛格威
+自拍
+出售
+研讨会
+感觉
+传感器
+服务器
+服务器机房
+服务
+集
+缝纫机
+影子
+摇
+瓶
+洗发水
+形状
+分享
+鲨鱼
+卷笔刀
+记号笔
+剃须刀
+剃须膏
+披肩/围巾
+剪切
+剪刀
+羊
+床单
+乐谱
+架子
+贝壳
+贝类
+避难所
+搁置
+牧羊人
+果子露
+柴犬
+发光
+航运
+集装箱
+海难
+船厂
+衬衫
+赤膊的
+浅滩
+鞋
+鞋盒
+鞋店
+鞋楦
+射击
+得分篮球后卫
+商店橱窗
+门面
+购物者
+购物
+购物袋
+购物篮
+购物车
+购物中心
+购物街
+海岸
+海岸线
+短的
+短发
+短裤
+小酒杯
+散弹枪
+肩膀
+单肩包
+铲
+陈列柜
+淋浴
+浴帽
+浴帘
+淋浴门
+淋浴头
+碎纸机
+泼妇
+虾
+神社
+灌木
+快门
+暹罗猫
+西伯利亚
+兄弟姐妹
+侧面
+边柜
+配菜
+边车
+边线
+壁板
+标志
+指示牌
+信号
+签名
+丝绸
+丝袜
+筒仓
+银
+银牌
+银器
+唱歌
+烧焦
+歌手
+水槽
+啜
+坐/放置/坐落
+坐着
+滑板公园
+滑板
+滑板者
+溜冰者
+溜冰场
+骨架
+草图
+串串
+滑雪
+滑雪靴
+滑雪设备
+滑雪服
+滑雪缆车
+滑雪杖
+滑雪胜地
+滑雪板
+滑雪
+滑雪鞋
+皮肤
+头骨
+无边便帽
+天空
+天空塔
+天窗
+天际线
+摩天大楼
+激流回旋
+石板
+雪橇
+睡眠
+睡袋
+睡衣
+袖子
+片
+滑动
+滑块
+吊索
+坡
+投币口
+老虎机
+树懒
+慢炖锅
+鼻涕虫
+贫民窟
+气味
+微笑
+烟雾/抽烟
+零食
+蜗牛
+蛇
+鲷鱼
+快照
+通气管
+鼻子
+雪
+雪豹
+雪山
+雪球
+单板滑雪者
+雪原
+雪花
+雪人
+雪地摩托
+雪犁
+雪鞋
+雪
+肥皂
+肥皂泡
+给皂器
+足球守门员
+社会名流
+短袜
+插座
+苏打水
+垒球
+软件
+太阳能电池阵列
+士兵
+独奏
+解决方案
+宽边帽
+歌曲
+声音
+汤
+汤碗
+汤匙
+酸奶油
+纪念品
+豆浆
+水疗中心
+空间
+航天飞机
+空间站
+宇宙飞船
+意大利面
+横跨
+扳手
+火花
+闪耀
+烟火
+起泡葡萄酒
+麻雀
+抹刀
+扬声器
+观众
+会话框
+速度限制
+限速标志
+快艇
+车速表
+球
+香料
+调料架
+蜘蛛
+蜘蛛网
+扣球
+旋转
+菠菜
+尖塔
+飞溅
+海绵
+勺子
+体育协会
+运动器材
+运动团队
+体育球
+体育器材
+运动会
+运动服装
+点
+喷雾
+伸展
+春天
+春卷
+撒
+洒水器
+发芽
+云杉
+云杉森林
+队
+广场
+南瓜
+蹲
+挤
+鱿鱼
+松鼠
+水枪
+刺
+稳定的
+(码放整齐的)一叠
+体育场
+工作人员
+舞台
+舞台灯
+驿马车
+弄脏
+不锈钢
+楼梯
+楼梯
+楼梯间
+摊位/小隔间
+种马
+站/矗立/摊位
+站
+主食
+订书机
+星星
+盯着
+海星
+杨桃
+燕八哥
+州立公园
+公立学校
+车站
+固定自行车
+文具
+雕像
+牛排
+牛排刀
+蒸汽
+蒸汽机
+蒸汽机车
+蒸汽火车
+馒头
+钢
+方向盘
+(花草的)茎
+模版
+梯凳
+立体声
+听诊器
+炖
+戳/条状物
+竹节虫
+贴纸
+静物画
+高跷
+黄貂鱼
+搅拌
+搅拌器
+镫
+缝
+股票
+长筒袜
+腹部
+石头建筑
+石雕
+石屋
+石磨
+凳子
+停止
+停在
+红灯
+停车标志
+秒表
+红绿灯
+存储箱
+储藏室
+罐/蓄水池
+商店
+店面
+鹳
+风暴
+暴风云
+狂风暴雨的
+炉子
+扑克
+跨骑
+过滤器
+海峡
+带
+稻草/吸管
+草帽
+草莓
+溪流
+街头艺术
+街头艺术家
+街角
+流浪狗
+街头食品
+路灯
+街市场
+街头摄影
+街景
+路标
+街头小贩
+拉伸
+担架
+罢工
+前锋
+细绳
+芝士条
+带子
+条纹
+漫步
+结构
+工作室
+影棚拍摄
+材料
+填充玩具动物
+毛绒玩具
+馅
+树桩
+惊人的
+特技
+佛塔
+风格
+手写笔
+潜艇
+潜艇形大三明治
+海底水
+郊区
+地铁
+地铁站
+低音炮
+多肉
+绒面革
+糖
+糖碗
+甘蔗
+方糖
+西装
+套房
+夏天
+夏天傍晚
+峰顶
+太阳
+太阳帽
+日光浴
+周日
+日晷
+向日葵
+向日葵田
+葵花籽
+太阳镜
+晴天
+日出
+日落
+遮阳伞
+阳光
+超级碗
+跑车
+超级英雄
+超市
+超市货架
+超模
+支持者
+冲浪
+表面
+冲浪板
+冲浪者
+外科医生
+外科手术
+环绕
+寿司
+寿司吧
+背带裤
+悬架
+吊桥
+越野车
+燕子
+燕尾蝶
+沼泽
+天鹅
+天鹅游艇
+运动裤
+防汗带
+毛衣
+运动衫
+甜的
+红薯
+游泳
+泳帽
+游泳者
+游泳洞
+游泳池
+摆动
+平转桥
+秋千
+漩涡
+开关
+转椅
+剑
+旗鱼
+象征
+对称
+犹太教堂
+注射器
+糖浆
+系统
+t恤
+t恤
+塔巴斯科辣椒酱
+虎斑
+乒乓球拍
+桌面
+桌布
+平板电脑
+餐具
+转速表
+拦截
+墨西哥煎玉米卷
+跆拳道
+太极
+尾巴
+裁缝
+拍/拿
+起飞
+说话/交谈/演讲
+手鼓
+棕褐色
+橘子
+胶带/磁带/终点线
+挂毯
+沥青碎石路面
+芋头
+篷布
+果馅饼
+流苏
+味道
+榻榻米
+纹身
+纹身艺术家
+酒馆
+茶
+茶包
+茶话会
+茶园
+茶壶
+茶具
+教
+老师
+茶杯
+水鸭
+团队合影
+团队介绍
+眼泪/撕裂/划破
+技术员
+技术
+泰迪熊
+T字形物
+青少年
+电线杆
+变焦镜头
+望远镜
+电视
+电视摄像机
+电视室
+电视演播室
+温度
+寺庙
+天妇罗
+网球
+网球场
+网球比赛
+网球网
+网球运动员
+网球拍
+帐篷
+龙舌兰酒
+终端/航站楼
+阳台
+地形
+玻璃容器
+领土
+测试
+测试赛
+试管
+文本
+短信
+纺织
+纹理
+感恩节
+感恩节晚餐
+剧院
+戏剧演员
+治疗
+温度计
+热水瓶
+暖瓶
+恒温器
+灌木丛
+顶针
+东西
+思考
+蓟
+宝座
+金銮殿
+扔
+抱枕
+雷
+雷雨
+百里香
+皇冠
+记号
+票
+售票亭
+潮池
+领带
+老虎
+紧
+瓦
+瓷砖地板
+瓦屋顶
+瓷砖墙
+锡
+锡纸
+箔
+提拉米苏
+轮胎
+纸巾
+烤面包
+烤面包机
+烟草
+烟斗
+学步的小孩
+脚趾
+豆腐
+马桶
+马桶座圈
+化妆包
+东京铁塔
+番茄
+番茄酱
+番茄汤
+墓
+钳子
+钳子
+工具
+工具箱
+牙刷
+牙膏
+牙签
+修剪成形的花园
+配料
+火炬/光源
+龙卷风
+玉米粉圆饼
+乌龟
+大手提袋
+图腾柱
+龙猫
+巨嘴鸟
+触摸
+触地
+旅行
+旅游巴士
+导游
+游客
+旅游景点
+锦标赛
+拖车
+毛巾
+毛巾杆
+大厦
+塔桥
+小镇
+城镇广场
+玩具
+玩具车
+玩具枪
+玩具店
+跑道
+拖拉机
+贸易
+传统
+传统的
+交通
+锥形交通路标
+交通拥堵
+交通堵塞
+交通标志
+小道
+预告片
+拖车
+火车
+火车桥
+火车车厢
+火车内部
+火车轨道
+火车窗口
+教练
+训练
+训练长椅
+训练场
+电车/手推车
+蹦床
+变形金刚
+透明度
+旅行
+托盘/碟子
+跑步机
+美食
+树
+树枝
+林场
+树蛙
+树屋
+树根
+树干
+试验
+三角形
+铁人三项
+部落
+支流
+戏法/特技
+三轮车
+修剪
+三人组
+三脚架
+长号
+部队
+奖杯
+奖杯
+热带
+鳟鱼
+卡车
+卡车司机
+浴缸
+管子
+拖船
+郁金香
+金枪鱼
+苔原
+隧道
+涡轮
+火鸡
+转动
+芜菁
+绿松石
+炮塔
+乌龟
+獠牙
+电视演员
+电视柜
+电视剧
+电视节目类型
+电视名人
+电视节目
+情景喜剧
+电视塔
+枝条
+黄昏
+双胞胎
+麻线
+扭
+类型
+键入
+打字机
+尤克里里
+奥特曼
+伞
+内衣
+水下
+独角兽
+制服
+宇宙
+大学
+向上
+城市
+尿壶
+瓮
+使用
+用具
+杂物间
+吸尘器/真空
+谷
+阀门
+吸血鬼
+货车
+香草
+虚荣
+种类
+花瓶/瓶
+金库
+矢量卡通插图
+矢量图标
+蔬菜
+菜园
+蔬菜市场
+植被
+车辆
+面纱
+静脉
+天鹅绒
+自动售货机
+小贩
+通风孔
+胡蜂属
+船
+背心
+兽医
+经验丰富的
+兽医办公室
+高架桥
+视频
+摄像机
+电子游戏
+录像带
+视镜
+守夜
+别墅
+村庄
+藤蔓
+醋
+葡萄园
+暴力
+紫罗兰色
+小提琴
+小提琴家
+中提琴演奏者
+愿景
+遮阳板
+伏特加
+火山
+排球
+排球场
+排球运动员
+志愿者
+航行
+秃鹰
+华夫饼干
+华夫饼机
+货车
+马车车轮
+腰
+服务员
+候机室
+等候室
+走
+步行
+手杖
+挂钟
+壁纸
+核桃
+海象
+战争
+仓库
+温暖的
+警告标志
+战士
+军舰
+疣猪
+洗
+洗衣机/垫圈
+洗
+洗衣机
+黄蜂
+浪费
+废物容器
+手表
+水
+水鸟
+水牛
+水冷却器
+水滴
+水景
+热水器
+水位
+荷花
+水上乐园
+水管
+净水器
+滑水板
+水上运动
+水面
+水塔
+水彩
+水彩插图
+水彩画
+瀑布
+喷壶
+水印叠加图章
+西瓜
+防水外套
+水路
+波浪
+蜡
+武器
+穿着
+天气
+叶片
+网
+摄像头
+婚礼
+结婚戒指
+婚礼花束
+结婚蛋糕
+新婚夫妇
+婚礼请柬
+婚礼派对
+婚纱照
+婚礼摄影师
+婚纱摄影
+婚宴
+楔
+杂草
+重量
+体重秤
+焊接工
+井
+西餐
+西餐厅
+湿
+吧台
+潜水衣
+湿地
+潜水服
+鲸鱼
+鲸鲨
+小麦
+麦田
+车轮
+轮椅
+后轮支撑车技
+生奶油
+搅拌器
+胡须
+威士忌
+哨子
+白色
+白宫
+白葡萄酒
+白板
+便门
+宽的
+挥动
+假发
+Wii
+Wii手柄
+荒野
+角马
+野火
+野花
+野生动物
+柳树
+风
+风铃
+风电场
+风力涡轮机
+风车
+窗户
+窗台花盆箱
+橱窗展示
+窗框
+纱窗
+靠窗的座位
+窗台
+雨刮器
+挡风玻璃
+有风的
+酒瓶
+冷酒器
+酒柜
+酒窖
+酒杯
+酒架
+品酒
+酒庄
+翅膀
+冬天
+冬瓜
+冬天的早晨
+冬季场景
+冬季运动
+冬季风暴
+电线
+紫藤
+巫婆
+女巫帽子
+炒锅
+狼
+女人
+木头
+林鸳鸯
+木地板
+木墙
+烧木炉
+木匙
+林地
+啄木鸟
+木工刨
+羊毛
+工作
+练习卡
+工作台
+工人
+工作场所
+车间
+世界
+蠕虫
+敬拜
+伤口
+包
+裹身裙
+包装纸
+搏斗
+摔跤手
+皱纹
+腕带
+写
+作家
+手写/字迹
+毛笔
+写字桌
+游艇
+牦牛
+院子
+黄色
+瑜伽
+瑜伽垫
+酸奶
+轭
+蛋黄
+青年
+青年旅馆
+蒙古包
+斑马
+斑马线
+禅意花园
+拉链
+拉链
+僵尸
+粽子
+动物园
diff --git a/ram/data/ram_tag_list_threshold.txt b/ram/data/ram_tag_list_threshold.txt
new file mode 100644
index 0000000000000000000000000000000000000000..0472b23c25903900c0dde68fffc9a6a6755f5117
--- /dev/null
+++ b/ram/data/ram_tag_list_threshold.txt
@@ -0,0 +1,4585 @@
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.71
+0.75
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.8
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.9
+0.65
+0.9
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.61
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.8
+0.7
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.82
+0.8
+0.65
+0.65
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.9
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.8
+0.8
+0.65
+0.85
+0.8
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.77
+0.65
+0.65
+0.65
+0.9
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.89
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.78
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.9
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.65
+0.9
+0.65
+0.83
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.79
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.89
+0.65
+0.65
+0.65
+0.65
+0.65
+0.9
+0.65
+0.65
+0.86
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.79
+0.65
+0.63
+0.65
+0.87
+0.8
+0.46
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.9
+0.65
+0.65
+0.9
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.9
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.9
+0.65
+0.8
+0.65
+0.8
+0.8
+0.8
+0.65
+0.65
+0.84
+0.65
+0.65
+0.79
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.81
+0.65
+0.8
+0.65
+0.65
+0.9
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.9
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.9
+0.65
+0.65
+0.65
+0.87
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.83
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.77
+0.87
+0.65
+0.65
+0.8
+0.8
+0.65
+0.65
+0.65
+0.65
+0.85
+0.65
+0.68
+0.65
+0.8
+0.65
+0.65
+0.75
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.8
+0.8
+0.8
+0.79
+0.65
+0.85
+0.65
+0.65
+0.65
+0.9
+0.65
+0.89
+0.8
+0.65
+0.65
+0.65
+0.76
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+1
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.9
+0.65
+0.89
+0.7
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.8
+0.65
+0.65
+0.71
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.9
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.8
+0.65
+0.65
+0.8
+0.8
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.8
+0.65
+0.8
+0.65
+0.8
+0.8
+0.9
+0.65
+0.85
+0.8
+0.8
+0.8
+0.9
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.75
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.63
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.8
+0.65
+0.88
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.8
+0.71
+0.65
+0.65
+0.65
+0.8
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.9
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.71
+0.65
+0.8
+0.76
+0.85
+0.8
+0.65
+0.65
+0.8
+0.65
+0.79
+0.65
+0.75
+0.65
+0.8
+0.65
+0.86
+0.65
+0.65
+0.9
+0.9
+0.65
+0.65
+0.65
+0.65
+0.65
+0.73
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.9
+0.65
+0.85
+0.65
+0.65
+0.65
+0.65
+0.8
+0.75
+0.65
+0.65
+0.65
+0.65
+0.8
+0.85
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.77
+0.65
+0.65
+0.65
+0.65
+0.65
+0.86
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.6
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.74
+0.65
+0.65
+0.67
+0.65
+0.65
+0.8
+0.65
+0.65
+0.85
+0.65
+0.8
+0.65
+0.65
+0.84
+0.8
+0.8
+0.8
+0.8
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.9
+0.9
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.8
+0.89
+0.65
+0.65
+0.65
+0.83
+0.65
+0.65
+0.65
+0.65
+0.6
+0.65
+0.8
+0.8
+0.8
+0.65
+0.65
+0.89
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.8
+0.65
+0.77
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.87
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.74
+0.65
+0.65
+0.66
+0.89
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.9
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.84
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.88
+0.65
+0.65
+0.8
+0.65
+0.65
+0.7
+0.65
+0.65
+0.65
+0.9
+0.65
+0.9
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.82
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.8
+0.9
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.8
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.8
+0.65
+0.8
+0.8
+0.65
+0.65
+0.65
+0.75
+0.65
+0.7
+0.9
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.88
+0.65
+0.65
+1
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.8
+0.65
+0.65
+0.8
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.65
+0.8
+0.8
+0.65
+0.8
+0.65
+0.65
+0.71
+0.65
+0.65
+0.65
+0.79
+0.65
+0.65
+0.65
+0.65
+0.65
+0.89
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.9
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.9
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.88
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.82
+0.65
+0.9
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.8
+0.65
+0.9
+0.65
+0.65
+0.88
+0.65
+0.65
+0.65
+0.65
+0.9
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.9
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.89
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.87
+0.65
+0.66
+0.65
+0.84
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.84
+0.65
+0.65
+0.65
+0.65
+0.65
+0.9
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.5
+0.65
+0.64
+0.65
+0.65
+0.8
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.81
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.8
+0.8
+0.65
+0.65
+0.65
+0.8
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.8
+0.65
+0.84
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.9
+0.65
+0.8
+0.65
+0.85
+0.65
+0.65
+0.65
+0.9
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.73
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.8
+0.65
+0.65
+0.86
+0.65
+0.65
+0.65
+0.65
+0.87
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.8
+0.8
+0.65
+0.9
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.82
+0.8
+0.65
+0.65
+0.65
+0.84
+0.9
+0.9
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.64
+0.65
+0.65
+0.65
+0.8
+0.8
+0.87
+0.65
+0.65
+0.78
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.8
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.8
+0.65
+0.9
+0.65
+0.65
+0.8
+0.65
+0.85
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.74
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.88
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.83
+0.89
+0.89
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.9
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.86
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.8
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.85
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.9
+0.65
+0.65
+0.9
+0.65
+0.65
+0.65
+0.9
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.86
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.87
+0.8
+0.84
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.81
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.8
+0.8
+0.65
+0.7
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.82
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.87
+0.65
+0.9
+0.8
+0.65
+0.65
+0.65
+0.9
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.7
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.9
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.85
+0.65
+0.65
+0.65
+0.65
+0.65
+0.73
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.9
+0.65
+0.89
+0.8
+0.65
+0.9
+0.65
+1
+0.65
+0.65
+0.65
+0.65
+0.9
+0.65
+0.65
+0.65
+0.65
+0.89
+0.89
+0.65
+0.65
+0.65
+0.8
+0.75
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.8
+0.8
+0.8
+0.65
+0.65
+0.88
+0.65
+0.8
+0.65
+0.65
+0.8
+0.85
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.9
+0.57
+0.65
+0.8
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.9
+0.8
+0.8
+0.79
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.8
+0.89
+0.8
+0.65
+0.8
+0.65
+0.8
+0.65
+0.81
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.89
+0.65
+0.65
+0.65
+0.65
+0.65
+0.89
+0.84
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.9
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.89
+0.65
+0.8
+0.83
+0.65
+0.65
+0.8
+0.65
+0.65
+0.72
+0.65
+0.65
+0.65
+0.8
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+1
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.9
+0.65
+0.65
+0.89
+0.65
+0.65
+0.65
+0.65
+0.9
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.8
+0.65
+0.69
+0.8
+0.65
+0.65
+0.65
+0.9
+0.65
+0.65
+0.65
+0.65
+0.71
+0.65
+0.65
+0.65
+0.88
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.85
+0.65
+0.8
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.87
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.9
+0.8
+0.9
+0.65
+0.8
+0.8
+0.65
+0.65
+0.8
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.8
+0.8
+0.65
+0.85
+0.65
+0.65
+0.8
+0.65
+0.89
+0.65
+0.65
+0.9
+0.8
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.86
+0.65
+0.77
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.9
+0.65
+0.8
+0.65
+0.65
+0.65
+0.9
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.9
+0.65
+0.65
+0.8
+0.8
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.8
+0.8
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.89
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.75
+0.8
+0.65
+0.8
+0.88
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.88
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.82
+0.65
+0.65
+0.8
+0.65
+0.8
+0.65
+0.9
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.83
+0.65
+0.65
+0.92
+0.89
+0.8
+0.8
+0.65
+0.65
+0.65
+0.65
+0.75
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.85
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.87
+0.65
+0.79
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.83
+0.8
+0.65
+0.65
+0.8
+0.8
+0.65
+0.7
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.9
+0.8
+0.65
+0.65
+0.65
+0.65
+0.7
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.87
+0.65
+0.65
+0.65
+0.65
+0.8
+0.82
+0.65
+0.8
+0.65
+0.65
+0.9
+0.65
+0.65
+0.65
+0.65
+0.65
+1
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.64
+0.65
+0.65
+0.63
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.76
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.65
+0.8
+0.65
+0.75
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.87
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.82
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.89
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.9
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.9
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.9
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.9
+0.8
+0.65
+0.73
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.86
+0.65
+0.9
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.9
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.8
+0.9
+0.65
+0.9
+0.65
+0.65
+0.65
+0.65
+0.86
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.86
+0.65
+0.8
+0.8
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.69
+0.65
+0.65
+0.65
+0.65
+0.65
+0.88
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.72
+0.65
+0.65
+0.8
+0.65
+0.8
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.9
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.9
+0.9
+0.8
+0.8
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.45
+0.8
+0.65
+0.88
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.8
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.65
+0.8
+0.51
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.66
+0.65
+0.8
+0.9
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.81
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.75
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.66
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.9
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.9
+0.8
+0.65
+0.85
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.9
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.81
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.89
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.79
+0.75
+0.65
+0.65
+0.8
+0.65
+0.67
+0.8
+0.8
+0.86
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.81
+0.8
+0.65
+0.65
+0.9
+0.65
+0.79
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.77
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.8
+0.8
+0.65
+0.74
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.6
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.89
+0.8
+0.65
+0.65
+0.88
+0.65
+0.65
+0.65
+0.9
+0.75
+0.65
+0.65
+0.65
+0.8
+0.6
+0.65
+0.65
+0.65
+0.9
+0.65
+0.65
+0.65
+0.84
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.9
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.65
+0.8
+0.8
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.85
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.63
+0.65
+0.65
+0.65
+0.7
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.9
+0.9
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.84
+0.65
+0.65
+0.8
+0.65
+0.81
+0.8
+0.8
+0.8
+0.82
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.65
+0.8
+0.65
+0.88
+0.65
+0.8
+0.65
+0.7
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.8
+0.65
+1
+0.8
+0.8
+0.65
+0.65
+0.65
+0.8
+0.8
+0.8
+0.65
+0.74
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.85
+0.65
+0.65
+0.65
+0.65
+0.8
+0.8
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.9
+0.86
+0.8
+0.65
+0.8
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.64
+0.65
+0.65
+0.8
+0.8
+0.65
+0.87
+0.65
+0.65
+0.8
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.87
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.7
+0.65
+0.65
+0.8
+0.65
+0.65
+0.75
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.85
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.71
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.73
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.8
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.9
+0.65
+0.65
+0.8
+0.65
+0.86
+0.65
+0.65
+0.65
+0.65
+0.9
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.75
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.88
+0.65
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.9
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.81
+0.65
+0.65
+0.8
+0.65
+0.65
+0.9
+0.8
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.9
+0.65
+0.65
+0.65
+0.65
+0.7
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.77
+0.65
+0.65
+0.65
+0.65
+0.65
+0.85
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.87
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.9
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.57
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.8
+0.65
+0.65
+0.65
+0.65
+0.76
+1
+0.8
+0.65
+0.65
+0.58
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+1
+0.65
+0.8
+0.65
+0.65
+0.65
+0.8
+0.65
+0.9
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.87
+0.8
+0.9
+0.8
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.87
+0.68
+0.8
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.99
+0.8
+0.77
+0.65
+0.9
+0.65
+0.65
+0.88
+0.65
+0.65
+0.65
+0.65
+0.9
+0.65
+0.88
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.89
+0.65
+0.65
+0.8
+0.8
+0.65
+0.7
+0.65
+0.65
+0.8
+0.9
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.8
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.8
+0.8
+0.65
+0.77
+0.65
+0.65
+0.65
+0.65
+0.79
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.8
+0.65
+0.65
+0.65
+0.85
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.52
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.86
+0.65
+0.65
+0.8
+0.56
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.72
+0.65
+0.65
+0.65
+0.8
+0.8
+0.65
+0.9
+0.65
+0.65
+0.8
+0.65
+0.8
+0.6
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.8
+0.65
+0.88
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.89
+0.85
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.87
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.75
+0.65
+0.65
+0.65
+0.65
+0.54
+1
+0.65
+0.65
+0.75
+0.65
+0.75
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.9
+0.9
+0.62
+0.65
+0.65
+0.65
+0.65
+0.86
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.82
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.9
+0.74
+0.8
+0.65
+0.8
+0.8
+0.7
+0.65
+0.65
+0.65
+0.89
+0.65
+0.65
+0.8
+0.8
+0.8
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.9
+0.65
+0.65
+0.65
+0.8
+0.8
+0.84
+0.8
+0.65
+0.65
+0.8
+0.75
+0.65
+0.65
+0.65
+0.89
+0.65
+0.65
+0.65
+0.65
+0.82
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.84
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.8
+0.8
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.65
+0.8
+0.65
+0.7
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.9
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.74
+0.65
+0.8
+0.65
+0.65
+0.65
+0.9
+0.65
+0.65
+0.65
+0.65
+0.85
+0.65
+0.9
+0.9
+0.65
+0.65
+0.65
+0.63
+0.82
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.8
+0.7
+0.65
+0.65
+0.65
+0.65
+0.65
+0.9
+0.65
+0.65
+0.74
+0.9
+0.65
+0.8
+0.65
+0.65
+0.58
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.89
+0.75
+0.65
+0.65
+0.8
+0.65
+0.65
+0.88
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.89
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.8
+0.8
+0.65
+0.65
+0.8
+0.8
+0.65
+0.65
+0.87
+0.65
+0.65
+0.65
+0.8
+0.65
+0.64
+0.65
+0.65
+0.65
+0.8
+0.87
+0.65
+0.65
+0.8
+0.9
+0.65
+0.65
+0.65
+0.65
+0.8
+0.8
+0.65
+0.89
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.83
+0.65
+0.65
+0.8
+0.65
+0.9
+0.65
+0.8
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.78
+0.65
+0.8
+0.65
+0.9
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.9
+0.65
+0.88
+0.8
+0.65
+0.65
+0.65
+0.81
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.77
+0.65
+0.65
+0.65
+0.8
+0.8
+0.8
+0.8
+0.65
+0.65
+0.65
+1
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.85
+0.65
+0.65
+0.8
+0.8
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.88
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.8
+0.8
+0.8
+0.65
+0.65
+0.65
+0.65
+0.68
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.89
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.8
+0.9
+0.65
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.81
+0.65
+0.65
+0.65
+0.8
+0.85
+0.65
+0.77
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.8
+0.8
+0.9
+0.65
+0.65
+0.89
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.8
+0.65
+0.65
+0.65
+0.88
+0.8
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.82
+0.65
+0.8
+0.74
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.85
+0.65
+0.65
+0.85
+0.65
+0.65
+0.65
+0.65
+0.7
+0.7
+0.8
+0.65
+0.65
+0.65
+0.65
+0.87
+0.8
+0.65
+0.65
+0.65
+0.89
+0.85
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.7
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.9
+0.8
+0.8
+0.65
+0.66
+0.57
+0.65
+0.65
+0.65
+0.49
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.65
+0.65
+0.65
+0.8
+0.65
+0.8
+0.8
+0.86
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.89
+0.65
+0.65
+0.65
+0.65
+0.65
+0.65
+0.76
diff --git a/ram/data/tag_list.txt b/ram/data/tag_list.txt
new file mode 100644
index 0000000000000000000000000000000000000000..11a61b68fb9a22eec9cc52a3a2d32474323aafdb
--- /dev/null
+++ b/ram/data/tag_list.txt
@@ -0,0 +1,3429 @@
+tennis
+bear cub
+observatory
+bicycle
+hillside
+judge
+watercolor illustration
+granite
+lobster
+livery
+stone
+ceramic
+ranch
+cloth
+smile
+building
+tattoo
+cricketer
+cheek
+pear
+source
+winter
+surface
+spray
+ceremony
+magic
+curve
+container
+fair
+medicine
+baby
+tennis racquet
+ornament
+bamboo
+duckling
+song
+safari
+team presentation
+daffodil
+cross
+toothpaste
+shield
+fashion model
+capsule
+map
+creek
+glass house
+glass plate
+siding
+corner
+water buffalo
+bison
+figure skater
+diploma
+tire
+race
+cable car
+brain
+gas stove
+soap bubble
+palette
+snowboard
+school child
+trench coat
+monk
+fiber
+kitchen window
+sunglass
+coffee
+security
+strawberry
+penguin
+tree root
+loaf
+engagement ring
+lamb
+vector cartoon illustration
+sandwich
+mountain village
+shape
+charm
+fiction
+knot
+greenhouse
+sushi
+text
+disaster
+trophy
+gang
+strap
+soccer game
+cardinal
+tee
+turtle
+water surface
+grassland
+dolphin
+store
+dirt
+iceberg
+pergola
+farmer market
+publicity portrait
+tote bag
+teenage girl
+view mirror
+session
+commuter
+dressing room
+tricycle
+christmas ball
+headlight
+police
+armchair
+chart
+yacht
+saw
+printer
+rock band
+gingerbread house
+tag
+table lamp
+hockey game
+slope
+font
+wicker basket
+jewelry
+quarter
+software
+weapon
+pin
+worship
+painter
+goal
+morning light
+bike
+baseball bat
+elevator
+cuisine
+sausage
+stunt
+wrestler
+statue
+landing
+pillar
+willow tree
+sea wave
+chicken
+peanut
+muscle
+bob
+tv genre
+bathroom window
+radish
+textile
+pelican
+marketplace
+crest
+elevation map
+gift
+parish
+traffic light
+campfire
+fog
+award winner
+beach ball
+mat
+white house
+plaster
+moped
+football team
+solution
+bicyclist
+bit
+playground
+darkness
+cake
+maple leave
+mold
+cracker
+blueberry
+rubble
+container ship
+pedestrian bridge
+snail
+parrot
+form
+circuit
+highlight
+pickup truck
+koala
+rain
+system
+weather
+raincoat
+soccer team
+windshield
+thunderstorm
+mike
+bird house
+bridge
+grandfather
+restroom
+animation
+wilderness
+clown
+banana
+brown
+braid
+dining room
+kindergarten
+launch event
+purple
+school
+stairwell
+brooch
+movie poster image
+mountain river
+shelf
+wicket
+headboard
+buddha
+flower field
+dugout
+cd
+bald eagle
+lagoon
+seaweed
+agriculture
+emergency service
+maple tree
+parachute
+continent
+amusement park
+remote
+bun
+tackle
+hospital
+garage door
+birthday party
+friendship
+go
+mausoleum
+jeep
+raccoon
+step
+ice hockey team
+cigarette
+lace dress
+forest floor
+mall
+captain
+milk
+golf course
+meal
+picnic table
+sail
+volleyball
+canal
+terrace
+computer desk
+caravan
+hotel
+cheerleader
+nurse
+museum
+marsh
+fox
+plateau
+night
+twin
+letter logo
+autumn tree
+powder
+convention
+creature
+lighthouse
+shop window
+jacket
+stork
+taxi
+trade
+blackboard
+olive
+road sign
+resort
+snowflake
+cemetery
+travel
+evening dress
+picnic
+drink
+winter morning
+football player
+snack
+boxing glove
+dinner party
+airline
+swing
+port
+wheelbarrow
+bathroom sink
+sweater
+ambulance
+gear
+oil
+wii controller
+array
+home office
+car show
+mixture
+profession
+tree frog
+square
+facility
+coral reef
+sea wall
+pizza
+exhibit
+demolition
+trout
+ring
+coffee shop
+bracelet
+bean
+lip
+fencing
+landscape
+sitting
+package
+metal
+bust
+king
+hair
+window seat
+wildlife
+trunk
+greenery
+stencil
+fire hydrant
+bridesmaid
+plaza
+alps
+tower bridge
+crop top
+crossing
+cinema
+pedestrian crossing
+family
+shopping cart
+stomach
+church building
+screen door
+skater
+soccer field
+kettle
+mussel
+raindrop
+candy cane
+water lily
+flower girl
+desert
+enclosure
+christmas light
+kitchen
+caterpillar
+plaid
+bath
+bush
+mud
+ballet
+knee
+adult
+raft
+sea view
+cactus
+office chair
+overall
+rim
+scaffolding
+pig
+cover
+poster page
+sprinkle
+chandelier
+algae
+traffic
+surfboard
+book
+filming
+flash
+mansion
+camouflage
+trouser
+ticket
+weed
+cab
+trench
+elephant
+huddle
+sphere
+christmas decoration
+city
+launch
+doll
+christmas ornament
+fabric
+bikini
+biplane
+breakfast
+neighbourhood
+race track
+foliage
+avocado
+school bus
+footwear
+highway
+ocean view
+art vector illustration
+wall clock
+curtain
+teenager
+kitchen area
+robot
+tusk
+lounge chair
+beam
+paddle
+camel
+lid
+world map
+city view
+newlywed
+cargo ship
+yellow
+exhibition
+bend
+novel
+wool
+ontario
+bread
+campus
+coastline
+cutting board
+booth
+table top
+carpet
+beach chair
+workout
+street food
+fun
+costumer film designer
+gadget
+artist
+fishing village
+builder
+violinist
+iphone
+spider web
+traffic sign
+ruin
+rescue
+clipboard
+seal
+film director
+paw
+nursery
+intersection
+tomato sauce
+taste
+paddy field
+christmas tree
+wave
+stool
+watering can
+rug
+daytime
+subway station
+craft
+pine forest
+black
+planet
+motif
+christmas market
+glass window
+college
+wheat
+damage
+rectangle
+picture frame
+chess
+guest room
+street corner
+religion
+seed
+puzzle
+freeway
+beauty
+ocean
+watch
+mother
+garage
+quote
+dj
+supporter
+hip hop artist
+muffin
+eiffel tower
+cash
+firefighter
+cauliflower
+bunker
+sled
+manicure
+shark
+stall
+jungle
+family home
+tour bus
+chimney
+touchdown
+roundabout
+coyote
+street scene
+tank
+wedding dress
+mantle
+bedroom window
+coconut
+chapel
+goat
+living space
+rock wall
+polka dot
+railway
+mandala
+mango
+lesson
+mountain landscape
+team photo
+bookshelf
+meter
+bulldog
+evening sun
+stick
+card
+pink
+fish pond
+paint
+pill
+cart
+pea
+van
+album
+football college game
+mountain pass
+doughnut
+ski slope
+match
+official
+shadow
+organ
+celebration
+coin
+log cabin
+firework display
+present
+twig
+chef
+confetti
+footpath
+tour
+ponytail
+artwork
+race car
+club
+season
+hose
+pencil
+aircraft
+rock formation
+wardrobe
+participant
+politician
+engineer
+peace
+filter
+sailing boat
+water bottle
+service dog
+poodle
+loki
+statesman
+sleeping bag
+outskirt
+clock
+factory
+oak tree
+physician
+color
+room
+stairway
+company
+lady
+graph
+faucet
+tablecloth
+subway train
+chocolate chip cookie
+headquarters
+screw
+goggle
+halloween
+city street
+swirl
+cord
+forward
+bone
+bedding
+archway
+wig
+lobby
+mask
+attic
+kitchen table
+skylight
+fire
+exit
+oil painting
+passenger
+meditation
+salmon
+fedora
+rubber stamp
+orange juice
+arch
+scientist
+stroll
+manhattan
+float
+baseball uniform
+circle
+church
+decker bus
+competitor
+zoo
+basketball team
+tourist
+daughter
+silverware
+ceiling fan
+birth
+vase
+jack
+mushroom
+spiral
+cage
+limb
+salad
+ad
+control
+earth
+party
+bolt
+tractor
+barley
+wedding photo
+hawk
+warehouse
+vegetable garden
+chocolate cake
+cabbage
+floor window
+baby shower
+magnifying glass
+table
+stethoscope
+reading
+mission
+croissant
+gift box
+rocket
+forest road
+cooking
+suite
+hill country
+motorcycle
+baseball player
+angle
+drug
+sport association
+championship
+family portrait
+florist
+softball
+egret
+office
+plywood
+jockey
+mosque
+brunch
+beanie
+office building
+pattern
+calendar
+indoor
+pepper
+ledge
+trail
+fuel
+laptop computer
+tennis shoe
+deck chair
+guitarist
+barn
+surgery
+cartoon illustration
+nebula
+railroad
+mountain goat
+goose
+car door
+cheer
+liquid
+hardwood floor
+pathway
+acorn
+gull
+airliner
+couch
+lake house
+spaghetti
+promenade
+collection
+garden
+bank
+robin
+tennis ball
+peony
+gymnast
+lavender
+deck
+test
+riverside
+rapper
+domino
+bride
+mouse
+basil
+wedding couple
+ocean wave
+arm
+kitchen floor
+grove
+family member
+backyard
+raspberry
+forest fire
+officer
+hibiscus
+canyon
+composer
+signature
+olive oil
+hibiscus flower
+rose
+vector icon
+sunrise
+horseback
+motor scooter
+office worker
+tradition
+ingredient
+washing machine
+lighting
+bagel
+sailboat
+policeman
+mare
+graphic
+halloween pumpkin
+stock
+pilot
+education
+team
+body
+horse
+kimono
+bazaar
+bag
+recording studio
+parsley
+entrance
+denim
+vet
+horse farm
+charcoal
+architecture
+glass vase
+puppy
+estuary
+television show host
+city bus
+shoulder
+beast
+balance
+golfer
+roadside
+denim jacket
+stone wall
+counter top
+app icon
+toast
+head coach
+ham
+warrior
+gem
+refrigerator
+snowman
+construction worker
+coal
+website
+morning fog
+mustard
+human
+owl
+puppy dog
+piggy bank
+vegetation
+pirate
+action film
+marshmallow
+thanksgiving
+business
+disease
+signage
+greeting
+skate park
+tile
+mouth
+spinach
+vacation
+leader
+shrine
+walker
+science fiction film
+bill
+rabbit
+motor boat
+bar
+radio
+barge
+tail
+chainsaw
+gallery
+rainbow
+pasta
+padlock
+web
+pastry
+ink
+reef
+school uniform
+shawl
+treasure
+peach
+dinner table
+injury
+harbor
+witch
+car dealership
+litter
+gesture
+documentary
+marriage
+sea shell
+priest
+dome
+kit
+icon
+seaside
+bucket
+entertainment
+stable
+hat
+puddle
+sock
+shopper
+technology
+harbour
+orbit
+antler
+tube
+flag waving
+cook
+tight
+commander
+farmland
+switch
+hiker
+wedding ceremony
+award ceremony
+champion
+chopstick
+farmhouse
+performer
+spike
+accident
+cruise ship
+passenger train
+attraction
+entertainer
+rear view
+sidewalk
+parade
+racing
+plane
+ritual
+peacock
+pocket
+plum
+drop
+carrot
+floor
+sunset
+troop
+architect
+coffee table
+dust
+outline
+leather
+charity event
+heat
+whale
+laundry
+coconut tree
+crosswalk
+pony
+ant
+pipe
+string
+coat
+angel
+beef
+church tower
+dish
+pitch
+cupboard
+thermometer
+dirt field
+fireworks
+minute
+cane
+pajama
+flower garden
+autumn
+trash can
+dachshund
+banana tree
+tray
+moose
+roadway
+carnival
+antenna
+pole
+castle wall
+ram
+cattle
+hay
+cookie
+swimmer
+baseball team
+strait
+hedge
+jet
+fire pit
+octopus
+calf
+cube
+opera
+cardboard box
+tiara
+kitchen sink
+prairie
+bowl
+galaxy
+straw hat
+linen
+ski resort
+stitch
+street lamp
+motorist
+icicle
+stain
+flora
+drain
+kitchen cabinet
+decor
+bouquet
+pound
+interior design
+nail polish
+figurine
+tomb
+disc
+twist
+blouse
+ribbon
+figure
+burger
+cork
+soccer goalkeeper
+train bridge
+drinking water
+dew
+baker
+storm cloud
+tarmac
+tv drama
+sponge
+magnet
+sailor
+entry
+swan
+exercise
+sloth
+jewel
+scuba diver
+bite
+cat tree
+tent
+can
+tennis match
+ecosystem
+picket fence
+palm
+train car
+frying pan
+rally
+tablet pc
+reindeer
+image
+wolf
+chin
+conservatory
+flood water
+cityscape
+beach sand
+car park
+pavement
+farm field
+swimming
+winter storm
+stem
+pillow
+inning
+gorilla
+desk
+avenue
+fern
+money
+pearl
+train station
+skillet
+nap
+barber
+library
+freezer
+label
+rainforest
+parking sign
+mirror
+wing
+noodle
+press room
+sculpture
+tablet
+viewer
+prayer
+mini
+mechanic
+laugh
+rice field
+hand
+mustache
+mountain road
+catwalk
+conference
+cape
+installation
+musician
+stream
+machine
+speech
+crocodile
+soccer match
+town square
+passport
+post box
+point
+stone building
+motorway
+mix
+dentist
+businessperson
+happiness
+boat
+vineyard
+treadmill
+glass wall
+water droplet
+coffee mug
+graduate
+sunflower
+parliament
+shepherd
+movie
+wine
+orchard
+tulip
+motherboard
+cup
+broom
+spot
+drawing
+polo shirt
+graduation
+film producer
+moonlight
+glow
+film format
+t shirt
+rock face
+sword
+clinic
+festival day
+meadow
+staple
+pupil
+training ground
+rider
+flower
+foal
+wharf
+foot bridge
+shooting
+top
+mast
+police car
+robe
+wedding bouquet
+stop sign
+birthday cake
+glitter
+butter
+scooter
+tundra
+superhero
+pocket watch
+inscription
+youngster
+fruit tree
+movie poster
+engine
+foundation
+motorcyclist
+take
+woman
+antelope
+country artist
+road trip
+typewriter
+tuxedo
+brand
+pine
+bathroom
+paradise
+texture
+balloon
+dining table
+home
+computer screen
+actor
+clip
+tv tower
+panorama
+summit
+cat
+plot
+eagle
+dancer
+pup
+studio shot
+tear
+bird bath
+classroom
+bookstore
+city wall
+tv programme
+blade
+easel
+buttercream
+sweet
+designer
+diamond
+handshake
+herb
+corn field
+seafront
+concrete
+street artist
+gas
+stamp
+window display
+paper
+note
+pint
+quarry
+research
+fixture
+manager
+soil
+leopard
+board game
+ladder
+stop light
+island
+ramp
+football match
+icing
+drill
+currency
+summer evening
+topping
+pyramid
+pomegranate
+cell
+ivy
+squad
+scenery
+computer
+locomotive
+surf
+mascot
+dune
+path
+duck
+twilight
+wire
+bow tie
+strike
+cormorant
+car wash
+crane
+market
+philosopher
+alarm clock
+camera
+birch
+greeting card
+plain
+clay
+donut
+lock
+moth
+laboratory
+fan
+violin
+jazz fusion artist
+mountain biker
+terrain
+magazine
+pickup
+comedy film
+smartphone
+film
+bed
+microwave oven
+tournament
+lawn
+car window
+alligator
+screen
+jetty
+shopping bag
+landscape view
+cabinetry
+friendly match
+thing
+petal
+shopping center
+transport
+ballet dancer
+shoreline
+princess
+car seat
+parking meter
+green
+vodka
+band
+rock
+costume
+warning sign
+strip
+plaque
+wheelchair
+headband
+ginger
+dice
+media
+hairdresser
+press
+living room
+stove
+player
+cherry
+workshop
+carving
+embroidery
+doodle
+adventure
+rugby player
+monument
+brush
+marker
+loft
+postcard
+collage
+ball
+professor
+dresser
+gig
+festival
+blackbird
+makeup artist
+video camera
+sticker
+peak
+wildflower
+santa hat
+rodeo
+wedding photographer
+guy
+staff
+waterfall
+operation
+defender
+falcon
+haze
+individual
+gentleman
+greyhound
+rocking chair
+rice
+garbage
+platter
+chocolate
+splash
+business suit
+cheetah
+valley
+maze
+trampoline
+garland
+slalom
+unicorn
+tree stump
+painting
+romance
+fight
+alcohol
+ghost
+fondant
+spa
+shutter
+death
+demonstration
+cotton
+pier
+flea market
+history
+savannah
+fist
+aisle
+crew
+jug
+pose
+anchor
+teapot
+boat house
+business team
+tripod
+bee
+pebble
+mattress
+canvas
+hallway
+campaign
+pod
+lake district
+article
+white
+sofa
+honey
+marathon
+pancake
+tourist attraction
+wedding gown
+battle
+shelving
+sea
+sheet music
+pie
+yarn
+construction site
+flyer
+tie
+star
+lettuce
+martial artist
+dart
+straw
+reflection
+conference room
+temperature
+rugby
+mosquito
+physicist
+rock climber
+crash
+backdrop
+toilet seat
+sand castle
+water park
+toy car
+waste
+luxury
+hangar
+rv
+tree trunk
+board
+gold
+project picture
+cap
+cottage
+relief
+attire
+microscope
+battery
+roll
+line
+parking garage
+crystal
+broadcasting
+brick wall
+lab
+flooring
+meeting
+3d cg rendering
+desktop computer
+cowboy
+sailing ship
+junction
+hairstyle
+homework
+profile
+model
+flower pot
+street light
+salt lake
+maple
+space
+blizzard
+throw
+zebras
+brochure
+constellation
+beak
+kilt
+pond
+blue sky
+sneaker
+sand dune
+morning sun
+almond
+grill
+curl
+basketball girl game
+chameleon
+toilet bowl
+prince
+keyboard
+queen
+computer monitor
+writing
+crown
+basilica
+kiss
+house
+parking
+football competition
+shell
+sport equipment
+comedy
+baboon
+vendor
+rise building
+wrap
+food truck
+cat bed
+rickshaw
+flare
+teal
+nectar
+eclipse
+vehicle
+steam locomotive
+gorge
+cow
+christmas card
+demonstrator
+memorial
+towel
+jewellery
+train
+frisbee
+baseball game
+fur
+afternoon sun
+community
+sparkler
+bandage
+firework
+dollar
+pasture
+video
+bus
+tree house
+seashore
+field
+hamburger
+souvenir
+hedgehog
+worm
+pine cone
+osprey
+dinosaur
+vegetable
+junk
+poster
+army
+winger
+bundle
+stage
+growth
+wedding party
+service
+blanket
+ruler
+eye
+credit card
+castle
+diner
+hut
+elk
+hard rock artist
+nun
+dog breed
+nest
+drama film
+number icon
+water tank
+giraffe
+altar
+pavilion
+tv personality
+suv
+street vendor
+street sign
+ditch
+debris
+foam
+takeoff
+spice
+mountain lake
+tea
+orchestra
+spacecraft
+counter
+abbey
+mountain
+hydrangea
+racer
+orange tree
+tide
+cowboy hat
+rapid
+town
+wild
+herd
+vein
+driveway
+jar
+bark
+illustration
+horror film
+corn
+stroller
+industry
+mountain stream
+gym
+neckline
+pan
+client
+spectator
+eggplant
+camper
+fawn
+hoodie
+meat
+lemonade
+food market
+slum
+comic book character
+flower market
+love
+palace
+gun
+heel
+shopping street
+shooting basketball guard
+family photo
+rooftop
+laundry basket
+airport runway
+horn
+face mask
+flight
+appetizer
+violet
+country lane
+cement
+instrument
+tv actor
+spark
+celebrity
+award
+country house
+standing
+auction
+date
+engagement
+puck
+advertisement
+chair
+zebra
+driftwood
+bumblebee
+maple leaf
+bonnet
+orange
+water tower
+door
+singer
+floor plan
+discussion
+theatre
+pilgrim
+mug
+branch
+window sill
+baseball pitcher
+bakery
+lollipop
+basketball player
+toilet paper
+chalkboard
+cabin
+sign
+night sky
+cannon
+fishing net
+submarine
+suit
+fur coat
+wine bottle
+folder
+street art
+suspension bridge
+evening sky
+billboard
+postage stamp
+newspaper
+transportation
+surgeon
+light
+park
+horizon
+road
+sand bar
+trumpet
+lounge
+cloud forest
+birthday celebration
+balcony
+anime
+beehive
+umbrella
+goldfish
+baseball cap
+waterhole
+ceiling
+carousel
+backpack
+plant pot
+atmosphere
+sunflower field
+spire
+vision
+woodpecker
+chip
+pool table
+lotus flower
+cone
+humpback whale
+reservoir
+hunt
+piano
+plate
+dining area
+luggage
+skier
+dance floor
+crow
+stair
+overpass
+opera house
+bear
+jazz artist
+water
+vessel
+cast
+yard
+cathedral
+basketball hoop
+graveyard
+sound
+berry
+onlooker
+fauna
+birch tree
+retail
+hill
+skeleton
+journalist
+frost
+basket
+nail
+dusk
+trash
+dawn
+clover
+hen
+volcano
+basketball coach
+home decor
+charge
+haircut
+sense
+university
+lizard
+daisy
+tablet computer
+grass field
+prison
+metal artist
+bathroom mirror
+window frame
+chest
+flavor
+pop country artist
+market square
+monkey
+blog
+deer
+speech bubble
+dog
+independence day
+girl
+boy
+tartan
+furniture
+appliance
+office window
+fish boat
+sand box
+tv sitcom
+drama
+sleigh
+depression
+paper towel
+baseball
+protestor
+grape
+wedding cake
+invitation
+accessory
+pick
+grandparent
+racket
+tea plantation
+outdoors
+egg
+glass bowl
+sun
+organization
+lion
+panel
+station
+wallpaper
+helicopter
+salt
+vanity
+patio
+lunch
+street performer
+mountain range
+soup
+bacon
+power station
+cantilever bridge
+hummingbird
+shirt
+rope
+hip
+chalk
+pendant
+choir
+tv
+lichen
+railway bridge
+art gallery
+bartender
+wagon
+baby elephant
+accordion
+horseshoe
+building site
+clutch
+harvest
+savanna
+geranium
+business woman
+paddock
+patch
+beech tree
+war
+suburbs
+hospital bed
+motorcycle racer
+moss
+gravel
+government agency
+dollar bill
+father
+fjord
+concert
+nut
+wedding photography
+finish line
+home plate
+food
+nose
+thumb
+village
+dining room table
+bumper
+monster
+blackberry
+lime
+conflict
+gala
+wallet
+wrist
+hug
+mermaid
+lava
+lawyer
+folk rock artist
+arena
+onion
+toothbrush
+fashion
+perfume
+flip
+triangle
+woodland
+mail
+grasshopper
+studio
+wood floor
+den
+racquet
+cello
+lemur
+astronaut
+glass table
+blood
+dvd
+planter
+silver
+leash
+master bedroom
+forest
+batter
+shoe
+engraving
+opening
+product
+toe
+cocktail
+mallard duck
+bike ride
+oasis
+wedding ring
+cinematographer
+holly
+autograph
+fence
+ice cube
+cove
+pineapple
+aurora
+glass bead
+produce
+apartment building
+cob
+miniature
+cockpit
+flashlight
+frog
+sheep
+groom
+steel
+watermelon
+clip art
+paper plate
+ostrich
+contour
+mural
+cub
+paisley bandanna
+winery
+turn
+handle
+satellite
+post
+pork
+child
+asphalt
+grocery store
+vulture
+trolley
+nightclub
+brick
+trailer
+compass
+cereal
+cafe
+cartoon character
+sugar
+fiction book
+glass floor
+umpire
+guitar
+hamster
+protester
+airplane
+garment
+blazer
+railway line
+wedding
+shoe box
+parking lot
+construction
+graduation ceremony
+tram
+telescope
+copper
+pain
+autumn forest
+guest house
+partner
+crayon
+dip
+boot
+corridor
+computer keyboard
+hockey player
+chicken coop
+bus station
+gathering
+ankle
+bunk bed
+wood table
+football coach
+monarch
+pharmacy
+legging
+mannequin
+female
+train track
+stack
+canopy
+design element
+grandmother
+symbol
+beach hut
+zucchini
+bomb
+businessman
+skyscraper
+tongue
+case
+sparkle
+highland
+ballroom
+prom
+estate
+customer
+archipelago
+cheese
+debate
+carriage
+bulldozer
+pumpkin
+sitting room
+gas station
+wedding reception
+camp
+dog bed
+tower
+property
+river bed
+pop latin artist
+fridge
+wine glass
+coast
+beer
+tow truck
+fire truck
+mountain bike
+thigh
+heron
+boat ride
+gondola
+turquoise
+lake
+llama
+kitty
+tin
+waiting room
+coffee cup
+socialite
+guard
+tap
+waterway
+forehead
+list
+erosion
+box
+sea lion
+pollen
+dam
+wasp
+salon
+tennis tournament
+flower box
+aquarium
+rain cloud
+clothing store
+lead singer
+cupcake
+tortoise
+lettering
+sport facility
+dance
+dog house
+nature
+football
+rooster
+footballer
+railway track
+crowd
+fishing rod
+silhouette
+wind turbine
+sari
+bus window
+cloud
+charity
+medal
+yoga
+event
+veil
+fashion menswear milan week
+news
+knife
+print
+screen tv
+walnut
+fungus
+ice cream
+computer mouse
+play
+tribe
+picture
+video game
+business card
+music festival
+rack
+envelope
+shower
+dirt road
+mine
+oyster
+monarch butterfly
+dude
+fruit salad
+podium
+fork
+lace
+test match
+boulder
+cricket player
+staircase
+peninsula
+shopping
+popcorn
+oak
+market stall
+pine tree
+mountaineer
+student
+closet
+hood
+handstand
+centerpiece
+insect
+patient
+makeover
+tennis player
+sheet
+park bench
+apple
+organism
+hook
+turkey
+tangerine
+sibling
+shopping mall
+bird
+scarf
+smoothie
+net
+grass
+napkin
+ray
+eyebrow
+laptop keyboard
+motorbike
+woman hand
+oven
+book cover
+easter egg
+microwave
+sand
+snapshot
+soccer ball
+makeup
+knight
+bowling ball
+shower curtain
+flame
+lightning
+running
+power plant
+crib
+cartoon
+moat
+fashion girl
+wedding invitation
+bottle
+cliff
+monastery
+file photo
+apartment
+casino
+cream
+sweatshirt
+storm
+cruise
+teddy bear
+shovel
+wind farm
+writer
+dock
+professional
+hotel room
+job
+monitor
+donkey
+pass
+interview
+duchess
+mark
+plank
+beard
+zombie
+trio
+channel
+cricket team
+windmill
+vest
+diagram
+cable
+winter scene
+golden gate bridge
+buffalo
+studio portrait
+pagoda
+whiskey
+freight train
+kite
+future
+steam train
+phone box
+headset
+wood
+snowboarder
+paper bag
+slide
+grapefruit
+seating
+morning
+bronze sculpture
+theatre actor
+stump
+jean
+landmark
+jam
+waist
+watercolor
+hammock
+light fixture
+ice
+basin
+beverage
+shelter
+premiere
+mound
+ear
+bronze
+sunlight
+street
+energy
+barn door
+hike
+fleet
+claw
+beach
+pepperoni
+bin
+trainer
+buffet
+archive
+toddler
+referee
+bay window
+dove
+production company
+evening light
+gate
+farm
+reed
+fruit stand
+explorer
+snow storm
+throw pillow
+button
+display case
+bookcase
+lead
+lipstick
+basketball court
+cargo
+ensemble
+pope
+clock tower
+teen
+speaker
+rat
+laptop
+ski
+mess
+stadium
+ferry boat
+bunny
+waterfront
+downtown
+sink
+press conference
+dinner
+condiment
+thread
+audience
+grid
+car
+plastic
+people
+barbecue
+pigeon
+urinal
+seagull
+volunteer
+hockey
+fir tree
+pollution
+trial
+collar
+area
+meeting room
+circus
+yogurt
+orangutan
+viaduct
+comedian
+drone
+scissor
+pop rock artist
+biscuit
+panda
+water feature
+air balloon
+remote control
+watercolor painting
+show
+walk
+post office
+bike path
+rap gangsta artist
+microphone
+crack
+sunset sky
+glass
+tv show
+cartoon style
+stripe
+foyer
+signal
+calligraphy
+bulb
+gardener
+coffee bean
+spider
+tapestry
+city skyline
+necklace
+kitten
+traveler
+veteran
+frosting
+fry
+tennis court
+tank top
+butterfly house
+mist
+drummer
+water level
+scale
+baseball glove
+music video performer
+champagne
+camping
+clothing
+water drop
+telephone box
+pen
+morning mist
+fire engine
+porch
+opening ceremony
+style
+palm tree
+fashion show
+universe
+scratch
+axe
+ottoman
+explosion
+rib
+boutique
+game
+cucumber
+fruit
+stone bridge
+nature reserve
+track
+train window
+punch
+telephone pole
+velvet
+sauce
+moon
+contrast
+flamingo
+bat
+vending machine
+ship
+equestrian
+shade
+comforter
+pallet
+sparrow
+wii
+glaze
+grocery
+steeple
+soccer player
+contract
+advertising
+runner
+chimpanzee
+world
+seat
+project
+chihuahua
+bubble
+willow
+pedestal
+soul hip hop artist
+curb
+drawer
+leaf
+banner
+launch party
+coach
+government
+snowball
+toy
+portrait
+doctor
+whiteboard
+electronic
+tiger
+graffiti
+column
+nightstand
+whistle
+maxi dress
+bench
+wetsuit
+bird feeder
+football game
+basketball
+class
+bathroom door
+store window
+text message
+wreath
+street view
+binocular
+pet
+facade
+drought
+lemon
+new year
+night view
+airplane window
+specie
+rule
+jaw
+wheat field
+diet
+pop artist
+habitat
+screenshot
+scoreboard
+shore
+mane
+quilt
+ski lift
+orchid
+turban
+christmas
+airport
+marina
+glass door
+glass bottle
+restaurant
+conductor
+logo
+sleep
+tape
+tomato
+river bank
+lilac
+tooth
+training
+pottery
+shop
+steam engine
+mason jar
+base
+procession
+border
+shoot
+footprint
+hotdog
+bull
+stocking
+recreation
+automobile model
+design
+country pop artist
+river
+retriever
+department store
+auditorium
+sport car
+supermarket
+belt
+cricket
+window box
+dress shirt
+letter
+residence
+megaphone
+pant
+wildfire
+bird nest
+crab
+swimsuit
+candle
+funeral
+mill
+national park
+plant
+cop
+power line
+perch
+blue
+finger
+ferris wheel
+globe
+skateboard
+helmet
+movie theater
+uniform
+hammer
+material
+kid
+well
+butterfly
+sideline
+fashion fall show
+planet earth
+lift
+male
+sauna
+gray
+flour
+sand sculpture
+program
+cabinet
+infant
+wheel
+aircraft model
+dough
+garlic
+skate
+arrow
+wrapping paper
+ripple
+lamp
+iron
+banknote
+beaver
+ferry
+courtyard
+bassist
+countryside
+steak
+comfort
+boxer
+laundry room
+campsite
+brick building
+golf
+subway
+headphone
+fort
+handbag
+drum
+flood
+saddle
+bass
+labyrinth
+needle
+sun ray
+app
+menu
+president
+cardigan
+dandelion
+wetland
+ice hockey player
+number
+city hall
+fishing
+portrait session
+pug
+key
+art print
+minister
+hurdle
+emergency
+painting artist
+flag pole
+evening
+purse
+recipe
+golf ball
+coloring book
+mountain peak
+senior
+holiday
+bud
+cousin
+pantry
+lap
+skin
+flag
+tissue paper
+ridge
+wire fence
+surfer
+climber
+photograph
+sewing machine
+cooler
+actress
+apple tree
+cancer
+starfish
+automobile make
+dumbbell
+brace
+tunnel
+window
+paint artist
+composition
+school student
+condo
+convertible
+cushion
+selfie
+territory
+guide
+tree
+court
+shrimp
+stone house
+dress
+eyelash
+juice
+broccoli
+chain
+tourism
+mountain top
+concept car
+film premiere
+light bulb
+cafeteria
+badge
+flower bed
+theater
+root
+racecar driver
+basketball boy game
+glove
+skyline
+wall
+glacier
+airport terminal
+bug
+trim
+railway station
+briefcase
+flat
+fountain
+person
+lane
+asparagus
+art
+lantern
+dishwasher
+director
+snake
+lecture
+game controller
+tree branch
+pub
+bathing suit
+queue
+belly
+poppy
+bow
+pitcher
+ice cream cone
+cave
+candy
+road bridge
+host
+traffic jam
+earring
+file
+foot
+watermark overlay stamp
+mailbox
+supercar
+railing
+bedroom
+seafood
+waffle
+bronze statue
+plan
+flow
+marble
+basketball game
+automobile
+scene
+cypress tree
+soldier
+skateboarder
+glass building
+cherry tree
+pump
+grain
+wildebeest
+loop
+frame
+bathtub
+saxophone
+diver
+stalk
+lily
+bead
+alley
+flock
+family room
+manufacturing
+pointer
+worker
+navy
+potato
+teacher
+photography
+dolly
+boardwalk
+water fountain
+athlete
+side dish
+bay
+ice hockey
+phone
+hero
+face
+gold medal
+blind
+swamp
+researcher
+swim
+meatball
+iguana
+leather jacket
+jellyfish
+site
+smoke
+traffic signal
+melon
+beetle
+calculator
+skirt
+plantation
+sculptor
+barrier
+catcher
+security guard
+sketch
+awning
+steering wheel
+mountain view
+bus stop
+pool
+leg
+spotlight
+apron
+mineral
+inlet
+sleeve
+torch
+emotion
+march
+police officer
+performance
+lamp post
+fishing boat
+summer
+presentation
+saucer
+suitcase
+supermodel
+goalkeeper
+shrub
+rock artist
+document
+beach house
+man
+blue artist
+cigar
+railroad track
+gown
+mosaic
+bungalow
+alphabet
+baseball field
+shed
+pedestrian
+rail
+soap
+kitchen counter
+dessert
+dunk
+blossom
+conversation
+fruit market
+glass jar
+military
+beer bottle
+photographer
+tennis racket
+competition
+escalator
+bell tower
+stilt
+ballerina
+television
+feather
+fence post
+rear
+dahlia
+red carpet
+tub
+hole
+fortress
+pack
+telephone
+cardboard
+city park
+platform
+college student
+arch bridge
+wind
+blender
+bloom
+ice rink
+birthday
+raven
+fairy
+embankment
+hall
+flower shop
+suburb
+barrel
+biker
+steam
+dragonfly
+formation
+electricity
+business people
+symmetry
+walkway
+fisherman
+gas mask
+loch
+youth
+hanger
+dot
+fish
+street market
+animation film
+crime fiction film
+boar
+emblem
+halloween costume
+kangaroo
+couple
+spoon
+squirrel
+neon sign
+sky
+office desk
+beauty salon
+breakwater
+fashion look
+toaster
+author
+news conference
+outdoor
+canoe
+dragon
+tool
+shopping centre
+ladybug
+swimming pool
+landscaping
+ski pole
+red
+truck
+fly
+temple
+level
+sunday
+railroad bridge
+car mirror
+lawn mower
+flute
+aircraft carrier
+fashion menswear london week
+sunshine
+tile floor
+skull
+fossil
+flower arrangement
+diaper
+sea turtle
+cherry blossom
+fireman
+shack
+lens
+waiter
+animal
+basement
+snow
+autumn park
+glass box
+kick
+head
+anniversary
+vine
+back
+paper lantern
+fish tank
+cellphone
+silk
+coral
+notebook
+photo
+gazebo
+ketchup
+driver
+farmer
+bonfire
+chestnut
+photoshoot
+football field
+olive tree
+pheasant
+sandal
+toilet
+fireplace
+music
+deity
+fish market
+fig
+bell
+neck
+grave
+villa
+cyclist
+crate
+grey
+asphalt road
+soccer
+hostel
+municipality
+courthouse
+roof
+end table
+pot
+sedan
+structure
+folk artist
+sport
+sport team
+protest
+syringe
+fashion designer
+jersey
+heart shape
+kayak
+stare
+sit with
+direct
+read
+photograph
+spin
+teach
+laugh
+carve
+grow on
+warm
+watch
+stretch
+smell
+decorate
+shine
+light
+dance
+send
+park
+chase
+collect
+lead
+kiss
+lead to
+lick
+smile
+cheer
+sit
+point
+block
+rock
+drop
+cut
+ski
+wrap
+lose
+serve
+provide
+sleep
+dress
+embrace
+burn
+pack
+stir
+create
+touch
+wash
+stick
+reveal
+shop
+train
+paint
+groom
+hunt
+bloom
+play
+pay
+brush
+shoot
+hold
+picture
+carry
+sip
+contain
+turn
+pour
+pitch
+give
+add
+blow
+look in
+show
+walk
+illuminate
+kneel
+cover
+drag
+post
+present
+fit
+operate
+fish
+race
+write
+deliver
+peel
+push
+run
+sit around
+buy
+jump
+walk on
+attend
+clean
+sell
+ride on
+mount
+host
+dry
+plant
+sing
+row
+shake
+perch
+ride
+fight
+skateboard
+live
+call
+surround
+practice
+play on
+work on
+step
+relax
+hit
+fall in
+flow
+greet
+launch
+wear
+hang on
+drive
+sit in
+break
+learn
+fly
+connect
+display
+locate
+compete
+go for
+sail
+lift
+toast
+help
+run on
+reflect
+pose
+scratch
+frame
+dribble
+herd
+enter
+exit
+place
+inspect
+build
+pick
+fill
+grind
+skate
+offer
+float
+sit by
+stand
+release
+rest
+singe
+climb
+tie
+mark
+lay
+stand around
+capture
+set
+land
+swinge
+run in
+kick
+lean
+head
+sign
+approach
+swim
+close
+crash
+control
+fall
+remove
+repair
+open
+appear
+travel
+load
+miss
+check
+surf
+moor
+smoke
+drink
+board
+seat
+feed
+rise
+sit on
+swing
+grow
+strike
+date
+slide
+share
+graze
+jump in
+lie
+extrude
+roll
+move
+gather
+eat
+pull
+run through
+squeeze
+lay on
+draw
+play with
+wave
+assemble
+perform
+march
+score
+attach
+adjust
+hang
+hug
+sleep on
+throw
+live in
+talk
+pet
+work
+run with
+see
+flip
+catch
+cook
+receive
+celebrate
+look
+classic
+bridal
+indoor
+industrial
+teenage
+mini
+grassy
+aged
+long
+warm
+light
+handsome
+happy
+three
+pregnant
+circular
+urban
+silver
+ceramic
+3d
+green
+blonde
+golden
+dark
+tropical
+ripe
+deep
+fat
+musical
+giant
+medical
+medieval
+bare
+stunning
+bold
+geographical
+huge
+plastic
+foggy
+stormy
+gothic
+biological
+empty
+clear
+antique
+pink
+steep
+brown
+striped
+aerial
+rainy
+cool
+flying
+commercial
+purple
+trendy
+blank
+haired
+dead
+wooden
+flat
+high
+beige
+panoramic
+angry
+dozen
+rural
+solar
+big
+small
+stained
+thick
+many
+fresh
+clean
+strong
+abstract
+crowded
+retro
+dry
+gorgeous
+martial
+modern
+blue
+cloudy
+low
+four
+outdoor
+single
+much
+beautiful
+snowy
+pretty
+new
+short
+sunny
+closed
+rocky
+red
+two
+double
+male
+gray
+five
+colorful
+automotive
+various
+one
+old
+rusty
+tall
+wild
+narrow
+natural
+several
+frozen
+textured
+lush
+young
+hot
+mixed
+white
+float
+quiet
+round
+bright
+religious
+female
+historical
+shiny
+traditional
+tourist
+yellow
+bald
+coastal
+lovely
+little
+broken
+romantic
+wide
+royal
+rich
+open
+cute
+ancient
+cold
+political
+elderly
+gold
+full
+rustic
+metallic
+floral
+sad
+wet
+fancy
+senior
+tiny
+stylish
+large
+frosty
+orange
+transparent
+electronic
+shallow
+scared
+armed
+dirty
+historic
+black
+few
+windy
+some
+square
+ornamental
+sandy
+thin
\ No newline at end of file
diff --git a/ram/inference.py b/ram/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..182efc55098ce201cdc776236aa8c9468845cb41
--- /dev/null
+++ b/ram/inference.py
@@ -0,0 +1,46 @@
+'''
+ * The Inference of RAM and Tag2Text Models
+ * Written by Xinyu Huang
+'''
+import torch
+
+
+def inference_tag2text(image, model, input_tag="None"):
+
+ with torch.no_grad():
+ caption, tag_predict = model.generate(image,
+ tag_input=None,
+ max_length=50,
+ return_tag_predict=True)
+
+ if input_tag == '' or input_tag == 'none' or input_tag == 'None':
+ return tag_predict[0], None, caption[0]
+
+ # If user input specified tags:
+ else:
+ input_tag_list = []
+ input_tag_list.append(input_tag.replace(',', ' | '))
+
+ with torch.no_grad():
+ caption, input_tag = model.generate(image,
+ tag_input=input_tag_list,
+ max_length=50,
+ return_tag_predict=True)
+
+ return tag_predict[0], input_tag[0], caption[0]
+
+
+def inference_ram(image, model):
+
+ with torch.no_grad():
+ tags, tags_chinese = model.generate_tag(image)
+
+ return tags[0],tags_chinese[0]
+
+
+def inference_ram_openset(image, model):
+
+ with torch.no_grad():
+ tags = model.generate_tag_openset(image)
+
+ return tags[0]
diff --git a/ram/models/__init__.py b/ram/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2592c66e617522fb0feefb2d8c93f95b4c9fce8
--- /dev/null
+++ b/ram/models/__init__.py
@@ -0,0 +1,2 @@
+from .ram import ram
+from .tag2text import tag2text
diff --git a/ram/models/__pycache__/__init__.cpython-310.pyc b/ram/models/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3405ec562266d31a60af8d1b7d93837d811e3aa1
Binary files /dev/null and b/ram/models/__pycache__/__init__.cpython-310.pyc differ
diff --git a/ram/models/__pycache__/bert.cpython-310.pyc b/ram/models/__pycache__/bert.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b8a04283df21a92effa647ed8a6d29f5e1f5a925
Binary files /dev/null and b/ram/models/__pycache__/bert.cpython-310.pyc differ
diff --git a/ram/models/__pycache__/bert_lora.cpython-310.pyc b/ram/models/__pycache__/bert_lora.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1139087f35c8132cc99fd6527cbc6bbb87f07633
Binary files /dev/null and b/ram/models/__pycache__/bert_lora.cpython-310.pyc differ
diff --git a/ram/models/__pycache__/condition_network.cpython-310.pyc b/ram/models/__pycache__/condition_network.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3a5b476674867645b53e27bfc33e59298393ee65
Binary files /dev/null and b/ram/models/__pycache__/condition_network.cpython-310.pyc differ
diff --git a/ram/models/__pycache__/ram.cpython-310.pyc b/ram/models/__pycache__/ram.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ddffe81f7db05f60dc4703d9489eca0b9dc177cd
Binary files /dev/null and b/ram/models/__pycache__/ram.cpython-310.pyc differ
diff --git a/ram/models/__pycache__/ram_condition.cpython-310.pyc b/ram/models/__pycache__/ram_condition.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c1a6f7c519f2963c91c363cba66f6500715cfbc5
Binary files /dev/null and b/ram/models/__pycache__/ram_condition.cpython-310.pyc differ
diff --git a/ram/models/__pycache__/ram_lora.cpython-310.pyc b/ram/models/__pycache__/ram_lora.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bda72c013d020ad984d230ca9eb01478ce79b44b
Binary files /dev/null and b/ram/models/__pycache__/ram_lora.cpython-310.pyc differ
diff --git a/ram/models/__pycache__/ram_swin_bert_lora.cpython-310.pyc b/ram/models/__pycache__/ram_swin_bert_lora.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c2bc78b407b1eadc7d6e17828119702334ac70c5
Binary files /dev/null and b/ram/models/__pycache__/ram_swin_bert_lora.cpython-310.pyc differ
diff --git a/ram/models/__pycache__/ram_swin_lora.cpython-310.pyc b/ram/models/__pycache__/ram_swin_lora.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..444519ccfd6112e30a709d4f21db01f3a3a42748
Binary files /dev/null and b/ram/models/__pycache__/ram_swin_lora.cpython-310.pyc differ
diff --git a/ram/models/__pycache__/swin_transformer.cpython-310.pyc b/ram/models/__pycache__/swin_transformer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7ea1c80a3dacaf5a3e9b8bb19d314105f8a9b118
Binary files /dev/null and b/ram/models/__pycache__/swin_transformer.cpython-310.pyc differ
diff --git a/ram/models/__pycache__/swin_transformer_lora.cpython-310.pyc b/ram/models/__pycache__/swin_transformer_lora.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f5353f255c087afa65617fe7f8de2fe1f25dd167
Binary files /dev/null and b/ram/models/__pycache__/swin_transformer_lora.cpython-310.pyc differ
diff --git a/ram/models/__pycache__/tag2text.cpython-310.pyc b/ram/models/__pycache__/tag2text.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..afbed4cd11f2ae696d644ef7bc998279976a4e92
Binary files /dev/null and b/ram/models/__pycache__/tag2text.cpython-310.pyc differ
diff --git a/ram/models/__pycache__/tag2text_lora.cpython-310.pyc b/ram/models/__pycache__/tag2text_lora.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2686e84b7b3fa46d24080f066ab6ca3989cf2e5e
Binary files /dev/null and b/ram/models/__pycache__/tag2text_lora.cpython-310.pyc differ
diff --git a/ram/models/__pycache__/utils.cpython-310.pyc b/ram/models/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..62e105e9fe1c37baf9f79552e583cb4e9b755de6
Binary files /dev/null and b/ram/models/__pycache__/utils.cpython-310.pyc differ
diff --git a/ram/models/__pycache__/vit.cpython-310.pyc b/ram/models/__pycache__/vit.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..434d3930038ac1adf32c6ecf72d5463ee4649077
Binary files /dev/null and b/ram/models/__pycache__/vit.cpython-310.pyc differ
diff --git a/ram/models/bert.py b/ram/models/bert.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb90b794284d2262d171aa0f93fdf20854a9059b
--- /dev/null
+++ b/ram/models/bert.py
@@ -0,0 +1,1035 @@
+'''
+ * Copyright (c) 2022, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+ * Based on huggingface code base
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
+'''
+
+import math
+import os
+import warnings
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+import torch
+from torch import Tensor, device, dtype, nn
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+import torch.nn.functional as F
+
+from transformers.activations import ACT2FN
+from transformers.file_utils import (
+ ModelOutput,
+)
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ BaseModelOutputWithPoolingAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ MaskedLMOutput,
+ MultipleChoiceModelOutput,
+ NextSentencePredictorOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from transformers.modeling_utils import (
+ PreTrainedModel,
+ apply_chunking_to_forward,
+ find_pruneable_heads_and_indices,
+ prune_linear_layer,
+)
+from transformers.utils import logging
+from transformers.models.bert.configuration_bert import BertConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class BertEmbeddings_nopos(nn.Module):
+ """Construct the embeddings from word and position embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ # self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ # self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
+ # self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+
+ self.config = config
+
+ def forward(
+ self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
+ ):
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ # if position_ids is None:
+ # position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ embeddings = inputs_embeds
+
+ # if self.position_embedding_type == "absolute":
+ # position_embeddings = self.position_embeddings(position_ids)
+ # # print('add position_embeddings!!!!')
+ # embeddings += position_embeddings
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+
+
+class BertEmbeddings(nn.Module):
+ """Construct the embeddings from word and position embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+
+ self.config = config
+
+ def forward(
+ self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
+ ):
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ embeddings = inputs_embeds
+
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids)
+ # print('add position_embeddings!!!!')
+ embeddings += position_embeddings
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class BertSelfAttention(nn.Module):
+ def __init__(self, config, is_cross_attention):
+ super().__init__()
+ self.config = config
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ "The hidden size (%d) is not a multiple of the number of attention "
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ if is_cross_attention:
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
+ else:
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
+ self.save_attention = False
+
+ def save_attn_gradients(self, attn_gradients):
+ self.attn_gradients = attn_gradients
+
+ def get_attn_gradients(self):
+ return self.attn_gradients
+
+ def save_attention_map(self, attention_map):
+ self.attention_map = attention_map
+
+ def get_attention_map(self):
+ return self.attention_map
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+ mixed_query_layer = self.query(hidden_states)
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention:
+ # print(self.key.weight.shape)
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+ else:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ past_key_value = (key_layer, value_layer)
+
+ # compatible with higher versions of transformers
+ if key_layer.shape[0] > query_layer.shape[0]:
+ key_layer = key_layer[:query_layer.shape[0], :, :, :]
+ attention_mask = attention_mask[:query_layer.shape[0], :, :]
+ value_layer = value_layer[:query_layer.shape[0], :, :, :]
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ seq_length = hidden_states.size()[1]
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
+ distance = position_ids_l - position_ids_r
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
+
+ if self.position_embedding_type == "relative_key":
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores
+ elif self.position_embedding_type == "relative_key_query":
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
+
+ if is_cross_attention and self.save_attention:
+ self.save_attention_map(attention_probs)
+ attention_probs.register_hook(self.save_attn_gradients)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs_dropped = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs_dropped = attention_probs_dropped * head_mask
+
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+
+class BertSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertAttention(nn.Module):
+ def __init__(self, config, is_cross_attention=False):
+ super().__init__()
+ self.self = BertSelfAttention(config, is_cross_attention)
+ self.output = BertSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class BertIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class BertOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertLayer(nn.Module):
+ def __init__(self, config, layer_num):
+ super().__init__()
+ self.config = config
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = BertAttention(config)
+ self.layer_num = layer_num
+ if self.config.add_cross_attention:
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
+ self.intermediate = BertIntermediate(config)
+ self.output = BertOutput(config)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ mode=None,
+ ):
+
+ if mode == 'tagging':
+
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
+
+ cross_attention_outputs = self.crossattention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ output_attentions=output_attentions,
+ )
+ attention_output = cross_attention_outputs[0]
+ outputs = cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
+
+ present_key_value = cross_attention_outputs[-1]
+
+ else:
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions=output_attentions,
+ past_key_value=self_attn_past_key_value,
+ )
+ attention_output = self_attention_outputs[0]
+
+ outputs = self_attention_outputs[1:-1]
+ present_key_value = self_attention_outputs[-1]
+
+ if mode=='multimodal':
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
+
+ cross_attention_outputs = self.crossattention(
+ attention_output,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ output_attentions=output_attentions,
+ )
+ attention_output = cross_attention_outputs[0]
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
+ )
+ outputs = (layer_output,) + outputs
+
+ outputs = outputs + (present_key_value,)
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+class BertEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ mode='multimodal',
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+
+ next_decoder_cache = () if use_cache else None
+
+ for i in range(self.config.num_hidden_layers):
+ layer_module = self.layer[i]
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ past_key_value = past_key_values[i] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ if use_cache:
+ logger.warn(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, past_key_value, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ mode=mode,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ mode=mode,
+ )
+
+ hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache += (layer_outputs[-1],)
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ next_decoder_cache,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_decoder_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+class BertPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states):
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+class BertPredictionHeadTransform(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if isinstance(config.hidden_act, str):
+ self.transform_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.transform_act_fn = config.hidden_act
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.transform_act_fn(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+class BertLMPredictionHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.transform = BertPredictionHeadTransform(config)
+
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+ self.decoder.bias = self.bias
+
+ def forward(self, hidden_states):
+ hidden_states = self.transform(hidden_states)
+ hidden_states = self.decoder(hidden_states)
+ return hidden_states
+
+
+class BertOnlyMLMHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = BertLMPredictionHead(config)
+
+ def forward(self, sequence_output):
+ prediction_scores = self.predictions(sequence_output)
+ return prediction_scores
+
+
+class BertPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = BertConfig
+ base_model_prefix = "bert"
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def _init_weights(self, module):
+ """ Initialize the weights """
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+
+class BertModel(BertPreTrainedModel):
+ """
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
+ all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
+ input to the forward pass.
+ """
+
+ def __init__(self, config, add_pooling_layer=True):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = BertEmbeddings(config)
+
+ self.encoder = BertEncoder(config)
+
+ self.pooler = BertPooler(config) if add_pooling_layer else None
+
+ self.init_weights()
+
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
+ """
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
+
+ Arguments:
+ attention_mask (:obj:`torch.Tensor`):
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
+ input_shape (:obj:`Tuple[int]`):
+ The shape of the input to the model.
+ device: (:obj:`torch.device`):
+ The device of the input to the model.
+
+ Returns:
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
+ """
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ if attention_mask.dim() == 3:
+ extended_attention_mask = attention_mask[:, None, :, :]
+ elif attention_mask.dim() == 2:
+ # Provided a padding mask of dimensions [batch_size, seq_length]
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if is_decoder:
+ batch_size, seq_length = input_shape
+
+ seq_ids = torch.arange(seq_length, device=device)
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
+ # causal and attention masks must have same type with pytorch version < 1.3
+ causal_mask = causal_mask.to(attention_mask.dtype)
+
+ if causal_mask.shape[1] < attention_mask.shape[1]:
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
+ causal_mask = torch.cat(
+ [
+ torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
+ causal_mask,
+ ],
+ axis=-1,
+ )
+
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
+ else:
+ extended_attention_mask = attention_mask[:, None, None, :]
+ else:
+ raise ValueError(
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
+ input_shape, attention_mask.shape
+ )
+ )
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
+ return extended_attention_mask
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ encoder_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ is_decoder=False,
+ mode='multimodal',
+ ):
+ r"""
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+ use_cache (:obj:`bool`, `optional`):
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+ decoding (see :obj:`past_key_values`).
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if is_decoder:
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ else:
+ use_cache = False
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ batch_size, seq_length = input_shape
+ device = input_ids.device
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ batch_size, seq_length = input_shape
+ device = inputs_embeds.device
+ elif encoder_embeds is not None:
+ input_shape = encoder_embeds.size()[:-1]
+ batch_size, seq_length = input_shape
+ device = encoder_embeds.device
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
+
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if attention_mask is None:
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
+ device, is_decoder)
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if encoder_hidden_states is not None:
+ if type(encoder_hidden_states) == list:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
+ else:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+
+ if type(encoder_attention_mask) == list:
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
+ elif encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ if encoder_embeds is None:
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_key_values_length,
+ )
+ else:
+ embedding_output = encoder_embeds
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ mode=mode,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ past_key_values=encoder_outputs.past_key_values,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
+
+class BertLMHeadModel(BertPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bert = BertModel(config, add_pooling_layer=False)
+ self.cls = BertOnlyMLMHead(config)
+
+ self.init_weights()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ labels=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ return_logits=False,
+ is_decoder=True,
+ reduction='mean',
+ mode='multimodal',
+ ):
+ r"""
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+ use_cache (:obj:`bool`, `optional`):
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+ decoding (see :obj:`past_key_values`).
+ Returns:
+ Example::
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
+ >>> import torch
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+ >>> prediction_logits = outputs.logits
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ if labels is not None:
+ use_cache = False
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ is_decoder=is_decoder,
+ mode=mode,
+ )
+
+ sequence_output = outputs[0]
+ prediction_scores = self.cls(sequence_output)
+ # sequence_output.shape torch.Size([85, 30, 768])
+ # prediction_scores.shape torch.Size([85, 30, 30524])
+ # labels.shape torch.Size([85, 30])
+
+
+ if return_logits:
+ return prediction_scores[:, :-1, :].contiguous()
+
+ lm_loss = None
+ if labels is not None:
+ # we are doing next-token prediction; shift prediction scores and input ids by one
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
+ labels = labels[:, 1:].contiguous()
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+ if reduction=='none':
+ lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((lm_loss,) + output) if lm_loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=lm_loss,
+ logits=prediction_scores,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
+ input_shape = input_ids.shape
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_shape)
+
+ # cut decoder_input_ids if past is used
+ if past is not None:
+ input_ids = input_ids[:, -1:]
+
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "past_key_values": past,
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
+ "is_decoder": True,
+ }
+
+ def _reorder_cache(self, past, beam_idx):
+ reordered_past = ()
+ for layer_past in past:
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
+ return reordered_past
+
+
diff --git a/ram/models/bert_lora.py b/ram/models/bert_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b01f4105846a8a2a1f1ebd96a7f713022ded9c1
--- /dev/null
+++ b/ram/models/bert_lora.py
@@ -0,0 +1,1040 @@
+'''
+ * Copyright (c) 2022, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+ * Based on huggingface code base
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
+'''
+
+import math
+import os
+import warnings
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+import torch
+from torch import Tensor, device, dtype, nn
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+import torch.nn.functional as F
+
+from transformers.activations import ACT2FN
+from transformers.file_utils import (
+ ModelOutput,
+)
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ BaseModelOutputWithPoolingAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ MaskedLMOutput,
+ MultipleChoiceModelOutput,
+ NextSentencePredictorOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from transformers.modeling_utils import (
+ PreTrainedModel,
+ apply_chunking_to_forward,
+ find_pruneable_heads_and_indices,
+ prune_linear_layer,
+)
+from transformers.utils import logging
+from transformers.models.bert.configuration_bert import BertConfig
+
+import loralib as lora
+
+
+logger = logging.get_logger(__name__)
+
+
+class BertEmbeddings_nopos(nn.Module):
+ """Construct the embeddings from word and position embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ # self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ # self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
+ # self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+
+ self.config = config
+
+ def forward(
+ self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
+ ):
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ # if position_ids is None:
+ # position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ embeddings = inputs_embeds
+
+ # if self.position_embedding_type == "absolute":
+ # position_embeddings = self.position_embeddings(position_ids)
+ # # print('add position_embeddings!!!!')
+ # embeddings += position_embeddings
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+
+
+class BertEmbeddings(nn.Module):
+ """Construct the embeddings from word and position embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+
+ self.config = config
+
+ def forward(
+ self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
+ ):
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ embeddings = inputs_embeds
+
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids)
+ # print('add position_embeddings!!!!')
+ embeddings += position_embeddings
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class BertSelfAttention(nn.Module):
+ def __init__(self, config, is_cross_attention):
+ super().__init__()
+ self.config = config
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ "The hidden size (%d) is not a multiple of the number of attention "
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ # self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ self.query = lora.Linear(config.hidden_size, self.all_head_size, r=8)
+ if is_cross_attention:
+ # self.key = nn.Linear(config.encoder_width, self.all_head_size)
+ self.key = lora.Linear(config.encoder_width, self.all_head_size, r=8)
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
+ else:
+ # self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.key = lora.Linear(config.hidden_size, self.all_head_size, r=8)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
+ self.save_attention = False
+
+ def save_attn_gradients(self, attn_gradients):
+ self.attn_gradients = attn_gradients
+
+ def get_attn_gradients(self):
+ return self.attn_gradients
+
+ def save_attention_map(self, attention_map):
+ self.attention_map = attention_map
+
+ def get_attention_map(self):
+ return self.attention_map
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+ mixed_query_layer = self.query(hidden_states)
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention:
+ # print(self.key.weight.shape)
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+ else:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ past_key_value = (key_layer, value_layer)
+
+ # compatible with higher versions of transformers
+ if key_layer.shape[0] > query_layer.shape[0]:
+ key_layer = key_layer[:query_layer.shape[0], :, :, :]
+ attention_mask = attention_mask[:query_layer.shape[0], :, :]
+ value_layer = value_layer[:query_layer.shape[0], :, :, :]
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ seq_length = hidden_states.size()[1]
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
+ distance = position_ids_l - position_ids_r
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
+
+ if self.position_embedding_type == "relative_key":
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores
+ elif self.position_embedding_type == "relative_key_query":
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
+
+ if is_cross_attention and self.save_attention:
+ self.save_attention_map(attention_probs)
+ attention_probs.register_hook(self.save_attn_gradients)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs_dropped = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs_dropped = attention_probs_dropped * head_mask
+
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+
+class BertSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertAttention(nn.Module):
+ def __init__(self, config, is_cross_attention=False):
+ super().__init__()
+ self.self = BertSelfAttention(config, is_cross_attention)
+ self.output = BertSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class BertIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class BertOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertLayer(nn.Module):
+ def __init__(self, config, layer_num):
+ super().__init__()
+ self.config = config
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = BertAttention(config)
+ self.layer_num = layer_num
+ if self.config.add_cross_attention:
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
+ self.intermediate = BertIntermediate(config)
+ self.output = BertOutput(config)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ mode=None,
+ ):
+
+ if mode == 'tagging':
+
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
+
+ cross_attention_outputs = self.crossattention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ output_attentions=output_attentions,
+ )
+ attention_output = cross_attention_outputs[0]
+ outputs = cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
+
+ present_key_value = cross_attention_outputs[-1]
+
+ else:
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions=output_attentions,
+ past_key_value=self_attn_past_key_value,
+ )
+ attention_output = self_attention_outputs[0]
+
+ outputs = self_attention_outputs[1:-1]
+ present_key_value = self_attention_outputs[-1]
+
+ if mode=='multimodal':
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
+
+ cross_attention_outputs = self.crossattention(
+ attention_output,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ output_attentions=output_attentions,
+ )
+ attention_output = cross_attention_outputs[0]
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
+ )
+ outputs = (layer_output,) + outputs
+
+ outputs = outputs + (present_key_value,)
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+class BertEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ mode='multimodal',
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+
+ next_decoder_cache = () if use_cache else None
+
+ for i in range(self.config.num_hidden_layers):
+ layer_module = self.layer[i]
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ past_key_value = past_key_values[i] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ if use_cache:
+ logger.warn(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, past_key_value, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ mode=mode,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ mode=mode,
+ )
+
+ hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache += (layer_outputs[-1],)
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ next_decoder_cache,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_decoder_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+class BertPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states):
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+class BertPredictionHeadTransform(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if isinstance(config.hidden_act, str):
+ self.transform_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.transform_act_fn = config.hidden_act
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.transform_act_fn(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+class BertLMPredictionHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.transform = BertPredictionHeadTransform(config)
+
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+ self.decoder.bias = self.bias
+
+ def forward(self, hidden_states):
+ hidden_states = self.transform(hidden_states)
+ hidden_states = self.decoder(hidden_states)
+ return hidden_states
+
+
+class BertOnlyMLMHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = BertLMPredictionHead(config)
+
+ def forward(self, sequence_output):
+ prediction_scores = self.predictions(sequence_output)
+ return prediction_scores
+
+
+class BertPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = BertConfig
+ base_model_prefix = "bert"
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def _init_weights(self, module):
+ """ Initialize the weights """
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+
+class BertModel(BertPreTrainedModel):
+ """
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
+ all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
+ input to the forward pass.
+ """
+
+ def __init__(self, config, add_pooling_layer=True):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = BertEmbeddings(config)
+
+ self.encoder = BertEncoder(config)
+
+ self.pooler = BertPooler(config) if add_pooling_layer else None
+
+ self.init_weights()
+
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
+ """
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
+
+ Arguments:
+ attention_mask (:obj:`torch.Tensor`):
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
+ input_shape (:obj:`Tuple[int]`):
+ The shape of the input to the model.
+ device: (:obj:`torch.device`):
+ The device of the input to the model.
+
+ Returns:
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
+ """
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ if attention_mask.dim() == 3:
+ extended_attention_mask = attention_mask[:, None, :, :]
+ elif attention_mask.dim() == 2:
+ # Provided a padding mask of dimensions [batch_size, seq_length]
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if is_decoder:
+ batch_size, seq_length = input_shape
+
+ seq_ids = torch.arange(seq_length, device=device)
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
+ # causal and attention masks must have same type with pytorch version < 1.3
+ causal_mask = causal_mask.to(attention_mask.dtype)
+
+ if causal_mask.shape[1] < attention_mask.shape[1]:
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
+ causal_mask = torch.cat(
+ [
+ torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
+ causal_mask,
+ ],
+ axis=-1,
+ )
+
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
+ else:
+ extended_attention_mask = attention_mask[:, None, None, :]
+ else:
+ raise ValueError(
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
+ input_shape, attention_mask.shape
+ )
+ )
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
+ return extended_attention_mask
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ encoder_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ is_decoder=False,
+ mode='multimodal',
+ ):
+ r"""
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+ use_cache (:obj:`bool`, `optional`):
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+ decoding (see :obj:`past_key_values`).
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if is_decoder:
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ else:
+ use_cache = False
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ batch_size, seq_length = input_shape
+ device = input_ids.device
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ batch_size, seq_length = input_shape
+ device = inputs_embeds.device
+ elif encoder_embeds is not None:
+ input_shape = encoder_embeds.size()[:-1]
+ batch_size, seq_length = input_shape
+ device = encoder_embeds.device
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
+
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if attention_mask is None:
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
+ device, is_decoder)
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if encoder_hidden_states is not None:
+ if type(encoder_hidden_states) == list:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
+ else:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+
+ if type(encoder_attention_mask) == list:
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
+ elif encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ if encoder_embeds is None:
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_key_values_length,
+ )
+ else:
+ embedding_output = encoder_embeds
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ mode=mode,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ past_key_values=encoder_outputs.past_key_values,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
+
+class BertLMHeadModel(BertPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bert = BertModel(config, add_pooling_layer=False)
+ self.cls = BertOnlyMLMHead(config)
+
+ self.init_weights()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ labels=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ return_logits=False,
+ is_decoder=True,
+ reduction='mean',
+ mode='multimodal',
+ ):
+ r"""
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+ use_cache (:obj:`bool`, `optional`):
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+ decoding (see :obj:`past_key_values`).
+ Returns:
+ Example::
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
+ >>> import torch
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+ >>> prediction_logits = outputs.logits
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ if labels is not None:
+ use_cache = False
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ is_decoder=is_decoder,
+ mode=mode,
+ )
+
+ sequence_output = outputs[0]
+ prediction_scores = self.cls(sequence_output)
+ # sequence_output.shape torch.Size([85, 30, 768])
+ # prediction_scores.shape torch.Size([85, 30, 30524])
+ # labels.shape torch.Size([85, 30])
+
+
+ if return_logits:
+ return prediction_scores[:, :-1, :].contiguous()
+
+ lm_loss = None
+ if labels is not None:
+ # we are doing next-token prediction; shift prediction scores and input ids by one
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
+ labels = labels[:, 1:].contiguous()
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+ if reduction=='none':
+ lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((lm_loss,) + output) if lm_loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=lm_loss,
+ logits=prediction_scores,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
+ input_shape = input_ids.shape
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_shape)
+
+ # cut decoder_input_ids if past is used
+ if past is not None:
+ input_ids = input_ids[:, -1:]
+
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "past_key_values": past,
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
+ "is_decoder": True,
+ }
+
+ def _reorder_cache(self, past, beam_idx):
+ reordered_past = ()
+ for layer_past in past:
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
+ return reordered_past
+
+
diff --git a/ram/models/ram.py b/ram/models/ram.py
new file mode 100644
index 0000000000000000000000000000000000000000..3642b2ddc894e70bd90ead311e3833f4cda9d565
--- /dev/null
+++ b/ram/models/ram.py
@@ -0,0 +1,317 @@
+'''
+ * The Recognize Anything Model (RAM)
+ * Written by Xinyu Huang
+'''
+import json
+import warnings
+
+import numpy as np
+import torch
+from torch import nn
+
+from .bert import BertConfig, BertLMHeadModel, BertModel
+from .swin_transformer import SwinTransformer
+from .utils import *
+
+warnings.filterwarnings("ignore")
+
+
+
+class RAM(nn.Module):
+ def __init__(self,
+ med_config=f'{CONFIG_PATH}/configs/med_config.json',
+ image_size=384,
+ vit='base',
+ vit_grad_ckpt=False,
+ vit_ckpt_layer=0,
+ prompt='a picture of ',
+ threshold=0.68,
+ delete_tag_index=[],
+ tag_list=f'{CONFIG_PATH}/data/ram_tag_list.txt',
+ tag_list_chinese=f'{CONFIG_PATH}/data/ram_tag_list_chinese.txt'):
+ r""" The Recognize Anything Model (RAM) inference module.
+ RAM is a strong image tagging model, which can recognize any common category with high accuracy.
+ Described in the paper " Recognize Anything: A Strong Image Tagging Model" https://recognize-anything.github.io/
+
+ Args:
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
+ image_size (int): input image size
+ vit (str): model size of vision transformer
+ threshold (int): tagging threshold
+ delete_tag_index (list): delete some tags that may disturb captioning
+ """
+ super().__init__()
+
+ # create image encoder
+ if vit == 'swin_b':
+ if image_size == 224:
+ vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_224.json'
+ elif image_size == 384:
+ vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_384.json'
+ vision_config = read_json(vision_config_path)
+ assert image_size == vision_config['image_res']
+ # assert config['patch_size'] == 32
+ vision_width = vision_config['vision_width']
+
+ self.visual_encoder = SwinTransformer(
+ img_size=vision_config['image_res'],
+ patch_size=4,
+ in_chans=3,
+ embed_dim=vision_config['embed_dim'],
+ depths=vision_config['depths'],
+ num_heads=vision_config['num_heads'],
+ window_size=vision_config['window_size'],
+ mlp_ratio=4.,
+ qkv_bias=True,
+ drop_rate=0.0,
+ drop_path_rate=0.1,
+ ape=False,
+ patch_norm=True,
+ use_checkpoint=False)
+
+ elif vit == 'swin_l':
+ if image_size == 224:
+ vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_224.json'
+ elif image_size == 384:
+ vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_384.json'
+ elif image_size == 444:
+ vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_444.json'
+ vision_config = read_json(vision_config_path)
+ assert image_size == vision_config['image_res']
+ # assert config['patch_size'] == 32
+ vision_width = vision_config['vision_width']
+
+ self.visual_encoder = SwinTransformer(
+ img_size=vision_config['image_res'],
+ patch_size=4,
+ in_chans=3,
+ embed_dim=vision_config['embed_dim'],
+ depths=vision_config['depths'],
+ num_heads=vision_config['num_heads'],
+ window_size=vision_config['window_size'],
+ mlp_ratio=4.,
+ qkv_bias=True,
+ drop_rate=0.0,
+ drop_path_rate=0.1,
+ ape=False,
+ patch_norm=True,
+ use_checkpoint=False)
+
+ else:
+ self.visual_encoder, vision_width = create_vit(
+ vit, image_size, vit_grad_ckpt, vit_ckpt_layer)
+
+ # create tokenzier
+ self.tokenizer = init_tokenizer()
+
+ # Tag2Text employ encoder-decoder architecture for image-tag-text generation: image-tag interaction encoder and image-tag-text decoder
+ # create image-tag interaction encoder
+ encoder_config = BertConfig.from_json_file(med_config)
+ encoder_config.encoder_width = 512
+ self.tag_encoder = BertModel(config=encoder_config,
+ add_pooling_layer=False)
+
+ # create image-tag-text decoder
+ decoder_config = BertConfig.from_json_file(med_config)
+ self.text_decoder = BertLMHeadModel(config=decoder_config)
+
+ self.delete_tag_index = delete_tag_index
+ self.prompt = prompt
+ self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1
+
+ # load tag list
+ self.tag_list = self.load_tag_list(tag_list)
+ self.tag_list_chinese = self.load_tag_list(tag_list_chinese)
+
+ # create image-tag recognition decoder
+ self.threshold = threshold
+ self.num_class = len(self.tag_list)
+ q2l_config = BertConfig.from_json_file(f'{CONFIG_PATH}/configs/q2l_config.json')
+ q2l_config.encoder_width = 512
+ self.tagging_head = BertModel(config=q2l_config,
+ add_pooling_layer=False)
+ self.tagging_head.resize_token_embeddings(len(self.tokenizer))
+ # self.label_embed = nn.Embedding(self.num_class, q2l_config.hidden_size)
+ self.label_embed = nn.Parameter(torch.zeros(self.num_class, q2l_config.encoder_width))
+
+ if q2l_config.hidden_size != 512:
+ self.wordvec_proj = nn.Linear(512, q2l_config.hidden_size)
+ else:
+ self.wordvec_proj = nn.Identity()
+
+ self.fc = nn.Linear(q2l_config.hidden_size, 1)
+
+ self.del_selfattention()
+
+ # share weights of the lowest 2-layer of "image-tag interaction encoder" with the "image-tag recogntion decoder"
+ tie_encoder_decoder_weights(self.tag_encoder, self.tagging_head, '',
+ ' ')
+ self.image_proj = nn.Linear(vision_width, 512)
+ # self.label_embed = nn.Parameter(torch.load(f'{CONFIG_PATH}/data/textual_label_embedding.pth',map_location='cpu').float())
+
+ # adjust thresholds for some tags
+ self.class_threshold = torch.ones(self.num_class) * self.threshold
+ ram_class_threshold_path = f'{CONFIG_PATH}/data/ram_tag_list_threshold.txt'
+ with open(ram_class_threshold_path, 'r', encoding='utf-8') as f:
+ ram_class_threshold = [float(s.strip()) for s in f]
+ for key,value in enumerate(ram_class_threshold):
+ self.class_threshold[key] = value
+
+ def load_tag_list(self, tag_list_file):
+ with open(tag_list_file, 'r', encoding="utf-8") as f:
+ tag_list = f.read().splitlines()
+ tag_list = np.array(tag_list)
+ return tag_list
+
+ # delete self-attention layer of image-tag recognition decoder to reduce computation, follower Query2Label
+ def del_selfattention(self):
+ del self.tagging_head.embeddings
+ for layer in self.tagging_head.encoder.layer:
+ del layer.attention
+
+ def condition_forward(self,
+ image,
+ threshold=0.68,
+ condition_flag=None,
+ tag_input=None,
+ only_feature=True,
+ ):
+
+ label_embed = torch.nn.functional.relu(self.wordvec_proj(self.label_embed))
+
+ image_embeds = self.image_proj(self.visual_encoder(image))
+ if only_feature:
+ return image_embeds
+ else:
+ image_atts = torch.ones(image_embeds.size()[:-1],
+ dtype=torch.long).to(image.device)
+
+ # recognized image tags using image-tag recogntiion decoder
+ image_cls_embeds = image_embeds[:, 0, :]
+ image_spatial_embeds = image_embeds[:, 1:, :]
+
+ bs = image_spatial_embeds.shape[0]
+ label_embed = label_embed.unsqueeze(0).repeat(bs, 1, 1)
+ tagging_embed = self.tagging_head(
+ encoder_embeds=label_embed,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=False,
+ mode='tagging',
+ )
+
+ logits = self.fc(tagging_embed[0]).squeeze(-1)
+
+ targets = torch.where(
+ torch.sigmoid(logits) > self.class_threshold.to(image.device),
+ torch.tensor(1.0).to(image.device),
+ torch.zeros(self.num_class).to(image.device))
+
+ return image_embeds, logits, targets
+
+ def generate_tag(self,
+ image,
+ threshold=0.68,
+ tag_input=None,
+ ):
+
+ label_embed = torch.nn.functional.relu(self.wordvec_proj(self.label_embed))
+
+ image_embeds = self.image_proj(self.visual_encoder(image))
+ image_atts = torch.ones(image_embeds.size()[:-1],
+ dtype=torch.long).to(image.device)
+
+ # recognized image tags using image-tag recogntiion decoder
+ image_cls_embeds = image_embeds[:, 0, :]
+ image_spatial_embeds = image_embeds[:, 1:, :]
+
+ bs = image_spatial_embeds.shape[0]
+ label_embed = label_embed.unsqueeze(0).repeat(bs, 1, 1)
+ tagging_embed = self.tagging_head(
+ encoder_embeds=label_embed,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=False,
+ mode='tagging',
+ )
+
+ logits = self.fc(tagging_embed[0]).squeeze(-1)
+
+ targets = torch.where(
+ torch.sigmoid(logits) > self.class_threshold.to(image.device),
+ torch.tensor(1.0).to(image.device),
+ torch.zeros(self.num_class).to(image.device))
+
+ tag = targets.cpu().numpy()
+ tag[:,self.delete_tag_index] = 0
+ tag_output = []
+ tag_output_chinese = []
+ for b in range(bs):
+ index = np.argwhere(tag[b] == 1)
+ token = self.tag_list[index].squeeze(axis=1)
+ # tag_output.append(' | '.join(token))
+ tag_output.append(', '.join(token))
+ token_chinese = self.tag_list_chinese[index].squeeze(axis=1)
+ # tag_output_chinese.append(' | '.join(token_chinese))
+ tag_output_chinese.append(', '.join(token_chinese))
+
+
+ return tag_output, tag_output_chinese
+
+ def generate_tag_openset(self,
+ image,
+ threshold=0.68,
+ tag_input=None,
+ ):
+
+ label_embed = torch.nn.functional.relu(self.wordvec_proj(self.label_embed))
+
+ image_embeds = self.image_proj(self.visual_encoder(image))
+ image_atts = torch.ones(image_embeds.size()[:-1],
+ dtype=torch.long).to(image.device)
+
+ # recognized image tags using image-tag recogntiion decoder
+ image_cls_embeds = image_embeds[:, 0, :]
+ image_spatial_embeds = image_embeds[:, 1:, :]
+
+ bs = image_spatial_embeds.shape[0]
+ label_embed = label_embed.unsqueeze(0).repeat(bs, 1, 1)
+ tagging_embed = self.tagging_head(
+ encoder_embeds=label_embed,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=False,
+ mode='tagging',
+ )
+
+ logits = self.fc(tagging_embed[0]).squeeze(-1)
+
+ targets = torch.where(
+ torch.sigmoid(logits) > self.class_threshold.to(image.device),
+ torch.tensor(1.0).to(image.device),
+ torch.zeros(self.num_class).to(image.device))
+
+ tag = targets.cpu().numpy()
+ tag[:,self.delete_tag_index] = 0
+ tag_output = []
+ for b in range(bs):
+ index = np.argwhere(tag[b] == 1)
+ token = self.tag_list[index].squeeze(axis=1)
+ tag_output.append(' | '.join(token))
+
+ return tag_output
+
+
+# load RAM pretrained model parameters
+def ram(pretrained='', **kwargs):
+ model = RAM(**kwargs)
+ if pretrained:
+ if kwargs['vit'] == 'swin_b':
+ model, msg = load_checkpoint_swinbase(model, pretrained, kwargs)
+ elif kwargs['vit'] == 'swin_l':
+ model, msg = load_checkpoint_swinlarge(model, pretrained, kwargs)
+ else:
+ model, msg = load_checkpoint(model, pretrained)
+ print('vit:', kwargs['vit'])
+# print('msg', msg)
+ return model
diff --git a/ram/models/ram_lora.py b/ram/models/ram_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..c63ed6913e0cb6fa7027670b84cc4f915c9df8b0
--- /dev/null
+++ b/ram/models/ram_lora.py
@@ -0,0 +1,344 @@
+'''
+ * The Recognize Anything Model (RAM)
+ * Written by Xinyu Huang
+'''
+import json
+import warnings
+
+import numpy as np
+import torch
+from torch import nn
+
+
+from .bert_lora import BertConfig, BertLMHeadModel, BertModel
+from .swin_transformer_lora import SwinTransformer
+from .utils import *
+
+warnings.filterwarnings("ignore")
+
+
+
+class RAMLora(nn.Module):
+ def __init__(self,
+ condition_config=f'{CONFIG_PATH}/configs/condition_config.json',
+ med_config=f'{CONFIG_PATH}/configs/med_config.json',
+ image_size=384,
+ vit='base',
+ vit_grad_ckpt=False,
+ vit_ckpt_layer=0,
+ prompt='a picture of ',
+ threshold=0.68,
+ max_threthold=0.9,
+ add_threthold=0,
+ delete_tag_index=[],
+ tag_list=f'{CONFIG_PATH}/data/ram_tag_list.txt',
+ tag_list_chinese=f'{CONFIG_PATH}/data/ram_tag_list_chinese.txt'):
+ r""" The Recognize Anything Model (RAM) inference module.
+ RAM is a strong image tagging model, which can recognize any common category with high accuracy.
+ Described in the paper " Recognize Anything: A Strong Image Tagging Model" https://recognize-anything.github.io/
+
+ Args:
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
+ image_size (int): input image size
+ vit (str): model size of vision transformer
+ threshold (int): tagging threshold
+ delete_tag_index (list): delete some tags that may disturb captioning
+ """
+ super().__init__()
+
+ # create image encoder
+ if vit == 'swin_b':
+ if image_size == 224:
+ vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_224.json'
+ elif image_size == 384:
+ vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_384.json'
+ vision_config = read_json(vision_config_path)
+ assert image_size == vision_config['image_res']
+ # assert config['patch_size'] == 32
+ vision_width = vision_config['vision_width']
+
+ self.visual_encoder = SwinTransformer(
+ img_size=vision_config['image_res'],
+ patch_size=4,
+ in_chans=3,
+ embed_dim=vision_config['embed_dim'],
+ depths=vision_config['depths'],
+ num_heads=vision_config['num_heads'],
+ window_size=vision_config['window_size'],
+ mlp_ratio=4.,
+ qkv_bias=True,
+ drop_rate=0.0,
+ drop_path_rate=0.1,
+ ape=False,
+ patch_norm=True,
+ use_checkpoint=False)
+
+ elif vit == 'swin_l':
+ if image_size == 224:
+ vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_224.json'
+ elif image_size == 384:
+ vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_384.json'
+ elif image_size == 444:
+ vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_444.json'
+ vision_config = read_json(vision_config_path)
+ assert image_size == vision_config['image_res']
+ # assert config['patch_size'] == 32
+ vision_width = vision_config['vision_width']
+
+ self.visual_encoder = SwinTransformer(
+ img_size=vision_config['image_res'],
+ patch_size=4,
+ in_chans=3,
+ embed_dim=vision_config['embed_dim'],
+ depths=vision_config['depths'],
+ num_heads=vision_config['num_heads'],
+ window_size=vision_config['window_size'],
+ mlp_ratio=4.,
+ qkv_bias=True,
+ drop_rate=0.0,
+ drop_path_rate=0.1,
+ ape=False,
+ patch_norm=True,
+ use_checkpoint=False)
+
+ else:
+ self.visual_encoder, vision_width = create_vit(
+ vit, image_size, vit_grad_ckpt, vit_ckpt_layer)
+
+ # create tokenzier
+ self.tokenizer = init_tokenizer()
+
+ # Tag2Text employ encoder-decoder architecture for image-tag-text generation: image-tag interaction encoder and image-tag-text decoder
+ # create image-tag interaction encoder
+ encoder_config = BertConfig.from_json_file(med_config)
+ encoder_config.encoder_width = 512
+ self.tag_encoder = BertModel(config=encoder_config,
+ add_pooling_layer=False)
+
+ # create image-tag-text decoder
+ decoder_config = BertConfig.from_json_file(med_config)
+ self.text_decoder = BertLMHeadModel(config=decoder_config)
+
+ self.delete_tag_index = delete_tag_index
+ self.prompt = prompt
+ self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1
+
+ # load tag list
+ self.tag_list = self.load_tag_list(tag_list)
+ self.tag_list_chinese = self.load_tag_list(tag_list_chinese)
+
+ # create image-tag recognition decoder
+ self.threshold = threshold
+ self.num_class = len(self.tag_list)
+ q2l_config = BertConfig.from_json_file(f'{CONFIG_PATH}/configs/q2l_config.json')
+ q2l_config.encoder_width = 512
+ self.tagging_head = BertModel(config=q2l_config,
+ add_pooling_layer=False)
+ self.tagging_head.resize_token_embeddings(len(self.tokenizer))
+ # self.label_embed = nn.Embedding(self.num_class, q2l_config.hidden_size)
+ self.label_embed = nn.Parameter(torch.zeros(self.num_class, q2l_config.encoder_width))
+
+ if q2l_config.hidden_size != 512:
+ self.wordvec_proj = nn.Linear(512, q2l_config.hidden_size)
+ else:
+ self.wordvec_proj = nn.Identity()
+
+ self.fc = nn.Linear(q2l_config.hidden_size, 1)
+
+ self.del_selfattention()
+
+ # share weights of the lowest 2-layer of "image-tag interaction encoder" with the "image-tag recogntion decoder"
+ tie_encoder_decoder_weights(self.tag_encoder, self.tagging_head, '',
+ ' ')
+ self.image_proj = nn.Linear(vision_width, 512)
+ # self.label_embed = nn.Parameter(torch.load(f'{CONFIG_PATH}/data/textual_label_embedding.pth',map_location='cpu').float())
+
+ # adjust thresholds for some tags
+ self.class_threshold = torch.ones(self.num_class) * self.threshold
+
+ print(f'Loading default thretholds from .txt....')
+ ram_class_threshold_path = f'{CONFIG_PATH}/data/ram_tag_list_threshold.txt'
+ with open(ram_class_threshold_path, 'r', encoding='utf-8') as f:
+ ram_class_threshold = [float(s.strip()) for s in f]
+ for key,value in enumerate(ram_class_threshold):
+ if value > max_threthold:
+ self.class_threshold[key] = value
+ else:
+ self.class_threshold[key] = min(value + add_threthold, max_threthold)
+
+
+
+ def load_tag_list(self, tag_list_file):
+ with open(tag_list_file, 'r', encoding="utf-8") as f:
+ tag_list = f.read().splitlines()
+ tag_list = np.array(tag_list)
+ return tag_list
+
+ # delete self-attention layer of image-tag recognition decoder to reduce computation, follower Query2Label
+ def del_selfattention(self):
+ del self.tagging_head.embeddings
+ for layer in self.tagging_head.encoder.layer:
+ del layer.attention
+
+ def generate_image_embeds(self,
+ image,
+ condition=False
+ ):
+
+ image_embeds = self.image_proj(self.visual_encoder(image))
+
+ return image_embeds
+
+ def generate_tag(self,
+ image,
+ threshold=0.68,
+ tag_input=None,
+ ):
+
+ label_embed = torch.nn.functional.relu(self.wordvec_proj(self.label_embed))
+
+ image_embeds = self.image_proj(self.visual_encoder(image))
+
+ image_atts = torch.ones(image_embeds.size()[:-1],
+ dtype=torch.long).to(image.device)
+
+ # recognized image tags using image-tag recogntiion decoder
+ image_cls_embeds = image_embeds[:, 0, :]
+ image_spatial_embeds = image_embeds[:, 1:, :]
+
+ bs = image_spatial_embeds.shape[0]
+ label_embed = label_embed.unsqueeze(0).repeat(bs, 1, 1)
+ tagging_embed = self.tagging_head(
+ encoder_embeds=label_embed,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=False,
+ mode='tagging',
+ )
+
+ logits = self.fc(tagging_embed[0]).squeeze(-1)
+ targets = torch.where(
+ torch.sigmoid(logits) > self.class_threshold.to(image.device),
+ torch.tensor(1.0).to(image.device),
+ torch.zeros(self.num_class).to(image.device))
+
+ tag = targets.cpu().numpy()
+ tag[:,self.delete_tag_index] = 0
+ tag_output = []
+ tag_output_chinese = []
+ for b in range(bs):
+ index = np.argwhere(tag[b] == 1)
+ token = self.tag_list[index].squeeze(axis=1)
+ # tag_output.append(' | '.join(token))
+ tag_output.append(', '.join(token))
+ token_chinese = self.tag_list_chinese[index].squeeze(axis=1)
+ # tag_output_chinese.append(' | '.join(token_chinese))
+ tag_output_chinese.append(', '.join(token_chinese))
+
+
+ return tag_output, tag_output_chinese
+
+
+
+ def condition_forward(self,
+ image,
+ threshold=0.68,
+ condition_flag=None,
+ tag_input=None,
+ only_feature=True
+ ):
+
+ label_embed = torch.nn.functional.relu(self.wordvec_proj(self.label_embed))
+ image_embeds = self.image_proj(self.visual_encoder(image))
+
+ if only_feature:
+ return image_embeds
+ else:
+ image_atts = torch.ones(image_embeds.size()[:-1],
+ dtype=torch.long).to(image.device)
+
+ # recognized image tags using image-tag recogntiion decoder
+ image_cls_embeds = image_embeds[:, 0, :]
+ image_spatial_embeds = image_embeds[:, 1:, :]
+
+ bs = image_spatial_embeds.shape[0]
+ label_embed = label_embed.unsqueeze(0).repeat(bs, 1, 1)
+ tagging_embed = self.tagging_head(
+ encoder_embeds=label_embed,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=False,
+ mode='tagging',
+ )
+
+ logits = self.fc(tagging_embed[0]).squeeze(-1)
+
+ targets = torch.where(
+ torch.sigmoid(logits) > self.class_threshold.to(image.device),
+ torch.tensor(1.0).to(image.device),
+ torch.zeros(self.num_class).to(image.device))
+
+ return image_embeds, logits, targets
+
+ def generate_tag_openset(self,
+ image,
+ threshold=0.68,
+ tag_input=None,
+ ):
+
+ label_embed = torch.nn.functional.relu(self.wordvec_proj(self.label_embed))
+
+ image_embeds = self.image_proj(self.visual_encoder(image))
+ image_atts = torch.ones(image_embeds.size()[:-1],
+ dtype=torch.long).to(image.device)
+
+ # recognized image tags using image-tag recogntiion decoder
+ image_cls_embeds = image_embeds[:, 0, :]
+ image_spatial_embeds = image_embeds[:, 1:, :]
+
+ bs = image_spatial_embeds.shape[0]
+ label_embed = label_embed.unsqueeze(0).repeat(bs, 1, 1)
+ tagging_embed = self.tagging_head(
+ encoder_embeds=label_embed,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=False,
+ mode='tagging',
+ )
+
+ logits = self.fc(tagging_embed[0]).squeeze(-1)
+
+ targets = torch.where(
+ torch.sigmoid(logits) > self.class_threshold.to(image.device),
+ torch.tensor(1.0).to(image.device),
+ torch.zeros(self.num_class).to(image.device))
+
+ tag = targets.cpu().numpy()
+ tag[:,self.delete_tag_index] = 0
+ tag_output = []
+ for b in range(bs):
+ index = np.argwhere(tag[b] == 1)
+ token = self.tag_list[index].squeeze(axis=1)
+ tag_output.append(' | '.join(token))
+
+ return tag_output
+
+
+# load RAM pretrained model parameters
+def ram(pretrained='', pretrained_condition='', **kwargs):
+ model = RAMLora(**kwargs)
+
+ if pretrained:
+ if kwargs['vit'] == 'swin_b':
+ model, msg = load_checkpoint_swinbase(model, pretrained, kwargs)
+ elif kwargs['vit'] == 'swin_l':
+ model, msg = load_checkpoint_swinlarge(model, pretrained, kwargs)
+ else:
+ model, msg = load_checkpoint(model, pretrained)
+ print('vit:', kwargs['vit'])
+
+ if pretrained_condition:
+ model.load_state_dict(torch.load(pretrained_condition), strict=False)
+ print(f'load lora from {pretrained_condition}')
+
+ return model
diff --git a/ram/models/swin_transformer.py b/ram/models/swin_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..69bceb888697909a2eb7a22f49765a5163d8b0a0
--- /dev/null
+++ b/ram/models/swin_transformer.py
@@ -0,0 +1,696 @@
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu
+# --------------------------------------------------------
+
+import numpy as np
+from scipy import interpolate
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse(windows, window_size, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+class WindowAttention(nn.Module):
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
+ It supports both of shifted and non-shifted window.
+
+ Args:
+ dim (int): Number of input channels.
+ window_size (tuple[int]): The height and width of the window.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ """
+
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # Wh, Ww
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ # define a parameter table of relative position bias
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ trunc_normal_(self.relative_position_bias_table, std=.02)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+ """
+ Args:
+ x: input features with shape of (num_windows*B, N, C)
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+ """
+ B_, N, C = x.shape
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ def extra_repr(self) -> str:
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
+
+ def flops(self, N):
+ # calculate flops for 1 window with token length of N
+ flops = 0
+ # qkv = self.qkv(x)
+ flops += N * self.dim * 3 * self.dim
+ # attn = (q @ k.transpose(-2, -1))
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
+ # x = (attn @ v)
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
+ # x = self.proj(x)
+ flops += N * self.dim * self.dim
+ return flops
+
+
+class SwinTransformerBlock(nn.Module):
+ r""" Swin Transformer Block.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resulotion.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ if min(self.input_resolution) <= self.window_size:
+ # if window size is larger than input resolution, we don't partition windows
+ self.shift_size = 0
+ self.window_size = min(self.input_resolution)
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowAttention(
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ if self.shift_size > 0:
+ # calculate attention mask for SW-MSA
+ H, W = self.input_resolution
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
+ h_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ w_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+ else:
+ attn_mask = None
+
+ self.register_buffer("attn_mask", attn_mask)
+
+ ## condition from LR
+ self.condition_attention = nn.Sequential(
+ nn.Linear(256, dim*2, bias=False),
+ )
+ self.condition_ffn = nn.Sequential(
+ nn.Linear(256, dim*2, bias=False),
+ )
+
+ zero_module(self.condition_attention)
+ zero_module(self.condition_ffn)
+
+ def forward(self, x, condition=None):
+ H, W = self.input_resolution
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+
+ shortcut = x
+ x = self.norm1(x)
+ x = x.view(B, H, W, C)
+
+ # add condition before attention
+ # input B,H,W,C
+ if condition is not None:
+ x = x.permute(0, 3, 1, 2) # BCHW
+ condition_attention = self.condition_attention(condition).view(-1, 2*C, 1, 1)
+ condition_attn_multiplication, condition_attn_addition = condition_attention.chunk(2, dim=1)
+ x = x*condition_attn_multiplication + condition_attn_multiplication
+ x = x.permute(0, 2, 3, 1)
+
+
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ else:
+ shifted_x = x
+
+ # partition windows
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
+
+ # W-MSA/SW-MSA
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ x = shifted_x
+ x = x.view(B, H * W, C)
+
+ # FFN
+ x = shortcut + self.drop_path(x)
+ # x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+ # add condition before ffn
+ # input B,H*W,C
+ if condition is not None:
+ res = x
+ x = self.norm2(x)
+ x = x.view(B, H, W, C)
+ x = x.permute(0, 3, 1, 2) # BCHW
+ condition_ffn = self.condition_ffn(condition).view(-1, 2*C, 1, 1)
+ condition_ffn_multiplication, condition_ffn_addition = condition_ffn.chunk(2, dim=1)
+ x = x*condition_ffn_multiplication + condition_ffn_addition
+ x = x.permute(0, 2, 3, 1)
+ x = x.view(B, H*W, C)
+ x = res + self.drop_path(self.mlp(x))
+ else:
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+ def extra_repr(self) -> str:
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
+
+ def flops(self):
+ flops = 0
+ H, W = self.input_resolution
+ # norm1
+ flops += self.dim * H * W
+ # W-MSA/SW-MSA
+ nW = H * W / self.window_size / self.window_size
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
+ # mlp
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
+ # norm2
+ flops += self.dim * H * W
+ return flops
+
+
+class PatchMerging(nn.Module):
+ r""" Patch Merging Layer.
+
+ Args:
+ input_resolution (tuple[int]): Resolution of input feature.
+ dim (int): Number of input channels.
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.input_resolution = input_resolution
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(4 * dim)
+
+ def forward(self, x):
+ """
+ x: B, H*W, C
+ """
+ H, W = self.input_resolution
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
+
+ x = x.view(B, H, W, C)
+
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
+
+ x = self.norm(x)
+ x = self.reduction(x)
+
+ return x
+
+ def extra_repr(self) -> str:
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
+
+ def flops(self):
+ H, W = self.input_resolution
+ flops = H * W * self.dim
+ flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
+ return flops
+
+
+class BasicLayer(nn.Module):
+ """ A basic Swin Transformer layer for one stage.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): Local window size.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ """
+
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
+
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList([
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
+ num_heads=num_heads, window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop, attn_drop=attn_drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer)
+ for i in range(depth)])
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
+ else:
+ self.downsample = None
+
+ def forward(self, x, condition=None):
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x, condition=condition)
+ if self.downsample is not None:
+ x = self.downsample(x)
+ return x
+
+ def extra_repr(self) -> str:
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
+
+ def flops(self):
+ flops = 0
+ for blk in self.blocks:
+ flops += blk.flops()
+ if self.downsample is not None:
+ flops += self.downsample.flops()
+ return flops
+
+
+class PatchEmbed(nn.Module):
+ r""" Image to Patch Embedding
+
+ Args:
+ img_size (int): Image size. Default: 224.
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.patches_resolution = patches_resolution
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ if norm_layer is not None:
+ self.norm = norm_layer(embed_dim)
+ else:
+ self.norm = None
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ # FIXME look at relaxing size constraints
+ assert H == self.img_size[0] and W == self.img_size[1], \
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
+ if self.norm is not None:
+ x = self.norm(x)
+ return x
+
+ def flops(self):
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
+
+
+class SwinTransformer(nn.Module):
+ r""" Swin Transformer
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
+ https://arxiv.org/pdf/2103.14030
+
+ Args:
+ img_size (int | tuple(int)): Input image size. Default 224
+ patch_size (int | tuple(int)): Patch size. Default: 4
+ in_chans (int): Number of input image channels. Default: 3
+ num_classes (int): Number of classes for classification head. Default: 1000
+ embed_dim (int): Patch embedding dimension. Default: 96
+ depths (tuple(int)): Depth of each Swin Transformer layer.
+ num_heads (tuple(int)): Number of attention heads in different layers.
+ window_size (int): Window size. Default: 7
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
+ drop_rate (float): Dropout rate. Default: 0
+ attn_drop_rate (float): Attention dropout rate. Default: 0
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
+ """
+
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
+ embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
+ window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
+ use_checkpoint=False, **kwargs):
+ super().__init__()
+
+ self.num_classes = num_classes
+ self.num_layers = len(depths)
+ self.embed_dim = embed_dim
+ self.ape = ape
+ self.patch_norm = patch_norm
+ self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
+ self.mlp_ratio = mlp_ratio
+
+ # split image into non-overlapping patches
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None)
+ num_patches = self.patch_embed.num_patches
+ patches_resolution = self.patch_embed.patches_resolution
+ self.patches_resolution = patches_resolution
+
+ # absolute position embedding
+ if self.ape:
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
+ trunc_normal_(self.absolute_pos_embed, std=.02)
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
+
+ # build layers
+ self.layers = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
+ input_resolution=(patches_resolution[0] // (2 ** i_layer),
+ patches_resolution[1] // (2 ** i_layer)),
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
+ norm_layer=norm_layer,
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
+ use_checkpoint=use_checkpoint)
+ self.layers.append(layer)
+
+ self.norm = norm_layer(self.num_features)
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
+ # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'absolute_pos_embed'}
+
+ @torch.jit.ignore
+ def no_weight_decay_keywords(self):
+ return {'relative_position_bias_table'}
+
+ def forward(self, x, idx_to_group_img=None, image_atts=None, condition=None, **kwargs):
+ x = self.patch_embed(x)
+ if self.ape:
+ x = x + self.absolute_pos_embed
+ x = self.pos_drop(x)
+
+ for layer in self.layers:
+ x = layer(x, condition=condition)
+
+ x = self.norm(x) # B L C
+
+ x_cls = self.avgpool(x.transpose(1, 2)) # B C 1
+
+ if idx_to_group_img is None:
+ return torch.cat([x_cls.transpose(1, 2), x], dim=1)
+ else:
+ x_bs = torch.gather(x, dim=0, index=idx_to_group_img.view(-1, 1, 1).expand(-1, x.shape[1], x.shape[2]))
+ weights = image_atts[:, 1:].unsqueeze(2) # B L 1
+ x_bs_cls = torch.sum((weights * x_bs).transpose(1, 2), dim=-1, keepdim=True) # B C 1
+ x_bs_cls = x_bs_cls / torch.sum(weights.transpose(1, 2), dim=-1, keepdim=True) # avgpool
+
+ return torch.cat([x_bs_cls.transpose(1, 2), x_bs], dim=1), \
+ torch.cat([x_cls.transpose(1, 2), x], dim=1)
+
+ def flops(self):
+ flops = 0
+ flops += self.patch_embed.flops()
+ for i, layer in enumerate(self.layers):
+ flops += layer.flops()
+ flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
+ flops += self.num_features * self.num_classes
+ return flops
+
+
+def interpolate_relative_pos_embed(rel_pos_bias, dst_num_pos, param_name=''):
+ # from: https://github.com/microsoft/unilm/blob/8a0a1c1f4e7326938ea7580a00d56d7f17d65612/beit/run_class_finetuning.py#L348
+
+ # rel_pos_bias: relative_position_bias_table
+ src_num_pos, num_attn_heads = rel_pos_bias.size()
+
+ num_extra_tokens = 0
+ src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
+ dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
+ if src_size != dst_size:
+ print("Position interpolate %s from %dx%d to %dx%d" % (param_name, src_size, src_size, dst_size, dst_size))
+
+ # extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
+ # rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
+
+ def geometric_progression(a, r, n):
+ return a * (1.0 - r ** n) / (1.0 - r)
+
+ left, right = 1.01, 1.5
+ while right - left > 1e-6:
+ q = (left + right) / 2.0
+ gp = geometric_progression(1, q, src_size // 2)
+ if gp > dst_size // 2:
+ right = q
+ else:
+ left = q
+
+ # if q > 1.090307:
+ # q = 1.090307
+
+ dis = []
+ cur = 1
+ for i in range(src_size // 2):
+ dis.append(cur)
+ cur += q ** (i + 1)
+
+ r_ids = [-_ for _ in reversed(dis)]
+
+ x = r_ids + [0] + dis
+ y = r_ids + [0] + dis
+
+ t = dst_size // 2.0
+ dx = np.arange(-t, t + 0.1, 1.0)
+ dy = np.arange(-t, t + 0.1, 1.0)
+
+ # print("Original positions = %s" % str(x))
+ # print("Target positions = %s" % str(dx))
+
+ all_rel_pos_bias = []
+
+ for i in range(num_attn_heads):
+ z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
+ f = interpolate.interp2d(x, y, z, kind='cubic')
+ all_rel_pos_bias.append(
+ torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))
+
+ rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
+
+ return rel_pos_bias
+
+def zero_module(module):
+ for p in module.parameters():
+ nn.init.zeros_(p)
+ return module
\ No newline at end of file
diff --git a/ram/models/swin_transformer_lora.py b/ram/models/swin_transformer_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..2eb78226a2947d77dd2f7ff6afada533733e557c
--- /dev/null
+++ b/ram/models/swin_transformer_lora.py
@@ -0,0 +1,660 @@
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu
+# --------------------------------------------------------
+
+import numpy as np
+from scipy import interpolate
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+
+import loralib as lora
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ # self.fc1 = lora.Linear(in_features, hidden_features, r=16)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ # self.fc2 = lora.Linear(hidden_features, out_features, r=16)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse(windows, window_size, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+class WindowAttention(nn.Module):
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
+ It supports both of shifted and non-shifted window.
+
+ Args:
+ dim (int): Number of input channels.
+ window_size (tuple[int]): The height and width of the window.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ """
+
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # Wh, Ww
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ # define a parameter table of relative position bias
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ # lora version
+ self.qkv = lora.MergedLinear(dim, 3*dim, r=8, enable_lora=[True, False, True])
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ trunc_normal_(self.relative_position_bias_table, std=.02)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+ """
+ Args:
+ x: input features with shape of (num_windows*B, N, C)
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+ """
+ B_, N, C = x.shape
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ def extra_repr(self) -> str:
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
+
+ def flops(self, N):
+ # calculate flops for 1 window with token length of N
+ flops = 0
+ # qkv = self.qkv(x)
+ flops += N * self.dim * 3 * self.dim
+ # attn = (q @ k.transpose(-2, -1))
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
+ # x = (attn @ v)
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
+ # x = self.proj(x)
+ flops += N * self.dim * self.dim
+ return flops
+
+
+class SwinTransformerBlock(nn.Module):
+ r""" Swin Transformer Block.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resulotion.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ if min(self.input_resolution) <= self.window_size:
+ # if window size is larger than input resolution, we don't partition windows
+ self.shift_size = 0
+ self.window_size = min(self.input_resolution)
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowAttention(
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ if self.shift_size > 0:
+ # calculate attention mask for SW-MSA
+ H, W = self.input_resolution
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
+ h_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ w_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+ else:
+ attn_mask = None
+
+ self.register_buffer("attn_mask", attn_mask)
+
+ def forward(self, x):
+ H, W = self.input_resolution
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+
+ shortcut = x
+ x = self.norm1(x)
+ x = x.view(B, H, W, C)
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ else:
+ shifted_x = x
+
+ # partition windows
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
+
+ # W-MSA/SW-MSA
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ x = shifted_x
+ x = x.view(B, H * W, C)
+
+ # FFN
+ x = shortcut + self.drop_path(x)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+ return x
+
+ def extra_repr(self) -> str:
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
+
+ def flops(self):
+ flops = 0
+ H, W = self.input_resolution
+ # norm1
+ flops += self.dim * H * W
+ # W-MSA/SW-MSA
+ nW = H * W / self.window_size / self.window_size
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
+ # mlp
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
+ # norm2
+ flops += self.dim * H * W
+ return flops
+
+
+class PatchMerging(nn.Module):
+ r""" Patch Merging Layer.
+
+ Args:
+ input_resolution (tuple[int]): Resolution of input feature.
+ dim (int): Number of input channels.
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.input_resolution = input_resolution
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(4 * dim)
+
+ def forward(self, x):
+ """
+ x: B, H*W, C
+ """
+ H, W = self.input_resolution
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
+
+ x = x.view(B, H, W, C)
+
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
+
+ x = self.norm(x)
+ x = self.reduction(x)
+
+ return x
+
+ def extra_repr(self) -> str:
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
+
+ def flops(self):
+ H, W = self.input_resolution
+ flops = H * W * self.dim
+ flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
+ return flops
+
+
+class BasicLayer(nn.Module):
+ """ A basic Swin Transformer layer for one stage.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): Local window size.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ """
+
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
+
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList([
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
+ num_heads=num_heads, window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop, attn_drop=attn_drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer)
+ for i in range(depth)])
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
+ else:
+ self.downsample = None
+
+ def forward(self, x):
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+ if self.downsample is not None:
+ x = self.downsample(x)
+ return x
+
+ def extra_repr(self) -> str:
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
+
+ def flops(self):
+ flops = 0
+ for blk in self.blocks:
+ flops += blk.flops()
+ if self.downsample is not None:
+ flops += self.downsample.flops()
+ return flops
+
+
+class PatchEmbed(nn.Module):
+ r""" Image to Patch Embedding
+
+ Args:
+ img_size (int): Image size. Default: 224.
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.patches_resolution = patches_resolution
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ if norm_layer is not None:
+ self.norm = norm_layer(embed_dim)
+ else:
+ self.norm = None
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ # FIXME look at relaxing size constraints
+ assert H == self.img_size[0] and W == self.img_size[1], \
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
+ if self.norm is not None:
+ x = self.norm(x)
+ return x
+
+ def flops(self):
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
+
+
+class SwinTransformer(nn.Module):
+ r""" Swin Transformer
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
+ https://arxiv.org/pdf/2103.14030
+
+ Args:
+ img_size (int | tuple(int)): Input image size. Default 224
+ patch_size (int | tuple(int)): Patch size. Default: 4
+ in_chans (int): Number of input image channels. Default: 3
+ num_classes (int): Number of classes for classification head. Default: 1000
+ embed_dim (int): Patch embedding dimension. Default: 96
+ depths (tuple(int)): Depth of each Swin Transformer layer.
+ num_heads (tuple(int)): Number of attention heads in different layers.
+ window_size (int): Window size. Default: 7
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
+ drop_rate (float): Dropout rate. Default: 0
+ attn_drop_rate (float): Attention dropout rate. Default: 0
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
+ """
+
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
+ embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
+ window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
+ use_checkpoint=False, **kwargs):
+ super().__init__()
+
+ self.num_classes = num_classes
+ self.num_layers = len(depths)
+ self.embed_dim = embed_dim
+ self.ape = ape
+ self.patch_norm = patch_norm
+ self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
+ self.mlp_ratio = mlp_ratio
+
+ # split image into non-overlapping patches
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None)
+ num_patches = self.patch_embed.num_patches
+ patches_resolution = self.patch_embed.patches_resolution
+ self.patches_resolution = patches_resolution
+
+ # absolute position embedding
+ if self.ape:
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
+ trunc_normal_(self.absolute_pos_embed, std=.02)
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
+
+ # build layers
+ self.layers = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
+ input_resolution=(patches_resolution[0] // (2 ** i_layer),
+ patches_resolution[1] // (2 ** i_layer)),
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
+ norm_layer=norm_layer,
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
+ use_checkpoint=use_checkpoint)
+ self.layers.append(layer)
+
+ self.norm = norm_layer(self.num_features)
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
+ # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'absolute_pos_embed'}
+
+ @torch.jit.ignore
+ def no_weight_decay_keywords(self):
+ return {'relative_position_bias_table'}
+
+ def forward(self, x, idx_to_group_img=None, image_atts=None, **kwargs):
+ x = self.patch_embed(x)
+ if self.ape:
+ x = x + self.absolute_pos_embed
+ x = self.pos_drop(x)
+
+ for layer in self.layers:
+ x = layer(x)
+
+ x = self.norm(x) # B L C
+
+ x_cls = self.avgpool(x.transpose(1, 2)) # B C 1
+
+ if idx_to_group_img is None:
+ return torch.cat([x_cls.transpose(1, 2), x], dim=1)
+ else:
+ x_bs = torch.gather(x, dim=0, index=idx_to_group_img.view(-1, 1, 1).expand(-1, x.shape[1], x.shape[2]))
+ weights = image_atts[:, 1:].unsqueeze(2) # B L 1
+ x_bs_cls = torch.sum((weights * x_bs).transpose(1, 2), dim=-1, keepdim=True) # B C 1
+ x_bs_cls = x_bs_cls / torch.sum(weights.transpose(1, 2), dim=-1, keepdim=True) # avgpool
+
+ return torch.cat([x_bs_cls.transpose(1, 2), x_bs], dim=1), \
+ torch.cat([x_cls.transpose(1, 2), x], dim=1)
+
+ def flops(self):
+ flops = 0
+ flops += self.patch_embed.flops()
+ for i, layer in enumerate(self.layers):
+ flops += layer.flops()
+ flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
+ flops += self.num_features * self.num_classes
+ return flops
+
+
+def interpolate_relative_pos_embed(rel_pos_bias, dst_num_pos, param_name=''):
+ # from: https://github.com/microsoft/unilm/blob/8a0a1c1f4e7326938ea7580a00d56d7f17d65612/beit/run_class_finetuning.py#L348
+
+ # rel_pos_bias: relative_position_bias_table
+ src_num_pos, num_attn_heads = rel_pos_bias.size()
+
+ num_extra_tokens = 0
+ src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
+ dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
+ if src_size != dst_size:
+ print("Position interpolate %s from %dx%d to %dx%d" % (param_name, src_size, src_size, dst_size, dst_size))
+
+ # extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
+ # rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
+
+ def geometric_progression(a, r, n):
+ return a * (1.0 - r ** n) / (1.0 - r)
+
+ left, right = 1.01, 1.5
+ while right - left > 1e-6:
+ q = (left + right) / 2.0
+ gp = geometric_progression(1, q, src_size // 2)
+ if gp > dst_size // 2:
+ right = q
+ else:
+ left = q
+
+ # if q > 1.090307:
+ # q = 1.090307
+
+ dis = []
+ cur = 1
+ for i in range(src_size // 2):
+ dis.append(cur)
+ cur += q ** (i + 1)
+
+ r_ids = [-_ for _ in reversed(dis)]
+
+ x = r_ids + [0] + dis
+ y = r_ids + [0] + dis
+
+ t = dst_size // 2.0
+ dx = np.arange(-t, t + 0.1, 1.0)
+ dy = np.arange(-t, t + 0.1, 1.0)
+
+ # print("Original positions = %s" % str(x))
+ # print("Target positions = %s" % str(dx))
+
+ all_rel_pos_bias = []
+
+ for i in range(num_attn_heads):
+ z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
+ f = interpolate.interp2d(x, y, z, kind='cubic')
+ all_rel_pos_bias.append(
+ torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))
+
+ rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
+
+ return rel_pos_bias
\ No newline at end of file
diff --git a/ram/models/tag2text.py b/ram/models/tag2text.py
new file mode 100644
index 0000000000000000000000000000000000000000..24ca0ad5b986ee2b4ffdd02bf31ac108e46a6b15
--- /dev/null
+++ b/ram/models/tag2text.py
@@ -0,0 +1,419 @@
+'''
+ * The Tag2Text Model
+ * Written by Xinyu Huang
+'''
+import numpy as np
+import json
+import torch
+import warnings
+
+from torch import nn
+from .bert import BertConfig, BertModel, BertLMHeadModel
+from .swin_transformer import SwinTransformer
+
+from .utils import *
+
+warnings.filterwarnings("ignore")
+
+
+class Tag2Text(nn.Module):
+
+ def __init__(self,
+ med_config=f'{CONFIG_PATH}/configs/med_config.json',
+ image_size=384,
+ vit='base',
+ vit_grad_ckpt=False,
+ vit_ckpt_layer=0,
+ prompt='a picture of ',
+ threshold=0.68,
+ delete_tag_index=[127,2961, 3351, 3265, 3338, 3355, 3359],
+ tag_list=f'{CONFIG_PATH}/data/tag_list.txt'):
+ r""" Tag2Text inference module, both captioning and tagging are included.
+ Tag2Text is an efficient and controllable vision-language pre-training framework.
+ Described in the paper "Tag2Text: Guiding Vision-Language Model via Image Tagging" https://arxiv.org/abs/2303.05657
+
+ Args:
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
+ image_size (int): input image size
+ vit (str): model size of vision transformer
+ threshold (int): tagging threshold
+ delete_tag_index (list): delete some tags that may disturb captioning
+ """
+ super().__init__()
+
+ # create image encoder
+ if vit == 'swin_b':
+ if image_size == 224:
+ vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_224.json'
+ elif image_size == 384:
+ vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_384.json'
+ vision_config = read_json(vision_config_path)
+ assert image_size == vision_config['image_res']
+ # assert config['patch_size'] == 32
+ vision_width = vision_config['vision_width']
+
+ self.visual_encoder = SwinTransformer(
+ img_size=vision_config['image_res'],
+ patch_size=4,
+ in_chans=3,
+ embed_dim=vision_config['embed_dim'],
+ depths=vision_config['depths'],
+ num_heads=vision_config['num_heads'],
+ window_size=vision_config['window_size'],
+ mlp_ratio=4.,
+ qkv_bias=True,
+ drop_rate=0.0,
+ drop_path_rate=0.1,
+ ape=False,
+ patch_norm=True,
+ use_checkpoint=False)
+
+ else:
+ self.visual_encoder, vision_width = create_vit(
+ vit, image_size, vit_grad_ckpt, vit_ckpt_layer)
+
+ # create tokenzier
+ self.tokenizer = init_tokenizer()
+
+ # Tag2Text employ encoder-decoder architecture for image-tag-text generation: image-tag interaction encoder and image-tag-text decoder
+ # create image-tag interaction encoder
+ encoder_config = BertConfig.from_json_file(med_config)
+ encoder_config.encoder_width = vision_width
+ self.tag_encoder = BertModel(config=encoder_config,
+ add_pooling_layer=False)
+
+ # create image-tag-text decoder
+ decoder_config = BertConfig.from_json_file(med_config)
+ self.text_decoder = BertLMHeadModel(config=decoder_config)
+
+ # delete some tags that may disturb captioning
+ # 127: "quarter"; 2961: "back"; 3351: "two"; 3265: "three"; 3338: "four"; 3355: "five"; 3359: "one"
+ self.delete_tag_index = delete_tag_index
+ self.prompt = prompt
+ self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1
+
+ # load tag list
+ self.tag_list = self.load_tag_list(tag_list)
+
+ # create image-tag recognition decoder
+ self.threshold = threshold
+ self.num_class = len(self.tag_list)
+ q2l_config = BertConfig.from_json_file(f'{CONFIG_PATH}/configs/q2l_config.json')
+ q2l_config.encoder_width = vision_width
+ self.tagging_head = BertModel(config=q2l_config,
+ add_pooling_layer=False)
+ self.tagging_head.resize_token_embeddings(len(self.tokenizer))
+ self.label_embed = nn.Embedding(self.num_class, q2l_config.hidden_size)
+ self.fc = GroupWiseLinear(self.num_class,
+ q2l_config.hidden_size,
+ bias=True)
+ self.del_selfattention()
+
+ self.tagging_loss_function = AsymmetricLoss(gamma_neg=7,
+ gamma_pos=0,
+ clip=0.05)
+
+ # share weights of the lowest 2-layer of "image-tag interaction encoder" with the "image-tag recogntion decoder"
+ tie_encoder_decoder_weights(self.tag_encoder, self.tagging_head, '',
+ ' ')
+
+ # adjust thresholds for some tags
+ # default threshold: 0.68
+ # 2701: "person"; 2828: "man"; 1167: "woman";
+ tag_thrshold = {2701:0.7, 2828: 0.7, 1167: 0.7}
+ self.class_threshold = torch.ones(self.num_class) * self.threshold
+ for key,value in tag_thrshold.items():
+ self.class_threshold[key] = value
+
+ def load_tag_list(self, tag_list_file):
+ with open(tag_list_file, 'r') as f:
+ tag_list = f.read().splitlines()
+ tag_list = np.array(tag_list)
+ return tag_list
+
+ # delete self-attention layer of image-tag recognition decoder to reduce computation, follower Query2Label
+ def del_selfattention(self):
+ del self.tagging_head.embeddings
+ for layer in self.tagging_head.encoder.layer:
+ del layer.attention
+
+
+ def forward(self, image, caption, tag):
+ """
+ call function as forward
+
+ Args:
+ image: type: torch.Tensor shape: batch_size * 3 * 384 * 384
+ caption: type: list[string] len: batch_size
+ tag: type: torch.Tensor shape: batch * class_num (e.g. 3429) value: positive sample is 1.0, negative sample is 0.0
+
+ Returns:
+ loss: type: torch.Tensor
+ """
+
+ image_embeds = self.visual_encoder(image)
+ image_atts = torch.ones(image_embeds.size()[:-1],
+ dtype=torch.long).to(image.device)
+
+ ##================= Image Tagging ================##
+ bs = image_embeds.shape[0]
+ label_embed = self.label_embed.weight.unsqueeze(0).repeat(bs, 1, 1)
+
+ tagging_embed = self.tagging_head(
+ encoder_embeds=label_embed,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=False,
+ mode='tagging',
+ )
+
+ logits = self.fc(tagging_embed[0])
+
+ loss_tag = self.tagging_loss_function(logits, tag)
+
+ ##================= Image-Tag-Text Generation ================##
+ tag = tag.cpu().numpy()
+ tag_input = []
+ for b in range(bs):
+ index = np.argwhere(tag[b] == 1)
+ token = self.tag_list[index].squeeze(axis=1)
+ tag_input.append(' | '.join(token))
+
+ # tokenizer input tags
+ tag_input_tokenzier = self.tokenizer(tag_input,
+ padding='max_length',
+ truncation=True,
+ max_length=40,
+ return_tensors="pt").to(
+ image.device)
+ encoder_input_ids = tag_input_tokenzier.input_ids
+ encoder_input_ids[:, 0] = self.tokenizer.enc_token_id
+
+ # put input tag into image-tag interaction encoder to interact with image embeddings
+ output_tagembedding = self.tag_encoder(
+ encoder_input_ids,
+ attention_mask=tag_input_tokenzier.attention_mask,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=True,
+ )
+
+ text = self.tokenizer(caption,
+ padding='longest',
+ truncation=True,
+ max_length=40,
+ return_tensors="pt").to(
+ image.device)
+
+ decoder_input_ids = text.input_ids
+ decoder_input_ids[:,0] = self.tokenizer.bos_token_id
+
+ decoder_targets = decoder_input_ids.masked_fill(
+ decoder_input_ids == self.tokenizer.pad_token_id, -100)
+ decoder_targets[:,:self.prompt_length] = -100
+
+ decoder_output = self.text_decoder(decoder_input_ids,
+ attention_mask = text.attention_mask,
+ encoder_hidden_states = output_tagembedding.last_hidden_state,
+ encoder_attention_mask = None,
+ labels = decoder_targets,
+ return_dict = True,
+ )
+
+ loss_t2t = decoder_output.loss
+
+ # balance loss scale
+ loss = loss_t2t + loss_tag/(loss_tag/loss_t2t).detach()
+
+ return loss
+
+ def generate_image_embeds(self,
+ image,
+ condition=False
+ ):
+
+ image_embeds = self.visual_encoder(image)
+
+ return image_embeds
+
+ def condition_forward(self,
+ image,
+ sample=False,
+ num_beams=3,
+ max_length=30,
+ min_length=10,
+ top_p=0.9,
+ repetition_penalty=1.0,
+ tag_input=None,
+ return_tag_predict=False):
+
+ image_embeds = self.visual_encoder(image)
+ image_atts = torch.ones(image_embeds.size()[:-1],
+ dtype=torch.long).to(image.device)
+
+ # if not user specified tags, recognized image tags using image-tag recogntiion decoder
+
+
+ bs = image_embeds.shape[0]
+ label_embed = self.label_embed.weight.unsqueeze(0).repeat(bs, 1, 1)
+ tagging_embed = self.tagging_head(
+ encoder_embeds=label_embed,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=False,
+ mode='tagging',
+ )
+
+ logits = self.fc(tagging_embed[0])
+
+ targets = torch.where(
+ torch.sigmoid(logits) > self.class_threshold.to(image.device),
+ torch.tensor(1.0).to(image.device),
+ torch.zeros(self.num_class).to(image.device))
+
+ # delete some tags that may disturb captioning
+ targets[:, self.delete_tag_index] = 0
+
+ return image_embeds, logits, targets
+
+
+ def generate(self,
+ image,
+ sample=False,
+ num_beams=3,
+ max_length=30,
+ min_length=10,
+ top_p=0.9,
+ repetition_penalty=1.0,
+ tag_input=None,
+ return_tag_predict=False):
+
+ image_embeds = self.visual_encoder(image)
+ image_atts = torch.ones(image_embeds.size()[:-1],
+ dtype=torch.long).to(image.device)
+
+ # if not user specified tags, recognized image tags using image-tag recogntiion decoder
+ if tag_input == None:
+
+ bs = image_embeds.shape[0]
+ label_embed = self.label_embed.weight.unsqueeze(0).repeat(bs, 1, 1)
+ tagging_embed = self.tagging_head(
+ encoder_embeds=label_embed,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=False,
+ mode='tagging',
+ )
+
+ logits = self.fc(tagging_embed[0])
+
+ targets = torch.where(
+ torch.sigmoid(logits) > self.class_threshold.to(image.device),
+ torch.tensor(1.0).to(image.device),
+ torch.zeros(self.num_class).to(image.device))
+
+ tag = targets.cpu().numpy()
+
+ # delete some tags that may disturb captioning
+ tag[:, self.delete_tag_index] = 0
+
+ tag_input = []
+ for b in range(bs):
+ index = np.argwhere(tag[b] == 1)
+ token = self.tag_list[index].squeeze(axis=1)
+ tag_input.append(', '.join(token))
+
+ tag_output = tag_input
+
+ # beam search for text generation(default)
+ if not sample:
+ image_embeds = image_embeds.repeat_interleave(num_beams, dim=0)
+ tag_input_temp = []
+ for tag in tag_input:
+ for i in range(num_beams):
+ tag_input_temp.append(tag)
+ tag_input = tag_input_temp
+
+ image_atts = torch.ones(image_embeds.size()[:-1],
+ dtype=torch.long).to(image.device)
+
+ # tokenizer input tags
+ tag_input_tokenzier = self.tokenizer(tag_input,
+ padding='max_length',
+ truncation=True,
+ max_length=40,
+ return_tensors="pt").to(
+ image.device)
+ encoder_input_ids = tag_input_tokenzier.input_ids
+ encoder_input_ids[:, 0] = self.tokenizer.enc_token_id
+
+ # put input tag into image-tag interaction encoder to interact with image embeddings
+ output_tagembedding = self.tag_encoder(
+ encoder_input_ids,
+ attention_mask=tag_input_tokenzier.attention_mask,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=True,
+ )
+
+ # prompt trick for better captioning, followed BLIP
+ prompt = [self.prompt] * image.size(0)
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(
+ image.device)
+ input_ids[:, 0] = self.tokenizer.bos_token_id
+ input_ids = input_ids[:, :-1]
+
+ if sample:
+ # nucleus sampling
+ model_kwargs = {
+ "encoder_hidden_states": output_tagembedding.last_hidden_state,
+ "encoder_attention_mask": None
+ }
+ outputs = self.text_decoder.generate(
+ input_ids=input_ids,
+ max_length=max_length,
+ min_length=min_length,
+ do_sample=True,
+ top_p=top_p,
+ num_return_sequences=1,
+ eos_token_id=self.tokenizer.sep_token_id,
+ pad_token_id=self.tokenizer.pad_token_id,
+ repetition_penalty=1.1,
+ **model_kwargs)
+ else:
+ # beam search (default)
+ model_kwargs = {
+ "encoder_hidden_states": output_tagembedding.last_hidden_state,
+ "encoder_attention_mask": None
+ }
+ outputs = self.text_decoder.generate(
+ input_ids=input_ids,
+ max_length=max_length,
+ min_length=min_length,
+ num_beams=num_beams,
+ eos_token_id=self.tokenizer.sep_token_id,
+ pad_token_id=self.tokenizer.pad_token_id,
+ repetition_penalty=repetition_penalty,
+ **model_kwargs)
+
+ captions = []
+ for output in outputs:
+ caption = self.tokenizer.decode(output, skip_special_tokens=True)
+ captions.append(caption[len(self.prompt):])
+ if return_tag_predict == True:
+ return captions, tag_output
+ return captions
+
+
+# load Tag2Text pretrained model parameters
+def tag2text(pretrained='', **kwargs):
+ model = Tag2Text(**kwargs)
+ if pretrained:
+ if kwargs['vit'] == 'swin_b':
+ model, msg = load_checkpoint_swinbase(model, pretrained, kwargs)
+ else:
+ model, msg = load_checkpoint(model, pretrained)
+ print('vit:', kwargs['vit'])
+# print('msg', msg)
+ return model
+
diff --git a/ram/models/tag2text_lora.py b/ram/models/tag2text_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..66230e27978966d0d64c8b14ea58f643354ea71d
--- /dev/null
+++ b/ram/models/tag2text_lora.py
@@ -0,0 +1,419 @@
+'''
+ * The Tag2Text Model
+ * Written by Xinyu Huang
+'''
+import numpy as np
+import json
+import torch
+import warnings
+
+from torch import nn
+from .bert_lora import BertConfig, BertModel, BertLMHeadModel
+from .swin_transformer_lora import SwinTransformer
+
+from .utils import *
+
+warnings.filterwarnings("ignore")
+
+
+class Tag2Text(nn.Module):
+
+ def __init__(self,
+ med_config=f'{CONFIG_PATH}/configs/med_config.json',
+ image_size=384,
+ vit='base',
+ vit_grad_ckpt=False,
+ vit_ckpt_layer=0,
+ prompt='a picture of ',
+ threshold=0.68,
+ delete_tag_index=[127,2961, 3351, 3265, 3338, 3355, 3359],
+ tag_list=f'{CONFIG_PATH}/data/tag_list.txt'):
+ r""" Tag2Text inference module, both captioning and tagging are included.
+ Tag2Text is an efficient and controllable vision-language pre-training framework.
+ Described in the paper "Tag2Text: Guiding Vision-Language Model via Image Tagging" https://arxiv.org/abs/2303.05657
+
+ Args:
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
+ image_size (int): input image size
+ vit (str): model size of vision transformer
+ threshold (int): tagging threshold
+ delete_tag_index (list): delete some tags that may disturb captioning
+ """
+ super().__init__()
+
+ # create image encoder
+ if vit == 'swin_b':
+ if image_size == 224:
+ vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_224.json'
+ elif image_size == 384:
+ vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_384.json'
+ vision_config = read_json(vision_config_path)
+ assert image_size == vision_config['image_res']
+ # assert config['patch_size'] == 32
+ vision_width = vision_config['vision_width']
+
+ self.visual_encoder = SwinTransformer(
+ img_size=vision_config['image_res'],
+ patch_size=4,
+ in_chans=3,
+ embed_dim=vision_config['embed_dim'],
+ depths=vision_config['depths'],
+ num_heads=vision_config['num_heads'],
+ window_size=vision_config['window_size'],
+ mlp_ratio=4.,
+ qkv_bias=True,
+ drop_rate=0.0,
+ drop_path_rate=0.1,
+ ape=False,
+ patch_norm=True,
+ use_checkpoint=False)
+
+ else:
+ self.visual_encoder, vision_width = create_vit(
+ vit, image_size, vit_grad_ckpt, vit_ckpt_layer)
+
+ # create tokenzier
+ self.tokenizer = init_tokenizer()
+
+ # Tag2Text employ encoder-decoder architecture for image-tag-text generation: image-tag interaction encoder and image-tag-text decoder
+ # create image-tag interaction encoder
+ encoder_config = BertConfig.from_json_file(med_config)
+ encoder_config.encoder_width = vision_width
+ self.tag_encoder = BertModel(config=encoder_config,
+ add_pooling_layer=False)
+
+ # create image-tag-text decoder
+ decoder_config = BertConfig.from_json_file(med_config)
+ self.text_decoder = BertLMHeadModel(config=decoder_config)
+
+ # delete some tags that may disturb captioning
+ # 127: "quarter"; 2961: "back"; 3351: "two"; 3265: "three"; 3338: "four"; 3355: "five"; 3359: "one"
+ self.delete_tag_index = delete_tag_index
+ self.prompt = prompt
+ self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1
+
+ # load tag list
+ self.tag_list = self.load_tag_list(tag_list)
+
+ # create image-tag recognition decoder
+ self.threshold = threshold
+ self.num_class = len(self.tag_list)
+ q2l_config = BertConfig.from_json_file(f'{CONFIG_PATH}/configs/q2l_config.json')
+ q2l_config.encoder_width = vision_width
+ self.tagging_head = BertModel(config=q2l_config,
+ add_pooling_layer=False)
+ self.tagging_head.resize_token_embeddings(len(self.tokenizer))
+ self.label_embed = nn.Embedding(self.num_class, q2l_config.hidden_size)
+ self.fc = GroupWiseLinear(self.num_class,
+ q2l_config.hidden_size,
+ bias=True)
+ self.del_selfattention()
+
+ self.tagging_loss_function = AsymmetricLoss(gamma_neg=7,
+ gamma_pos=0,
+ clip=0.05)
+
+ # share weights of the lowest 2-layer of "image-tag interaction encoder" with the "image-tag recogntion decoder"
+ tie_encoder_decoder_weights(self.tag_encoder, self.tagging_head, '',
+ ' ')
+
+ # adjust thresholds for some tags
+ # default threshold: 0.68
+ # 2701: "person"; 2828: "man"; 1167: "woman";
+ tag_thrshold = {2701:0.7, 2828: 0.7, 1167: 0.7}
+ self.class_threshold = torch.ones(self.num_class) * self.threshold
+ for key,value in tag_thrshold.items():
+ self.class_threshold[key] = value
+
+ def load_tag_list(self, tag_list_file):
+ with open(tag_list_file, 'r') as f:
+ tag_list = f.read().splitlines()
+ tag_list = np.array(tag_list)
+ return tag_list
+
+ # delete self-attention layer of image-tag recognition decoder to reduce computation, follower Query2Label
+ def del_selfattention(self):
+ del self.tagging_head.embeddings
+ for layer in self.tagging_head.encoder.layer:
+ del layer.attention
+
+
+ def forward(self, image, caption, tag):
+ """
+ call function as forward
+
+ Args:
+ image: type: torch.Tensor shape: batch_size * 3 * 384 * 384
+ caption: type: list[string] len: batch_size
+ tag: type: torch.Tensor shape: batch * class_num (e.g. 3429) value: positive sample is 1.0, negative sample is 0.0
+
+ Returns:
+ loss: type: torch.Tensor
+ """
+
+ image_embeds = self.visual_encoder(image)
+ image_atts = torch.ones(image_embeds.size()[:-1],
+ dtype=torch.long).to(image.device)
+
+ ##================= Image Tagging ================##
+ bs = image_embeds.shape[0]
+ label_embed = self.label_embed.weight.unsqueeze(0).repeat(bs, 1, 1)
+
+ tagging_embed = self.tagging_head(
+ encoder_embeds=label_embed,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=False,
+ mode='tagging',
+ )
+
+ logits = self.fc(tagging_embed[0])
+
+ loss_tag = self.tagging_loss_function(logits, tag)
+
+ ##================= Image-Tag-Text Generation ================##
+ tag = tag.cpu().numpy()
+ tag_input = []
+ for b in range(bs):
+ index = np.argwhere(tag[b] == 1)
+ token = self.tag_list[index].squeeze(axis=1)
+ tag_input.append(' | '.join(token))
+
+ # tokenizer input tags
+ tag_input_tokenzier = self.tokenizer(tag_input,
+ padding='max_length',
+ truncation=True,
+ max_length=40,
+ return_tensors="pt").to(
+ image.device)
+ encoder_input_ids = tag_input_tokenzier.input_ids
+ encoder_input_ids[:, 0] = self.tokenizer.enc_token_id
+
+ # put input tag into image-tag interaction encoder to interact with image embeddings
+ output_tagembedding = self.tag_encoder(
+ encoder_input_ids,
+ attention_mask=tag_input_tokenzier.attention_mask,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=True,
+ )
+
+ text = self.tokenizer(caption,
+ padding='longest',
+ truncation=True,
+ max_length=40,
+ return_tensors="pt").to(
+ image.device)
+
+ decoder_input_ids = text.input_ids
+ decoder_input_ids[:,0] = self.tokenizer.bos_token_id
+
+ decoder_targets = decoder_input_ids.masked_fill(
+ decoder_input_ids == self.tokenizer.pad_token_id, -100)
+ decoder_targets[:,:self.prompt_length] = -100
+
+ decoder_output = self.text_decoder(decoder_input_ids,
+ attention_mask = text.attention_mask,
+ encoder_hidden_states = output_tagembedding.last_hidden_state,
+ encoder_attention_mask = None,
+ labels = decoder_targets,
+ return_dict = True,
+ )
+
+ loss_t2t = decoder_output.loss
+
+ # balance loss scale
+ loss = loss_t2t + loss_tag/(loss_tag/loss_t2t).detach()
+
+ return loss
+
+ def generate_image_embeds(self,
+ image,
+ condition=False
+ ):
+
+ image_embeds = self.visual_encoder(image)
+
+ return image_embeds
+
+ def condition_forward(self,
+ image,
+ sample=False,
+ num_beams=3,
+ max_length=30,
+ min_length=10,
+ top_p=0.9,
+ repetition_penalty=1.0,
+ tag_input=None,
+ return_tag_predict=False):
+
+ image_embeds = self.visual_encoder(image)
+ image_atts = torch.ones(image_embeds.size()[:-1],
+ dtype=torch.long).to(image.device)
+
+ # if not user specified tags, recognized image tags using image-tag recogntiion decoder
+
+
+ bs = image_embeds.shape[0]
+ label_embed = self.label_embed.weight.unsqueeze(0).repeat(bs, 1, 1)
+ tagging_embed = self.tagging_head(
+ encoder_embeds=label_embed,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=False,
+ mode='tagging',
+ )
+
+ logits = self.fc(tagging_embed[0])
+
+ targets = torch.where(
+ torch.sigmoid(logits) > self.class_threshold.to(image.device),
+ torch.tensor(1.0).to(image.device),
+ torch.zeros(self.num_class).to(image.device))
+
+ # delete some tags that may disturb captioning
+ targets[:, self.delete_tag_index] = 0
+
+ return image_embeds, logits, targets
+
+
+ def generate(self,
+ image,
+ sample=False,
+ num_beams=3,
+ max_length=30,
+ min_length=10,
+ top_p=0.9,
+ repetition_penalty=1.0,
+ tag_input=None,
+ return_tag_predict=False):
+
+ image_embeds = self.visual_encoder(image)
+ image_atts = torch.ones(image_embeds.size()[:-1],
+ dtype=torch.long).to(image.device)
+
+ # if not user specified tags, recognized image tags using image-tag recogntiion decoder
+ if tag_input == None:
+
+ bs = image_embeds.shape[0]
+ label_embed = self.label_embed.weight.unsqueeze(0).repeat(bs, 1, 1)
+ tagging_embed = self.tagging_head(
+ encoder_embeds=label_embed,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=False,
+ mode='tagging',
+ )
+
+ logits = self.fc(tagging_embed[0])
+
+ targets = torch.where(
+ torch.sigmoid(logits) > self.class_threshold.to(image.device),
+ torch.tensor(1.0).to(image.device),
+ torch.zeros(self.num_class).to(image.device))
+
+ tag = targets.cpu().numpy()
+
+ # delete some tags that may disturb captioning
+ tag[:, self.delete_tag_index] = 0
+
+ tag_input = []
+ for b in range(bs):
+ index = np.argwhere(tag[b] == 1)
+ token = self.tag_list[index].squeeze(axis=1)
+ tag_input.append(', '.join(token))
+
+ tag_output = tag_input
+
+ # beam search for text generation(default)
+ if not sample:
+ image_embeds = image_embeds.repeat_interleave(num_beams, dim=0)
+ tag_input_temp = []
+ for tag in tag_input:
+ for i in range(num_beams):
+ tag_input_temp.append(tag)
+ tag_input = tag_input_temp
+
+ image_atts = torch.ones(image_embeds.size()[:-1],
+ dtype=torch.long).to(image.device)
+
+ # tokenizer input tags
+ tag_input_tokenzier = self.tokenizer(tag_input,
+ padding='max_length',
+ truncation=True,
+ max_length=40,
+ return_tensors="pt").to(
+ image.device)
+ encoder_input_ids = tag_input_tokenzier.input_ids
+ encoder_input_ids[:, 0] = self.tokenizer.enc_token_id
+
+ # put input tag into image-tag interaction encoder to interact with image embeddings
+ output_tagembedding = self.tag_encoder(
+ encoder_input_ids,
+ attention_mask=tag_input_tokenzier.attention_mask,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=True,
+ )
+
+ # prompt trick for better captioning, followed BLIP
+ prompt = [self.prompt] * image.size(0)
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(
+ image.device)
+ input_ids[:, 0] = self.tokenizer.bos_token_id
+ input_ids = input_ids[:, :-1]
+
+ if sample:
+ # nucleus sampling
+ model_kwargs = {
+ "encoder_hidden_states": output_tagembedding.last_hidden_state,
+ "encoder_attention_mask": None
+ }
+ outputs = self.text_decoder.generate(
+ input_ids=input_ids,
+ max_length=max_length,
+ min_length=min_length,
+ do_sample=True,
+ top_p=top_p,
+ num_return_sequences=1,
+ eos_token_id=self.tokenizer.sep_token_id,
+ pad_token_id=self.tokenizer.pad_token_id,
+ repetition_penalty=1.1,
+ **model_kwargs)
+ else:
+ # beam search (default)
+ model_kwargs = {
+ "encoder_hidden_states": output_tagembedding.last_hidden_state,
+ "encoder_attention_mask": None
+ }
+ outputs = self.text_decoder.generate(
+ input_ids=input_ids,
+ max_length=max_length,
+ min_length=min_length,
+ num_beams=num_beams,
+ eos_token_id=self.tokenizer.sep_token_id,
+ pad_token_id=self.tokenizer.pad_token_id,
+ repetition_penalty=repetition_penalty,
+ **model_kwargs)
+
+ captions = []
+ for output in outputs:
+ caption = self.tokenizer.decode(output, skip_special_tokens=True)
+ captions.append(caption[len(self.prompt):])
+ if return_tag_predict == True:
+ return captions, tag_output
+ return captions
+
+
+# load Tag2Text pretrained model parameters
+def tag2text(pretrained='', **kwargs):
+ model = Tag2Text(**kwargs)
+ if pretrained:
+ if kwargs['vit'] == 'swin_b':
+ model, msg = load_checkpoint_swinbase(model, pretrained, kwargs)
+ else:
+ model, msg = load_checkpoint(model, pretrained)
+ print('vit:', kwargs['vit'])
+# print('msg', msg)
+ return model
+
diff --git a/ram/models/utils.py b/ram/models/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2adbafc02fabef27cbd6105c6588fca711f8d90b
--- /dev/null
+++ b/ram/models/utils.py
@@ -0,0 +1,364 @@
+import os
+import json
+import torch
+import math
+
+from torch import nn
+from typing import List
+from transformers import BertTokenizer
+from urllib.parse import urlparse
+from timm.models.hub import download_cached_file
+from .vit import interpolate_pos_embed
+from .swin_transformer import interpolate_relative_pos_embed
+from pathlib import Path
+CONFIG_PATH=(Path(__file__).resolve().parents[1])
+
+def read_json(rpath):
+ with open(rpath, 'r') as f:
+ return json.load(f)
+
+
+def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module,
+ base_model_prefix: str, skip_key: str):
+ uninitialized_encoder_weights: List[str] = []
+ if decoder.__class__ != encoder.__class__:
+ logger.info(
+ f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized."
+ )
+
+ def tie_encoder_to_decoder_recursively(
+ decoder_pointer: nn.Module,
+ encoder_pointer: nn.Module,
+ module_name: str,
+ uninitialized_encoder_weights: List[str],
+ skip_key: str,
+ depth=0,
+ ):
+ assert isinstance(decoder_pointer, nn.Module) and isinstance(
+ encoder_pointer, nn.Module
+ ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module"
+ if hasattr(decoder_pointer, "weight") and skip_key not in module_name:
+ assert hasattr(encoder_pointer, "weight")
+ encoder_pointer.weight = decoder_pointer.weight
+ if hasattr(decoder_pointer, "bias"):
+ assert hasattr(encoder_pointer, "bias")
+ encoder_pointer.bias = decoder_pointer.bias
+ print(module_name + ' is tied')
+ return
+
+ encoder_modules = encoder_pointer._modules
+ decoder_modules = decoder_pointer._modules
+ if len(decoder_modules) > 0:
+ assert (
+ len(encoder_modules) > 0
+ ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
+
+ all_encoder_weights = set([
+ module_name + "/" + sub_name
+ for sub_name in encoder_modules.keys()
+ ])
+ encoder_layer_pos = 0
+ for name, module in decoder_modules.items():
+ if name.isdigit():
+ encoder_name = str(int(name) + encoder_layer_pos)
+ decoder_name = name
+ if not isinstance(
+ decoder_modules[decoder_name],
+ type(encoder_modules[encoder_name])) and len(
+ encoder_modules) != len(decoder_modules):
+ # this can happen if the name corresponds to the position in a list module list of layers
+ # in this case the decoder has added a cross-attention that the encoder does not have
+ # thus skip this step and subtract one layer pos from encoder
+ encoder_layer_pos -= 1
+ continue
+ elif name not in encoder_modules:
+ continue
+ elif depth > 500:
+ raise ValueError(
+ "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model."
+ )
+ else:
+ decoder_name = encoder_name = name
+ tie_encoder_to_decoder_recursively(
+ decoder_modules[decoder_name],
+ encoder_modules[encoder_name],
+ module_name + "/" + name,
+ uninitialized_encoder_weights,
+ skip_key,
+ depth=depth + 1,
+ )
+ all_encoder_weights.remove(module_name + "/" + encoder_name)
+
+ uninitialized_encoder_weights += list(all_encoder_weights)
+
+ # tie weights recursively
+ tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix,
+ uninitialized_encoder_weights, skip_key)
+
+
+class GroupWiseLinear(nn.Module):
+ # could be changed to:
+ # output = torch.einsum('ijk,zjk->ij', x, self.W)
+ # or output = torch.einsum('ijk,jk->ij', x, self.W[0])
+ def __init__(self, num_class, hidden_dim, bias=True):
+ super().__init__()
+ self.num_class = num_class
+ self.hidden_dim = hidden_dim
+ self.bias = bias
+
+ self.W = nn.Parameter(torch.Tensor(1, num_class, hidden_dim))
+ if bias:
+ self.b = nn.Parameter(torch.Tensor(1, num_class))
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ stdv = 1. / math.sqrt(self.W.size(2))
+ for i in range(self.num_class):
+ self.W[0][i].data.uniform_(-stdv, stdv)
+ if self.bias:
+ for i in range(self.num_class):
+ self.b[0][i].data.uniform_(-stdv, stdv)
+
+ def forward(self, x):
+ # x: B,K,d
+ x = (self.W * x).sum(-1)
+ if self.bias:
+ x = x + self.b
+ return x
+
+
+def init_tokenizer():
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
+ tokenizer.add_special_tokens({'bos_token': '[DEC]'})
+ tokenizer.add_special_tokens({'additional_special_tokens': ['[ENC]']})
+ tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
+ return tokenizer
+
+
+def create_vit(vit,
+ image_size,
+ use_grad_checkpointing=False,
+ ckpt_layer=0,
+ drop_path_rate=0):
+
+ assert vit in ['base', 'large'], "vit parameter must be base or large"
+ if vit == 'base':
+ vision_width = 768
+ visual_encoder = VisionTransformer(
+ img_size=image_size,
+ patch_size=16,
+ embed_dim=vision_width,
+ depth=12,
+ num_heads=12,
+ use_grad_checkpointing=use_grad_checkpointing,
+ ckpt_layer=ckpt_layer,
+ drop_path_rate=0 or drop_path_rate)
+ elif vit == 'large':
+ vision_width = 1024
+ visual_encoder = VisionTransformer(
+ img_size=image_size,
+ patch_size=16,
+ embed_dim=vision_width,
+ depth=24,
+ num_heads=16,
+ use_grad_checkpointing=use_grad_checkpointing,
+ ckpt_layer=ckpt_layer,
+ drop_path_rate=0.1 or drop_path_rate)
+ return visual_encoder, vision_width
+
+
+def is_url(url_or_filename):
+ parsed = urlparse(url_or_filename)
+ return parsed.scheme in ("http", "https")
+
+
+def load_checkpoint(model, url_or_filename):
+ if is_url(url_or_filename):
+ cached_file = download_cached_file(url_or_filename,
+ check_hash=False,
+ progress=True)
+ checkpoint = torch.load(cached_file, map_location='cpu')
+ elif os.path.isfile(url_or_filename):
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
+ else:
+ raise RuntimeError('checkpoint url or path is invalid')
+
+ state_dict = checkpoint['model']
+
+ state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(
+ state_dict['visual_encoder.pos_embed'], model.visual_encoder)
+ if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
+ state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(
+ state_dict['visual_encoder_m.pos_embed'], model.visual_encoder_m)
+ for key in model.state_dict().keys():
+ if key in state_dict.keys():
+ if state_dict[key].shape != model.state_dict()[key].shape:
+ del state_dict[key]
+
+ msg = model.load_state_dict(state_dict, strict=False)
+ print('load checkpoint from %s' % url_or_filename)
+ return model, msg
+
+# def load_checkpoint_condition(model, url_or_filename):
+def load_checkpoint_swinlarge_condition(model, url_or_filename, kwargs):
+ if kwargs['image_size'] == 224:
+ vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_224.json'
+ elif kwargs['image_size'] == 384:
+ vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_384.json'
+ window_size = read_json(vision_config_path)['window_size']
+ print('--------------')
+ print(url_or_filename)
+ print('--------------')
+ if is_url(url_or_filename):
+ cached_file = download_cached_file(url_or_filename,
+ check_hash=False,
+ progress=True)
+ checkpoint = torch.load(cached_file, map_location='cpu')
+ elif os.path.isfile(url_or_filename):
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
+ else:
+ raise RuntimeError('checkpoint url or path is invalid')
+
+ state_dict = checkpoint['params']
+
+ for k in list(state_dict.keys()):
+ if 'relative_position_bias_table' in k:
+ dst_num_pos = (2 * window_size - 1)**2
+ state_dict[k] = interpolate_relative_pos_embed(state_dict[k],
+ dst_num_pos,
+ param_name=k)
+ elif ('relative_position_index' in k) or ('attn_mask' in k):
+ del state_dict[k]
+ elif "vision_multi" in k:
+ state_dict[k.replace("vision_multi",
+ "tagging_head")] = state_dict.pop(k)
+
+ msg = model.load_state_dict(state_dict, strict=False)
+ print('load checkpoint from %s' % url_or_filename)
+ return model, msg
+
+
+def load_checkpoint_swinbase(model, url_or_filename, kwargs):
+ if kwargs['image_size'] == 224:
+ vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_224.json'
+ elif kwargs['image_size'] == 384:
+ vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_384.json'
+ window_size = read_json(vision_config_path)['window_size']
+ print('--------------')
+ print(url_or_filename)
+ print('--------------')
+ if is_url(url_or_filename):
+ cached_file = download_cached_file(url_or_filename,
+ check_hash=False,
+ progress=True)
+ checkpoint = torch.load(cached_file, map_location='cpu')
+ elif os.path.isfile(url_or_filename):
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
+ else:
+ raise RuntimeError('checkpoint url or path is invalid')
+
+ state_dict = checkpoint['model']
+
+ for k in list(state_dict.keys()):
+ if 'relative_position_bias_table' in k:
+ dst_num_pos = (2 * window_size - 1)**2
+ state_dict[k] = interpolate_relative_pos_embed(state_dict[k],
+ dst_num_pos,
+ param_name=k)
+ elif ('relative_position_index' in k) or ('attn_mask' in k):
+ del state_dict[k]
+ elif "vision_multi" in k:
+ state_dict[k.replace("vision_multi",
+ "tagging_head")] = state_dict.pop(k)
+
+ msg = model.load_state_dict(state_dict, strict=False)
+ print('load checkpoint from %s' % url_or_filename)
+ return model, msg
+
+
+def load_checkpoint_swinlarge(model, url_or_filename, kwargs):
+ if kwargs['image_size'] == 224:
+ vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_224.json'
+ elif kwargs['image_size'] == 384:
+ vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_384.json'
+ window_size = read_json(vision_config_path)['window_size']
+ print('--------------')
+ print(url_or_filename)
+ print('--------------')
+ if is_url(url_or_filename):
+ cached_file = download_cached_file(url_or_filename,
+ check_hash=False,
+ progress=True)
+ checkpoint = torch.load(cached_file, map_location='cpu')
+ elif os.path.isfile(url_or_filename):
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
+ else:
+ raise RuntimeError('checkpoint url or path is invalid')
+
+ state_dict = checkpoint['model']
+
+ for k in list(state_dict.keys()):
+ if 'relative_position_bias_table' in k:
+ dst_num_pos = (2 * window_size - 1)**2
+ state_dict[k] = interpolate_relative_pos_embed(state_dict[k],
+ dst_num_pos,
+ param_name=k)
+ elif ('relative_position_index' in k) or ('attn_mask' in k):
+ del state_dict[k]
+ elif "vision_multi" in k:
+ state_dict[k.replace("vision_multi",
+ "tagging_head")] = state_dict.pop(k)
+
+ msg = model.load_state_dict(state_dict, strict=False)
+ print('load checkpoint from %s' % url_or_filename)
+ return model, msg
+
+
+# Tagging loss function
+# copy from https://github.com/Alibaba-MIIL/ASL/blob/main/src/loss_functions/losses.py
+class AsymmetricLoss(nn.Module):
+ def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=True):
+ super(AsymmetricLoss, self).__init__()
+
+ self.gamma_neg = gamma_neg
+ self.gamma_pos = gamma_pos
+ self.clip = clip
+ self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
+ self.eps = eps
+
+ def forward(self, x, y):
+ """"
+ Parameters
+ ----------
+ x: input logits
+ y: targets (multi-label binarized vector)
+ """
+
+ # Calculating Probabilities
+ x_sigmoid = torch.sigmoid(x)
+ xs_pos = x_sigmoid
+ xs_neg = 1 - x_sigmoid
+
+ # Asymmetric Clipping
+ if self.clip is not None and self.clip > 0:
+ xs_neg = (xs_neg + self.clip).clamp(max=1)
+
+ # Basic CE calculation
+ los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
+ los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
+ loss = los_pos + los_neg
+
+ # Asymmetric Focusing
+ if self.gamma_neg > 0 or self.gamma_pos > 0:
+ if self.disable_torch_grad_focal_loss:
+ torch.set_grad_enabled(False)
+ pt0 = xs_pos * y
+ pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
+ pt = pt0 + pt1
+ one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
+ one_sided_w = torch.pow(1 - pt, one_sided_gamma)
+ if self.disable_torch_grad_focal_loss:
+ torch.set_grad_enabled(True)
+ loss *= one_sided_w
+
+ return -loss.sum()
\ No newline at end of file
diff --git a/ram/models/vit.py b/ram/models/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..cec3d8e08ed4451d65392feb2e9f4848d1ef3899
--- /dev/null
+++ b/ram/models/vit.py
@@ -0,0 +1,305 @@
+'''
+ * Copyright (c) 2022, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+ * Based on timm code base
+ * https://github.com/rwightman/pytorch-image-models/tree/master/timm
+'''
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import partial
+
+from timm.models.vision_transformer import _cfg, PatchEmbed
+from timm.models.registry import register_model
+from timm.models.layers import trunc_normal_, DropPath
+from timm.models.helpers import named_apply, adapt_input_conv
+
+from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
+
+class Mlp(nn.Module):
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
+ """
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
+ self.scale = qk_scale or head_dim ** -0.5
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.attn_gradients = None
+ self.attention_map = None
+
+ def save_attn_gradients(self, attn_gradients):
+ self.attn_gradients = attn_gradients
+
+ def get_attn_gradients(self):
+ return self.attn_gradients
+
+ def save_attention_map(self, attention_map):
+ self.attention_map = attention_map
+
+ def get_attention_map(self):
+ return self.attention_map
+
+ def forward(self, x, register_hook=False):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ if register_hook:
+ self.save_attention_map(attn)
+ attn.register_hook(self.save_attn_gradients)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ if use_grad_checkpointing:
+ self.attn = checkpoint_wrapper(self.attn)
+ self.mlp = checkpoint_wrapper(self.mlp)
+
+ def forward(self, x, register_hook=False):
+ x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class VisionTransformer(nn.Module):
+ """ Vision Transformer
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
+ https://arxiv.org/abs/2010.11929
+ """
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
+ num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
+ use_grad_checkpointing=False, ckpt_layer=0):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ num_classes (int): number of classes for classification head
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
+ drop_rate (float): dropout rate
+ attn_drop_rate (float): attention dropout rate
+ drop_path_rate (float): stochastic depth rate
+ norm_layer: (nn.Module): normalization layer
+ """
+ super().__init__()
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
+
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ self.blocks = nn.ModuleList([
+ Block(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
+ use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
+ )
+ for i in range(depth)])
+ self.norm = norm_layer(embed_dim)
+
+ trunc_normal_(self.pos_embed, std=.02)
+ trunc_normal_(self.cls_token, std=.02)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed', 'cls_token'}
+
+ def forward(self, x, register_blk=-1):
+ B = x.shape[0]
+ x = self.patch_embed(x)
+
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ x = x + self.pos_embed[:,:x.size(1),:]
+ x = self.pos_drop(x)
+
+ for i,blk in enumerate(self.blocks):
+ x = blk(x, register_blk==i)
+ x = self.norm(x)
+
+ return x
+
+ @torch.jit.ignore()
+ def load_pretrained(self, checkpoint_path, prefix=''):
+ _load_weights(self, checkpoint_path, prefix)
+
+
+@torch.no_grad()
+def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
+ """ Load weights from .npz checkpoints for official Google Brain Flax implementation
+ """
+ import numpy as np
+
+ def _n2p(w, t=True):
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
+ w = w.flatten()
+ if t:
+ if w.ndim == 4:
+ w = w.transpose([3, 2, 0, 1])
+ elif w.ndim == 3:
+ w = w.transpose([2, 0, 1])
+ elif w.ndim == 2:
+ w = w.transpose([1, 0])
+ return torch.from_numpy(w)
+
+ w = np.load(checkpoint_path)
+ if not prefix and 'opt/target/embedding/kernel' in w:
+ prefix = 'opt/target/'
+
+ if hasattr(model.patch_embed, 'backbone'):
+ # hybrid
+ backbone = model.patch_embed.backbone
+ stem_only = not hasattr(backbone, 'stem')
+ stem = backbone if stem_only else backbone.stem
+ stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
+ stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
+ stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
+ if not stem_only:
+ for i, stage in enumerate(backbone.stages):
+ for j, block in enumerate(stage.blocks):
+ bp = f'{prefix}block{i + 1}/unit{j + 1}/'
+ for r in range(3):
+ getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
+ getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
+ getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
+ if block.downsample is not None:
+ block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
+ block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
+ block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
+ embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
+ else:
+ embed_conv_w = adapt_input_conv(
+ model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
+ model.patch_embed.proj.weight.copy_(embed_conv_w)
+ model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
+ model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
+ pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
+ if pos_embed_w.shape != model.pos_embed.shape:
+ pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
+ pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
+ model.pos_embed.copy_(pos_embed_w)
+ model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
+ model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
+# if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
+# model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
+# model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
+# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
+# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
+# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
+ for i, block in enumerate(model.blocks.children()):
+ block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
+ mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
+ block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
+ block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
+ block.attn.qkv.weight.copy_(torch.cat([
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
+ block.attn.qkv.bias.copy_(torch.cat([
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
+ block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
+ block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
+ for r in range(2):
+ getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
+ getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
+ block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
+ block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
+
+
+def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
+ # interpolate position embedding
+ embedding_size = pos_embed_checkpoint.shape[-1]
+ num_patches = visual_encoder.patch_embed.num_patches
+ num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
+ # height (== width) for the checkpoint position embedding
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
+ # height (== width) for the new position embedding
+ new_size = int(num_patches ** 0.5)
+
+ if orig_size!=new_size:
+ # class_token and dist_token are kept unchanged
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
+
+ return new_pos_embed
+ else:
+ return pos_embed_checkpoint
\ No newline at end of file
diff --git a/ram/transform.py b/ram/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..baff9cdc7b437f842e07849d04e3e9a905d303e2
--- /dev/null
+++ b/ram/transform.py
@@ -0,0 +1,13 @@
+from torchvision.transforms import Normalize, Compose, Resize, ToTensor
+
+
+def convert_to_rgb(image):
+ return image.convert("RGB")
+
+def get_transform(image_size=384):
+ return Compose([
+ convert_to_rgb,
+ Resize((image_size, image_size)),
+ ToTensor(),
+ Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ ])
diff --git a/ram/utils/__init__.py b/ram/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b2595cba0f8fa049dd6d44099e225f8cefc1991
--- /dev/null
+++ b/ram/utils/__init__.py
@@ -0,0 +1,2 @@
+from .metrics import get_mAP, get_PR
+from .openset_utils import build_openset_label_embedding
diff --git a/ram/utils/__pycache__/__init__.cpython-310.pyc b/ram/utils/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c2181eab89285ef86266cd72ad896edfd1208684
Binary files /dev/null and b/ram/utils/__pycache__/__init__.cpython-310.pyc differ
diff --git a/ram/utils/__pycache__/metrics.cpython-310.pyc b/ram/utils/__pycache__/metrics.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fd3c46c03722ebc1941dbaa2fc536f543446ae1b
Binary files /dev/null and b/ram/utils/__pycache__/metrics.cpython-310.pyc differ
diff --git a/ram/utils/__pycache__/openset_utils.cpython-310.pyc b/ram/utils/__pycache__/openset_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e9b5702e10ed8a298f8c6d726375963095d3d7e8
Binary files /dev/null and b/ram/utils/__pycache__/openset_utils.cpython-310.pyc differ
diff --git a/ram/utils/metrics.py b/ram/utils/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..7738dfa1a84d30ec17a292a20d917d75cce8de16
--- /dev/null
+++ b/ram/utils/metrics.py
@@ -0,0 +1,102 @@
+from typing import List, Tuple
+
+import numpy as np
+from numpy import ndarray
+
+
+def get_mAP(
+ preds: ndarray,
+ gt_file: str,
+ taglist: List[str]
+) -> Tuple[float, ndarray]:
+ assert preds.shape[1] == len(taglist)
+
+ # When mapping categories from test datasets to our system, there might be
+ # multiple vs one situation due to different semantic definitions of tags.
+ # So there can be duplicate tags in `taglist`. This special case is taken
+ # into account.
+ tag2idxs = {}
+ for idx, tag in enumerate(taglist):
+ if tag not in tag2idxs:
+ tag2idxs[tag] = []
+ tag2idxs[tag].append(idx)
+
+ # build targets
+ targets = np.zeros_like(preds)
+ with open(gt_file, "r") as f:
+ lines = [line.strip("\n").split(",") for line in f.readlines()]
+ assert len(lines) == targets.shape[0]
+ for i, line in enumerate(lines):
+ for tag in line[1:]:
+ targets[i, tag2idxs[tag]] = 1.0
+
+ # compute average precision for each class
+ APs = np.zeros(preds.shape[1])
+ for k in range(preds.shape[1]):
+ APs[k] = _average_precision(preds[:, k], targets[:, k])
+
+ return APs.mean(), APs
+
+
+def _average_precision(output: ndarray, target: ndarray) -> float:
+ epsilon = 1e-8
+
+ # sort examples
+ indices = output.argsort()[::-1]
+ # Computes prec@i
+ total_count_ = np.cumsum(np.ones((len(output), 1)))
+
+ target_ = target[indices]
+ ind = target_ == 1
+ pos_count_ = np.cumsum(ind)
+ total = pos_count_[-1]
+ pos_count_[np.logical_not(ind)] = 0
+ pp = pos_count_ / total_count_
+ precision_at_i_ = np.sum(pp)
+ precision_at_i = precision_at_i_ / (total + epsilon)
+
+ return precision_at_i
+
+
+def get_PR(
+ pred_file: str,
+ gt_file: str,
+ taglist: List[str]
+) -> Tuple[float, float, ndarray, ndarray]:
+ # When mapping categories from test datasets to our system, there might be
+ # multiple vs one situation due to different semantic definitions of tags.
+ # So there can be duplicate tags in `taglist`. This special case is taken
+ # into account.
+ tag2idxs = {}
+ for idx, tag in enumerate(taglist):
+ if tag not in tag2idxs:
+ tag2idxs[tag] = []
+ tag2idxs[tag].append(idx)
+
+ # build preds
+ with open(pred_file, "r", encoding="utf-8") as f:
+ lines = [line.strip().split(",") for line in f.readlines()]
+ preds = np.zeros((len(lines), len(tag2idxs)), dtype=bool)
+ for i, line in enumerate(lines):
+ for tag in line[1:]:
+ preds[i, tag2idxs[tag]] = True
+
+ # build targets
+ with open(gt_file, "r", encoding="utf-8") as f:
+ lines = [line.strip().split(",") for line in f.readlines()]
+ targets = np.zeros((len(lines), len(tag2idxs)), dtype=bool)
+ for i, line in enumerate(lines):
+ for tag in line[1:]:
+ targets[i, tag2idxs[tag]] = True
+
+ assert preds.shape == targets.shape
+
+ # calculate P and R
+ TPs = ( preds & targets).sum(axis=0) # noqa: E201, E222
+ FPs = ( preds & ~targets).sum(axis=0) # noqa: E201, E222
+ FNs = (~preds & targets).sum(axis=0) # noqa: E201, E222
+ eps = 1.e-9
+ Ps = TPs / (TPs + FPs + eps)
+ Rs = TPs / (TPs + FNs + eps)
+
+ return Ps.mean(), Rs.mean(), Ps, Rs
diff --git a/ram/utils/openset_utils.py b/ram/utils/openset_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b368405c8eebf9fcee54cfd1362b9b416572264
--- /dev/null
+++ b/ram/utils/openset_utils.py
@@ -0,0 +1,333 @@
+
+
+
+import torch
+import torch.nn as nn
+from clip import clip
+
+
+def article(name):
+ return "an" if name[0] in "aeiou" else "a"
+
+
+def processed_name(name, rm_dot=False):
+ # _ for lvis
+ # / for obj365
+ res = name.replace("_", " ").replace("/", " or ").lower()
+ if rm_dot:
+ res = res.rstrip(".")
+ return res
+
+
+single_template = ["a photo of a {}."]
+
+multiple_templates = [
+ "There is {article} {} in the scene.",
+ "There is the {} in the scene.",
+ "a photo of {article} {} in the scene.",
+ "a photo of the {} in the scene.",
+ "a photo of one {} in the scene.",
+ "itap of {article} {}.",
+ "itap of my {}.", # itap: I took a picture of
+ "itap of the {}.",
+ "a photo of {article} {}.",
+ "a photo of my {}.",
+ "a photo of the {}.",
+ "a photo of one {}.",
+ "a photo of many {}.",
+ "a good photo of {article} {}.",
+ "a good photo of the {}.",
+ "a bad photo of {article} {}.",
+ "a bad photo of the {}.",
+ "a photo of a nice {}.",
+ "a photo of the nice {}.",
+ "a photo of a cool {}.",
+ "a photo of the cool {}.",
+ "a photo of a weird {}.",
+ "a photo of the weird {}.",
+ "a photo of a small {}.",
+ "a photo of the small {}.",
+ "a photo of a large {}.",
+ "a photo of the large {}.",
+ "a photo of a clean {}.",
+ "a photo of the clean {}.",
+ "a photo of a dirty {}.",
+ "a photo of the dirty {}.",
+ "a bright photo of {article} {}.",
+ "a bright photo of the {}.",
+ "a dark photo of {article} {}.",
+ "a dark photo of the {}.",
+ "a photo of a hard to see {}.",
+ "a photo of the hard to see {}.",
+ "a low resolution photo of {article} {}.",
+ "a low resolution photo of the {}.",
+ "a cropped photo of {article} {}.",
+ "a cropped photo of the {}.",
+ "a close-up photo of {article} {}.",
+ "a close-up photo of the {}.",
+ "a jpeg corrupted photo of {article} {}.",
+ "a jpeg corrupted photo of the {}.",
+ "a blurry photo of {article} {}.",
+ "a blurry photo of the {}.",
+ "a pixelated photo of {article} {}.",
+ "a pixelated photo of the {}.",
+ "a black and white photo of the {}.",
+ "a black and white photo of {article} {}.",
+ "a plastic {}.",
+ "the plastic {}.",
+ "a toy {}.",
+ "the toy {}.",
+ "a plushie {}.",
+ "the plushie {}.",
+ "a cartoon {}.",
+ "the cartoon {}.",
+ "an embroidered {}.",
+ "the embroidered {}.",
+ "a painting of the {}.",
+ "a painting of a {}.",
+]
+
+
+openimages_rare_unseen = ['Aerial photography',
+'Aircraft engine',
+'Ale',
+'Aloe',
+'Amphibian',
+'Angling',
+'Anole',
+'Antique car',
+'Arcade game',
+'Arthropod',
+'Assault rifle',
+'Athletic shoe',
+'Auto racing',
+'Backlighting',
+'Bagpipes',
+'Ball game',
+'Barbecue chicken',
+'Barechested',
+'Barquentine',
+'Beef tenderloin',
+'Billiard room',
+'Billiards',
+'Bird of prey',
+'Black swan',
+'Black-and-white',
+'Blond',
+'Boating',
+'Bonbon',
+'Bottled water',
+'Bouldering',
+'Bovine',
+'Bratwurst',
+'Breadboard',
+'Briefs',
+'Brisket',
+'Brochette',
+'Calabaza',
+'Camera operator',
+'Canola',
+'Childbirth',
+'Chordophone',
+'Church bell',
+'Classical sculpture',
+'Close-up',
+'Cobblestone',
+'Coca-cola',
+'Combat sport',
+'Comics',
+'Compact car',
+'Computer speaker',
+'Cookies and crackers',
+'Coral reef fish',
+'Corn on the cob',
+'Cosmetics',
+'Crocodilia',
+'Digital camera',
+'Dishware',
+'Divemaster',
+'Dobermann',
+'Dog walking',
+'Domestic rabbit',
+'Domestic short-haired cat',
+'Double-decker bus',
+'Drums',
+'Electric guitar',
+'Electric piano',
+'Electronic instrument',
+'Equestrianism',
+'Equitation',
+'Erinaceidae',
+'Extreme sport',
+'Falafel',
+'Figure skating',
+'Filling station',
+'Fire apparatus',
+'Firearm',
+'Flatbread',
+'Floristry',
+'Forklift truck',
+'Freight transport',
+'Fried food',
+'Fried noodles',
+'Frigate',
+'Frozen yogurt',
+'Frying',
+'Full moon',
+'Galleon',
+'Glacial landform',
+'Gliding',
+'Go-kart',
+'Goats',
+'Grappling',
+'Great white shark',
+'Gumbo',
+'Gun turret',
+'Hair coloring',
+'Halter',
+'Headphones',
+'Heavy cruiser',
+'Herding',
+'High-speed rail',
+'Holding hands',
+'Horse and buggy',
+'Horse racing',
+'Hound',
+'Hunting knife',
+'Hurdling',
+'Inflatable',
+'Jackfruit',
+'Jeans',
+'Jiaozi',
+'Junk food',
+'Khinkali',
+'Kitesurfing',
+'Lawn game',
+'Leaf vegetable',
+'Lechon',
+'Lifebuoy',
+'Locust',
+'Lumpia',
+'Luxury vehicle',
+'Machine tool',
+'Medical imaging',
+'Melee weapon',
+'Microcontroller',
+'Middle ages',
+'Military person',
+'Military vehicle',
+'Milky way',
+'Miniature Poodle',
+'Modern dance',
+'Molluscs',
+'Monoplane',
+'Motorcycling',
+'Musical theatre',
+'Narcissus',
+'Nest box',
+'Newsagent\'s shop',
+'Nile crocodile',
+'Nordic skiing',
+'Nuclear power plant',
+'Orator',
+'Outdoor shoe',
+'Parachuting',
+'Pasta salad',
+'Peafowl',
+'Pelmeni',
+'Perching bird',
+'Performance car',
+'Personal water craft',
+'Pit bull',
+'Plant stem',
+'Pork chop',
+'Portrait photography',
+'Primate',
+'Procyonidae',
+'Prosciutto',
+'Public speaking',
+'Racewalking',
+'Ramen',
+'Rear-view mirror',
+'Residential area',
+'Ribs',
+'Rice ball',
+'Road cycling',
+'Roller skating',
+'Roman temple',
+'Rowing',
+'Rural area',
+'Sailboat racing',
+'Scaled reptile',
+'Scuba diving',
+'Senior citizen',
+'Shallot',
+'Shinto shrine',
+'Shooting range',
+'Siberian husky',
+'Sledding',
+'Soba',
+'Solar energy',
+'Sport climbing',
+'Sport utility vehicle',
+'Steamed rice',
+'Stemware',
+'Sumo',
+'Surfing Equipment',
+'Team sport',
+'Touring car',
+'Toy block',
+'Trampolining',
+'Underwater diving',
+'Vegetarian food',
+'Wallaby',
+'Water polo',
+'Watercolor paint',
+'Whiskers',
+'Wind wave',
+'Woodwind instrument',
+'Yakitori',
+'Zeppelin']
+
+
+def build_openset_label_embedding(categories=None):
+ if categories is None:
+ categories = openimages_rare_unseen
+ # model, _ = clip.load("ViT-B/16")
+ model, _ = clip.load("ViT-B-16.pt")
+ templates = multiple_templates
+
+ run_on_gpu = torch.cuda.is_available()
+
+ with torch.no_grad():
+ openset_label_embedding = []
+ for category in categories:
+ texts = [
+ template.format(
+ processed_name(category, rm_dot=True), article=article(category)
+ )
+ for template in templates
+ ]
+ texts = [
+ "This is " + text if text.startswith("a") or text.startswith("the") else text
+ for text in texts
+ ]
+ texts = clip.tokenize(texts) # tokenize
+ if run_on_gpu:
+ texts = texts.cuda()
+ model = model.cuda()
+ text_embeddings = model.encode_text(texts)
+ text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
+ text_embedding = text_embeddings.mean(dim=0)
+ text_embedding /= text_embedding.norm()
+ openset_label_embedding.append(text_embedding)
+ openset_label_embedding = torch.stack(openset_label_embedding, dim=1)
+ if run_on_gpu:
+ openset_label_embedding = openset_label_embedding.cuda()
+
+ openset_label_embedding = openset_label_embedding.t()
+ return openset_label_embedding, categories
+
+
+
+
diff --git a/requirement.txt b/requirement.txt
new file mode 100644
index 0000000000000000000000000000000000000000..28aa6f7975c6935f440d6b10853bdfb66b2cb823
--- /dev/null
+++ b/requirement.txt
@@ -0,0 +1,14 @@
+diffusers==0.25.0
+torch==2.0.1
+transformers==4.28.1
+xformers==0.0.20
+einops==0.7.0
+open-clip-torch==2.20.0
+peft==0.9.0
+Pillow==9.5.0
+PyYAML==6.0
+huggingface_hub==0.25.2
+numpy==1.23.5
+loralib
+basicsr
+fairscale
\ No newline at end of file
diff --git a/scripts/test/test.sh b/scripts/test/test.sh
new file mode 100644
index 0000000000000000000000000000000000000000..0c4085267172c568fdd7dc7631c80cddce5bf739
--- /dev/null
+++ b/scripts/test/test.sh
@@ -0,0 +1,11 @@
+CUDA_VISIBLE_DEVICES=0, python GDPOSR/inferences/test.py \
+--input_image test_LR \
+--output_dir experiment/GDPOSR \
+--pretrained_path ckp/GDPOSR \
+--pretrained_model_name_or_path stable-diffusion-2-1-base \
+--ram_ft_path ckp/DAPE.pth \
+--negprompt 'dotted, noise, blur, lowres, smooth' \
+--prompt 'clean, high-resolution, 8k' \
+--upscale 1 \
+--time_step=100 \
+--time_step_noise=250
\ No newline at end of file
diff --git a/scripts/train/train_GDPOSR.sh b/scripts/train/train_GDPOSR.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c94c22a9f6137d2a85df4531d2a8e4b9bd5bd2b0
--- /dev/null
+++ b/scripts/train/train_GDPOSR.sh
@@ -0,0 +1,36 @@
+
+accelerate launch --main_process_port=12345 --gpu_ids=0,1,2,3,4,5,6,7 --num_processes=8 GDPOSR/train/train_GDPOSR.py \
+ --pretrained_model_name_or_path="stable-diffusion-2-1-base" \
+ --basemodel_path="NAOSD" \
+ --dataset_folder="dataset" \
+ --testdataset_folder="RealSRCrop128" \
+ --resolution=512 \
+ --learning_rate=5e-5 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=1 \
+ --enable_xformers_memory_efficient_attention \
+ --eval_freq 10 \
+ --checkpointing_steps 10 \
+ --mixed_precision='fp16' \
+ --report_to "tensorboard" \
+ --output_dir="experiment/GDPOSR" \
+ --null_text_ratio=1 \
+ --lora_rank_unet_vsd=4 \
+ --lora_rank_unet=4 \
+ --lora_rank_vae=4 \
+ --lambda_lpips=2 \
+ --lambda_l2=1 \
+ --lambda_vsd=1 \
+ --lambda_vsd_lora=1 \
+ --min_dm_step_ratio=0.02 \
+ --max_dm_step_ratio=0.50 \
+ --use_vae_encode_lora \
+ --align_method="adain" \
+ --use_online_deg \
+ --deg_file_path="params_GDPO.yml" \
+ --time_step=100 \
+ --time_step_noise=250 \
+ --groupsize=6 \
+ --negative_prompt="painting, oil painting, illustration, drawing, art, sketch, oil painting, cartoon, CG Style, 3D render, unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, signature, jpeg artifacts,deformed, lowres, over-smooth" \
+ --tracker_project_name "GDPOSR"
+
diff --git a/scripts/train/train_NAOSD.sh b/scripts/train/train_NAOSD.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9c1f639843e685a40f28318bfe8d664eb18f0ad6
--- /dev/null
+++ b/scripts/train/train_NAOSD.sh
@@ -0,0 +1,34 @@
+
+accelerate launch --main_process_port=12345 --gpu_ids=0,1,2,3, --num_processes=4 GDPOSR/train/train_NAOSD.py \
+ --pretrained_model_name_or_path="stable-diffusion-2-1-base" \
+ --dataset_folder="lsdir_ffhq10k" \
+ --testdataset_folder="RealSRCrop128" \
+ --resolution=512 \
+ --learning_rate=5e-5 \
+ --train_batch_size=2 \
+ --gradient_accumulation_steps=2 \
+ --enable_xformers_memory_efficient_attention \
+ --eval_freq 500 \
+ --checkpointing_steps 500 \
+ --mixed_precision='fp16' \
+ --report_to "tensorboard" \
+ --output_dir="experiment/NAOSD" \
+ --null_text_ratio=1 \
+ --lora_rank_unet_vsd=4 \
+ --lora_rank_unet=4 \
+ --lora_rank_vae=4 \
+ --lambda_lpips=2 \
+ --lambda_l2=1 \
+ --lambda_vsd=1 \
+ --lambda_vsd_lora=1 \
+ --min_dm_step_ratio=0.02 \
+ --max_dm_step_ratio=0.50 \
+ --use_vae_encode_lora \
+ --align_method="adain" \
+ --use_online_deg \
+ --deg_file_path="params_GDPO.yml" \
+ --time_step=100 \
+ --time_step_noise=250 \
+ --negative_prompt="painting, oil painting, illustration, drawing, art, sketch, oil painting, cartoon, CG Style, 3D render, unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, signature, jpeg artifacts,deformed, lowres, over-smooth" \
+ --tracker_project_name "NAOSD"
+