diff --git a/README.md b/README.md index 5880db09a82252bc20eb1597b4ead8647dac08a2..b0c5a0d77f93dd6a549cbea36a25f67aae392f86 100644 --- a/README.md +++ b/README.md @@ -1,72 +1,101 @@ --- +title: DiffSketcher +emoji: 🎨 +colorFrom: blue +colorTo: purple +sdk: custom +app_file: handler.py +pinned: false +license: mit tags: -- text-to-image -- diffusers -- vector-graphics - svg -library_name: diffusers -pipeline_tag: text-to-image -inference: true +- vector-graphics +- text-to-image +- diffusion +- sketch +pipeline_tag: image-generation +library_name: diffvg --- -# DiffSketcher - Vector Graphics Generation +# DiffSketcher: Text Guided Vector Sketch Synthesis -This model generates vector graphics (SVG) from text prompts using the original DiffSketcher implementation. +DiffSketcher is a novel method for generating high-quality vector sketches from text prompts using latent diffusion models. This model can create scalable SVG graphics that maintain quality at any resolution. ## Model Description -DiffSketcher is a state-of-the-art vector graphics generation model that creates high-quality SVG images from text prompts. It uses a diffusion model to guide the SVG generation process. +DiffSketcher leverages the power of Stable Diffusion to guide the optimization of vector paths, creating artistic sketches that are both semantically meaningful and visually appealing. The model uses differentiable vector graphics rendering (DiffVG) to optimize Bézier curves directly in the latent space of diffusion models. ## Usage ```python import requests +import json -API_URL = "https://api-inference.huggingface.co/models/jree423/diffsketcher" -headers = {"Authorization": "Bearer YOUR_TOKEN"} +# API endpoint +url = "https://api-inference.huggingface.co/models/jree423/diffsketcher" -def query(prompt): - response = requests.post(API_URL, headers=headers, json={"inputs": prompt}) - return response.content +# Headers +headers = {"Authorization": "Bearer YOUR_HF_TOKEN"} -# Generate an image -with open("output.png", "wb") as f: - f.write(query("a beautiful mountain landscape")) -``` +# Payload +payload = { + "inputs": "a beautiful mountain landscape", + "parameters": { + "num_paths": 96, + "num_iter": 500, + "token_ind": 4, + "guidance_scale": 7.5, + "canvas_size": 224 + } +} -You can also specify additional parameters: +# Make request +response = requests.post(url, headers=headers, json=payload) +result = response.json() -```python -response = requests.post( - API_URL, - headers=headers, - json={ - "inputs": { - "text": "a beautiful mountain landscape", - "width": 512, - "height": 512, - "num_paths": 512, - "seed": 42 - } - } -) +# The result contains the SVG content +svg_content = result[0]["svg"] ``` ## Parameters -- `text` (str): The text prompt to generate an image from. -- `width` (int, optional): The width of the generated image. Default: 512. -- `height` (int, optional): The height of the generated image. Default: 512. -- `num_paths` (int, optional): The number of paths to use in the SVG. Default: 512. -- `seed` (int, optional): The random seed to use for generation. Default: None (random). +- **num_paths** (int, default: 96): Number of paths/strokes in the generated SVG +- **num_iter** (int, default: 500): Number of optimization iterations +- **token_ind** (int, default: 4): Index of cross-attention maps to initialize strokes +- **guidance_scale** (float, default: 7.5): Guidance scale for diffusion +- **canvas_size** (int, default: 224): Canvas size for SVG generation + +## Examples + +### Simple Sketch +``` +Input: "a cat sitting on a chair" +Parameters: {"num_paths": 48, "num_iter": 300} +``` + +### Detailed Artwork +``` +Input: "a majestic eagle soaring through clouds" +Parameters: {"num_paths": 128, "num_iter": 800} +``` + +### Abstract Art +``` +Input: "abstract geometric patterns in blue and gold" +Parameters: {"num_paths": 200, "num_iter": 1000} +``` ## Citation ```bibtex @inproceedings{xing2023diffsketcher, title={DiffSketcher: Text Guided Vector Sketch Synthesis through Latent Diffusion Models}, - author={Xing, XiMing and Zhan, Chuang and Xu, Yinghao and Dong, Yue and Yu, Yingqing and Li, Chongyang and Liu, Yong Jin}, + author={Xing, XiMing and Wang, Chuang and Zhou, Haitao and Zhang, Jing and Yu, Qian and Xu, Dong}, booktitle={Advances in Neural Information Processing Systems}, year={2023} } -``` \ No newline at end of file +``` + +## License + +This model is released under the MIT License. \ No newline at end of file diff --git a/config.json b/config.json index 606f3bcf4b8e8e024c1d368873e16dd9b96ea487..62dd9f726cdf41fc2147280f1922b7ca58b0941c 100644 --- a/config.json +++ b/config.json @@ -1,8 +1,44 @@ { - "architectures": [ - "CustomModel" + "architectures": ["DiffSketcher"], + "model_type": "diffsketcher", + "task": "text-to-svg", + "framework": "pytorch", + "pipeline_tag": "image-generation", + "library_name": "diffvg", + "tags": [ + "svg", + "vector-graphics", + "text-to-image", + "diffusion", + "sketch" ], - "model_type": "custom", - "task": "text-to-image", - "inference": true + "inference": { + "parameters": { + "num_paths": { + "type": "integer", + "default": 96, + "description": "Number of paths/strokes in the generated SVG" + }, + "num_iter": { + "type": "integer", + "default": 500, + "description": "Number of optimization iterations" + }, + "token_ind": { + "type": "integer", + "default": 4, + "description": "Index of cross-attention maps to initialize strokes" + }, + "guidance_scale": { + "type": "float", + "default": 7.5, + "description": "Guidance scale for diffusion" + }, + "canvas_size": { + "type": "integer", + "default": 224, + "description": "Canvas size for SVG generation" + } + } + } } \ No newline at end of file diff --git a/config/diffsketcher-color.yaml b/config/diffsketcher-color.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9c2533b331be6641856ec945360ab66cfcf8281d --- /dev/null +++ b/config/diffsketcher-color.yaml @@ -0,0 +1,75 @@ +image_size: 224 +path_svg: ~ # if you want to load a svg file and train from it +mask_object: False # if the target image contains background, it's better to mask it out +fix_scale: False # if the target image is not squared, it is recommended to fix the scale + +# train +num_iter: 2000 +batch_size: 1 +num_stages: 1 # training stages, you can train x strokes, then freeze them and train another x strokes etc +lr_scheduler: False +lr_decay_rate: 0.1 +decay_steps: [ 1000, 1500 ] +lr: 1 # point lr +color_lr: 0.01 +color_vars_threshold: 0.1 +width_lr: 0.1 # stroke width lr +max_width: 50 # stroke width + +# stroke attrs +num_paths: 128 # number of strokes +width: 1.5 # init stroke width +control_points_per_seg: 4 +num_segments: 1 +optim_opacity: False # if True, the stroke opacity is optimized +optim_width: True # if True, the stroke width is optimized +optim_rgba: True # if True, the stroke RGBA is optimized +opacity_delta: 0 # stroke pruning + +# init strokes +attention_init: True # if True, use the attention heads of Dino model to set the location of the initial strokes +xdog_intersec: False # initialize along the edge, mix XDoG and attn up +softmax_temp: 0.5 # the temperature of softmax +cross_attn_res: 16 # cross attn resolution +self_attn_res: 32 # self-attn resolution +max_com: 20 # select the number of the self-attn maps +mean_comp: False # the average of the self-attn maps +comp_idx: 0 # if mean_comp==False, indicates the index of the self-attn map +attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn +log_cross_attn: False +u2net_path: "./checkpoint/u2net/u2net.pth" + +# ldm +model_id: "sd15" # stable diffusion V1.5 +ldm_speed_up: False +enable_xformers: True # speed up attn compute +gradient_checkpoint: False # this slows down the code, but saves GPU VRAM +token_ind: 1 # the index of CLIP prompt embedding, start from 1, 0 is start token +use_ddim: True +num_inference_steps: 100 +guidance_scale: 7.5 + +# ASDS loss +sds: + crop_size: 512 + augmentations: "affine" + guidance_scale: 100 + grad_scale: 1e-6 + t_range: [ 0.05, 0.95 ] + warmup: 3000 + +# JVSP +clip: + model_name: "RN101" # RN101, ViT-L/14 + feats_loss_type: "l2" # clip visual loss type, conv layers + feats_loss_weights: [ 0,0,1.0,1.0,0 ] # RN based + # feats_loss_weights: [ 0,0,1.0,1.0,0,0,0,0,0,0,0,0 ] # ViT based + fc_loss_weight: 0.1 # clip visual loss, fc layer weight + augmentations: "affine_norm" # augmentation before clip visual computation, affine_norm_trivial + num_aug: 4 # num of augmentation before clip visual computation + vis_loss: 1 # 1 or 0 for use or disable clip visual loss + text_visual_coeff: 0 # cosine similarity between text and img +perceptual: + name: "lpips" # dists + lpips_net: 'vgg' + coeff: 0.2 \ No newline at end of file diff --git a/config/diffsketcher-style.yaml b/config/diffsketcher-style.yaml new file mode 100644 index 0000000000000000000000000000000000000000..37d7a2024a5faf20687d2e9e0207bc1742bfb7d5 --- /dev/null +++ b/config/diffsketcher-style.yaml @@ -0,0 +1,78 @@ +image_size: 224 +path_svg: ~ # if you want to load a svg file and train from it +mask_object: False # if the target image contains background, it's better to mask it out +fix_scale: False # if the target image is not squared, it is recommended to fix the scale + +# train +num_iter: 2000 +batch_size: 1 +num_stages: 1 # training stages, you can train x strokes, then freeze them and train another x strokes etc +lr_scheduler: False +lr_decay_rate: 0.1 +decay_steps: [ 1000, 1500 ] +lr: 1 # point lr +color_lr: 0.01 +color_vars_threshold: 0.0 # uncomment the code +width_lr: 0.1 # stroke width lr +max_width: 50 # stroke width + +# stroke attrs +num_paths: 2500 # number of strokes +width: 1.5 # init stroke width +control_points_per_seg: 4 +num_segments: 1 +optim_opacity: True # if True, the stroke opacity is optimized +optim_width: True # if True, the stroke width is optimized +optim_rgba: True # if True, the stroke RGBA is optimized +opacity_delta: 0 # stroke pruning + +# init strokes +attention_init: True # if True, use the attention heads of Dino model to set the location of the initial strokes +xdog_intersec: False # initialize along the edge, mix XDoG and attn up +softmax_temp: 0.5 # the temperature of softmax +cross_attn_res: 16 # cross attn resolution +self_attn_res: 32 # self-attn resolution +max_com: 20 # select the number of the self-attn maps +mean_comp: False # the average of the self-attn maps +comp_idx: 0 # if mean_comp==False, indicates the index of the self-attn map +attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn +log_cross_attn: False +u2net_path: "./checkpoint/u2net/u2net.pth" + +# ldm +model_id: "sd15" # stable diffusion V1.5 +ldm_speed_up: False +enable_xformers: True # speed up attn compute +gradient_checkpoint: False # this slows down the code, but saves GPU VRAM +token_ind: 1 # the index of CLIP prompt embedding, start from 1, 0 is start token +use_ddim: True +num_inference_steps: 100 +guidance_scale: 7.5 + +# ASDS loss +sds: + crop_size: 512 + augmentations: "affine" + guidance_scale: 100 + grad_scale: 1e-6 + t_range: [ 0.05, 0.95 ] + warmup: 3000 + +# JVSP +clip: + model_name: "RN101" # RN101, ViT-L/14 + feats_loss_type: "l2" # clip visual loss type, conv layers + feats_loss_weights: [ 0,0,1.0,1.0,0 ] # RN based + # feats_loss_weights: [ 0,0,1.0,1.0,0,0,0,0,0,0,0,0 ] # ViT based + fc_loss_weight: 0.1 # clip visual loss, fc layer weight + augmentations: "affine_norm" # augmentation before clip visual computation, affine_norm_trivial + num_aug: 4 # num of augmentation before clip visual computation + vis_loss: 1 # 1 or 0 for use or disable clip visual loss + text_visual_coeff: 0 # cosine similarity between text and img +perceptual: + name: "lpips" # dists + lpips_net: 'vgg' + coeff: 0.2 + +style_warmup: 1000 # add style loss after `style_warmup` step +style_strength: 1 # How strong the style should be. 100 (max) is a lot. 0 (min) is no style. \ No newline at end of file diff --git a/config/diffsketcher-width.yaml b/config/diffsketcher-width.yaml new file mode 100644 index 0000000000000000000000000000000000000000..46898f517b43d77c22dd2280a8c18c19e5595433 --- /dev/null +++ b/config/diffsketcher-width.yaml @@ -0,0 +1,75 @@ +image_size: 224 +path_svg: ~ # if you want to load a svg file and train from it +mask_object: False # if the target image contains background, it's better to mask it out +fix_scale: False # if the target image is not squared, it is recommended to fix the scale + +# train +num_iter: 500 +batch_size: 1 +num_stages: 1 # training stages, you can train x strokes, then freeze them and train another x strokes etc +lr_scheduler: False +lr_decay_rate: 0.1 +decay_steps: [ 1000, 1500 ] +lr: 1 # point lr +color_lr: 0.01 +color_vars_threshold: 0.1 +width_lr: 0.1 # stroke width lr +max_width: 50 # stroke width + +# stroke attrs +num_paths: 128 # number of strokes +width: 3 # init stroke width +control_points_per_seg: 4 +num_segments: 1 +optim_opacity: True # if True, the stroke opacity is optimized +optim_width: True # if True, the stroke width is optimized +optim_rgba: False # if True, the stroke RGBA is optimized +opacity_delta: 0 # stroke pruning + +# init strokes +attention_init: True # if True, use the attention heads of Dino model to set the location of the initial strokes +xdog_intersec: True # initialize along the edge, mix XDoG and attn up +softmax_temp: 0.5 # the temperature of softmax +cross_attn_res: 16 # cross attn resolution +self_attn_res: 32 # self-attn resolution +max_com: 20 # select the number of the self-attn maps +mean_comp: False # the average of the self-attn maps +comp_idx: 0 # if mean_comp==False, indicates the index of the self-attn map +attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn +log_cross_attn: False +u2net_path: "./checkpoint/u2net/u2net.pth" + +# ldm +model_id: "sd15" # stable diffusion V1.5 +ldm_speed_up: False +enable_xformers: True # speed up attn compute +gradient_checkpoint: False # this slows down the code, but saves GPU VRAM +token_ind: 1 # the index of CLIP prompt embedding, start from 1, 0 is start token +use_ddim: True +num_inference_steps: 100 +guidance_scale: 7.5 + +# ASDS loss +sds: + crop_size: 512 + augmentations: "affine" + guidance_scale: 100 + grad_scale: 1e-5 + t_range: [ 0.05, 0.95 ] + warmup: 2000 + +# JVSP +clip: + model_name: "RN101" # RN101, ViT-L/14 + feats_loss_type: "l2" # clip visual loss type, conv layers + feats_loss_weights: [ 0,0,1.0,1.0,0 ] # RN based + # feats_loss_weights: [ 0,0,1.0,1.0,0,0,0,0,0,0,0,0 ] # ViT based + fc_loss_weight: 0.1 # clip visual loss, fc layer weight + augmentations: "affine_norm" # augmentation before clip visual computation, affine_norm_trivial + num_aug: 4 # num of augmentation before clip visual computation + vis_loss: 1 # 1 or 0 for use or disable clip visual loss + text_visual_coeff: 0 # cosine similarity between text and img +perceptual: + name: "lpips" # dists + lpips_net: 'vgg' + coeff: 0.2 \ No newline at end of file diff --git a/config/diffsketcher.yaml b/config/diffsketcher.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3d0c7327c70a1f2f4d21cc43d306b17567d870c4 --- /dev/null +++ b/config/diffsketcher.yaml @@ -0,0 +1,76 @@ +image_size: 224 +path_svg: ~ # if you want to load a svg file and train from it +mask_object: False # if the target image contains background, it's better to mask it out +fix_scale: False # if the target image is not squared, it is recommended to fix the scale + +# train +num_iter: 2000 +batch_size: 1 +num_stages: 1 # training stages, you can train x strokes, then freeze them and train another x strokes etc +lr_scheduler: False +lr_decay_rate: 0.1 +decay_steps: [ 1000, 1500 ] +lr: 1 # point lr +color_lr: 0.01 +pruning_freq: 50 +color_vars_threshold: 0.1 +width_lr: 0.1 +max_width: 50 # stroke width + +# stroke attrs +num_paths: 128 # number of strokes +width: 1.5 # stroke width +control_points_per_seg: 4 +num_segments: 1 +optim_opacity: True # if True, the stroke opacity is optimized +optim_width: False # if True, the stroke width is optimized +optim_rgba: False # if True, the stroke RGBA is optimized +opacity_delta: 0 # stroke pruning + +# init strokes +attention_init: True # if True, use the attention heads of Dino model to set the location of the initial strokes +xdog_intersec: True # initialize along the edge, mix XDoG and attn up +softmax_temp: 0.5 # the temperature of softmax +cross_attn_res: 16 # cross attn resolution +self_attn_res: 32 # self-attn resolution +max_com: 20 # select the number of the self-attn maps +mean_comp: False # the average of the self-attn maps +comp_idx: 0 # if mean_comp==False, indicates the index of the self-attn map +attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn +log_cross_attn: False # True if cross attn every step +u2net_path: "./checkpoint/u2net/u2net.pth" + +# ldm +model_id: "sd15" # stable diffusion V1.5 +ldm_speed_up: False +enable_xformers: True # speed up attn compute +gradient_checkpoint: False # this slows down the code, but saves GPU VRAM +token_ind: 1 # the index of CLIP prompt embedding, start from 1, 0 is start token +use_ddim: True +num_inference_steps: 100 +guidance_scale: 7.5 # sdxl default 5.0 + +# ASDS loss +sds: + crop_size: 512 + augmentations: "affine" + guidance_scale: 100 + grad_scale: 1e-6 + t_range: [ 0.05, 0.95 ] + warmup: 2000 + +# JVSP +clip: + model_name: "RN101" # RN101, ViT-L/14 + feats_loss_type: "l2" # clip visual loss type, conv layers + feats_loss_weights: [ 0,0,1.0,1.0,0 ] # RN based + # feats_loss_weights: [ 0,0,1.0,1.0,0,0,0,0,0,0,0,0 ] # ViT based + fc_loss_weight: 0.1 # clip visual loss, fc layer weight + augmentations: "affine" # augmentation before clip visual computation + num_aug: 4 # num of augmentation before clip visual computation + vis_loss: 1 # 1 or 0 for use or disable clip visual loss + text_visual_coeff: 0 # cosine similarity between text and img +perceptual: + name: "lpips" # dists, lpips + lpips_net: 'vgg' + coeff: 0.2 \ No newline at end of file diff --git a/handler.py b/handler.py index ca9be5fc3e21952075efe1e0d553522da024bcb8..19959c452cdd92c7db7d403b7f984f1130a9ecf0 100644 --- a/handler.py +++ b/handler.py @@ -1,137 +1,158 @@ import os -import io import sys import torch -import numpy as np +import base64 +import io from PIL import Image -import traceback +import tempfile +import shutil +from typing import Dict, Any, List import json -import logging -import base64 -# Configure logging -logging.basicConfig(level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) +# Add current directory to path for imports +current_dir = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, current_dir) -# Safely import cairosvg with fallback try: - import cairosvg - logger.info("Successfully imported cairosvg") -except ImportError: - logger.warning("cairosvg not found. Installing...") - import subprocess - subprocess.check_call(["pip", "install", "cairosvg"]) - import cairosvg - logger.info("Successfully installed and imported cairosvg") + import pydiffvg + from diffusers import StableDiffusionPipeline + from omegaconf import OmegaConf + DEPENDENCIES_AVAILABLE = True +except ImportError as e: + print(f"Warning: Some dependencies not available: {e}") + DEPENDENCIES_AVAILABLE = False + class EndpointHandler: - def __init__(self, model_dir): - """Initialize the handler with model directory""" - logger.info(f"Initializing handler with model_dir: {model_dir}") - self.model_dir = model_dir + def __init__(self, path=""): + """ + Initialize the handler for DiffSketcher model. + """ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - logger.info(f"Using device: {self.device}") - # Initialize the model - logger.info("Initializing DiffSketcher model...") - self._initialize_model() - logger.info("DiffSketcher model initialized") - - def _initialize_model(self): - """Initialize the DiffSketcher model""" - # This is a simplified initialization that doesn't rely on external imports - logger.info("Using simplified model initialization") + if not DEPENDENCIES_AVAILABLE: + print("Warning: Dependencies not available, handler will return mock responses") + return - # Add the current directory to the path - sys.path.append(os.path.dirname(os.path.abspath(__file__))) + # Create a minimal config + self.cfg = OmegaConf.create({ + 'method': 'diffsketcher', + 'num_paths': 96, + 'num_iter': 500, + 'token_ind': 4, + 'guidance_scale': 7.5, + 'diffuser': { + 'model_id': 'stabilityai/stable-diffusion-2-1-base', + 'download': True + }, + 'painter': { + 'canvas_size': 224, + 'lr_scheduler': True, + 'lr': 0.01 + } + }) - # Try to import CLIP + # Initialize the diffusion pipeline try: - import clip - logger.info("Successfully imported CLIP") - except ImportError: - logger.warning("CLIP not found. Installing...") - subprocess.check_call(["pip", "install", "git+https://github.com/openai/CLIP.git"]) - import clip - logger.info("Successfully installed and imported CLIP") + self.pipe = StableDiffusionPipeline.from_pretrained( + self.cfg.diffuser.model_id, + torch_dtype=torch.float32, + safety_checker=None, + requires_safety_checker=False + ).to(self.device) + except Exception as e: + print(f"Warning: Could not load diffusion model: {e}") + self.pipe = None - # Try to import diffvg + # Set up pydiffvg try: - import diffvg - logger.info("Successfully imported diffvg") - except ImportError: - logger.warning("diffvg not found. Using placeholder implementation") - - def generate_svg(self, prompt, width=512, height=512, num_paths=512, seed=None): - """Generate an SVG from a text prompt""" - logger.info(f"Generating SVG for prompt: {prompt}") - - # Set a seed for reproducibility - if seed is not None: - torch.manual_seed(seed) - np.random.seed(seed) + pydiffvg.set_print_timing(False) + pydiffvg.set_device(self.device) + except Exception as e: + print(f"Warning: Could not initialize pydiffvg: {e}") + + def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: + """ + Process the input data and return the generated SVG. - # Create a simple SVG with the prompt text - # In a real implementation, this would use the DiffSketcher model - svg_content = f''' - - {prompt} - DiffSketcher placeholder output - ''' + Args: + data: Dictionary containing: + - inputs: Text prompt for SVG generation + - parameters: Optional parameters like num_paths, num_iter, etc. - return svg_content - - def __call__(self, data): - """Handle a request to the model""" + Returns: + List containing the generated SVG as base64 encoded string + """ try: - logger.info(f"Handling request with data: {data}") - - # Extract the prompt and parameters - if isinstance(data, dict): - if "inputs" in data: - if isinstance(data["inputs"], str): - prompt = data["inputs"] - params = {} - elif isinstance(data["inputs"], dict): - prompt = data["inputs"].get("text", "No prompt provided") - params = {k: v for k, v in data["inputs"].items() if k != "text"} - else: - prompt = "No prompt provided" - params = {} - else: - prompt = "No prompt provided" - params = {} - else: - prompt = "No prompt provided" - params = {} + # Extract inputs + prompt = data.get("inputs", "") + if not prompt: + return [{"error": "No prompt provided"}] - logger.info(f"Extracted prompt: {prompt}") - logger.info(f"Extracted parameters: {params}") + # If dependencies aren't available, return a mock response + if not DEPENDENCIES_AVAILABLE: + mock_svg = f''' + + + Mock SVG for: {prompt} + + ''' + return [{ + "svg": mock_svg, + "svg_base64": base64.b64encode(mock_svg.encode()).decode(), + "prompt": prompt, + "status": "mock_response", + "message": "This is a mock response. Full model not available." + }] # Extract parameters - width = int(params.get("width", 512)) - height = int(params.get("height", 512)) - num_paths = int(params.get("num_paths", 512)) - seed = params.get("seed", None) - if seed is not None: - seed = int(seed) + parameters = data.get("parameters", {}) + num_paths = parameters.get("num_paths", self.cfg.num_paths) + num_iter = parameters.get("num_iter", self.cfg.num_iter) + token_ind = parameters.get("token_ind", self.cfg.token_ind) + guidance_scale = parameters.get("guidance_scale", self.cfg.guidance_scale) + canvas_size = parameters.get("canvas_size", self.cfg.painter.canvas_size) - # Generate SVG - svg_content = self.generate_svg(prompt, width, height, num_paths, seed) - logger.info("SVG content generated") + # For now, return a simple SVG since the full implementation requires + # the complete DiffSketcher pipeline which is complex to set up + simple_svg = f''' + + + + {prompt[:20]}... + + ''' - # Convert SVG to PNG - logger.info("Converting SVG to PNG") - png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8")) - image = Image.open(io.BytesIO(png_data)) - logger.info(f"Converted to PNG with size: {image.size}") - - # Return the image - return image + return [{ + "svg": simple_svg, + "svg_base64": base64.b64encode(simple_svg.encode()).decode(), + "prompt": prompt, + "parameters": { + "num_paths": num_paths, + "num_iter": num_iter, + "token_ind": token_ind, + "guidance_scale": guidance_scale, + "canvas_size": canvas_size + }, + "status": "simplified_response", + "message": "Simplified SVG generated. Full DiffSketcher pipeline requires additional setup." + }] + except Exception as e: - logger.error(f"Error in handler: {e}") - logger.error(traceback.format_exc()) - # Return an error image - error_image = Image.new('RGB', (512, 512), color='red') - return error_image \ No newline at end of file + return [{"error": f"Error during SVG generation: {str(e)}"}] + + +# For testing +if __name__ == "__main__": + handler = EndpointHandler() + test_data = { + "inputs": "a beautiful mountain landscape", + "parameters": { + "num_paths": 48, + "num_iter": 100 + } + } + result = handler(test_data) + print(result) \ No newline at end of file diff --git a/libs/__init__.py b/libs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..47480354ce92364882f0bc5d50415a616a5f7d89 --- /dev/null +++ b/libs/__init__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +# Copyright (c) XiMing Xing. All rights reserved. +# Author: XiMing Xing +# Description: a self consistent system, +# including runner, trainer, loss function, EMA, optimizer, lr scheduler , and common utils. + +from .utils import lazy + +__getattr__, __dir__, __all__ = lazy.attach( + __name__, + submodules={'engine', 'metric', 'modules', 'solver', 'utils'}, + submod_attrs={} +) + +__version__ = '0.0.1' diff --git a/libs/engine/__init__.py b/libs/engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5dacf6f7a946ed5381ec36399cfe29ad3abc36ac --- /dev/null +++ b/libs/engine/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- +# Copyright (c) XiMing Xing. All rights reserved. +# Author: XiMing Xing +# Description: + +from .model_state import ModelState +from .config_processor import merge_and_update_config + +__all__ = [ + 'ModelState', + 'merge_and_update_config' +] diff --git a/libs/engine/config_processor.py b/libs/engine/config_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..293efd80c4919fa5240c835ab983c2d6cae09854 --- /dev/null +++ b/libs/engine/config_processor.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- +# Copyright (c) XiMing Xing. All rights reserved. +# Author: XiMing Xing +# Description: + +import os +from typing import Tuple +from functools import reduce + +from argparse import Namespace +from omegaconf import DictConfig, OmegaConf + + +################################################################################# +# merge yaml and argparse # +################################################################################# + +def register_resolver(): + OmegaConf.register_new_resolver( + "add", lambda *numbers: sum(numbers) + ) + OmegaConf.register_new_resolver( + "multiply", lambda *numbers: reduce(lambda x, y: x * y, numbers) + ) + OmegaConf.register_new_resolver( + "sub", lambda n1, n2: n1 - n2 + ) + + +def _merge_args_and_config( + cmd_args: Namespace, + yaml_config: DictConfig, + read_only: bool = False +) -> Tuple[DictConfig, DictConfig, DictConfig]: + # convert cmd line args to OmegaConf + cmd_args_dict = vars(cmd_args) + cmd_args_list = [] + for k, v in cmd_args_dict.items(): + cmd_args_list.append(f"{k}={v}") + cmd_args_conf = OmegaConf.from_cli(cmd_args_list) + + # The following overrides the previous configuration + # cmd_args_list > configs + args_ = OmegaConf.merge(yaml_config, cmd_args_conf) + + if read_only: + OmegaConf.set_readonly(args_, True) + + return args_, cmd_args_conf, yaml_config + + +def merge_configs(args, method_cfg_path): + """merge command line args (argparse) and config file (OmegaConf)""" + yaml_config_path = os.path.join("./", "config", method_cfg_path) + try: + yaml_config = OmegaConf.load(yaml_config_path) + except FileNotFoundError as e: + print(f"error: {e}") + print(f"input file path: `{method_cfg_path}`") + print(f"config path: `{yaml_config_path}` not found.") + raise FileNotFoundError(e) + return _merge_args_and_config(args, yaml_config, read_only=False) + + +def update_configs(source_args, update_nodes, strict=True, remove_update_nodes=True): + """update config file (OmegaConf) with dotlist""" + if update_nodes is None: + return source_args + + update_args_list = str(update_nodes).split() + if len(update_args_list) < 1: + return source_args + + # check update_args + for item in update_args_list: + item_key_ = str(item).split('=')[0] # get key + # item_val_ = str(item).split('=')[1] # get value + + if strict: + # Tests if a key is existing + # assert OmegaConf.select(source_args, item_key_) is not None, f"{item_key_} is not existing." + + # Tests if a value is missing + assert not OmegaConf.is_missing(source_args, item_key_), f"the value of {item_key_} is missing." + + # if keys is None, then add key and set the value + if OmegaConf.select(source_args, item_key_) is None: + source_args.item_key_ = item_key_ + + # update original yaml params + update_nodes = OmegaConf.from_dotlist(update_args_list) + merged_args = OmegaConf.merge(source_args, update_nodes) + + # remove update_args + if remove_update_nodes: + OmegaConf.update(merged_args, 'update', '') + return merged_args + + +def update_if_exist(source_args, update_nodes): + """update config file (OmegaConf) with dotlist""" + if update_nodes is None: + return source_args + + upd_args_list = str(update_nodes).split() + if len(upd_args_list) < 1: + return source_args + + update_args_list = [] + for item in upd_args_list: + item_key_ = str(item).split('=')[0] # get key + + # if a key is existing + # if OmegaConf.select(source_args, item_key_) is not None: + # update_args_list.append(item) + + update_args_list.append(item) + + # update source_args if key be selected + if len(update_args_list) < 1: + merged_args = source_args + else: + update_nodes = OmegaConf.from_dotlist(update_args_list) + merged_args = OmegaConf.merge(source_args, update_nodes) + + return merged_args + + +def merge_and_update_config(args): + register_resolver() + + # if yaml_config is existing, then merge command line args and yaml_config + # if os.path.isfile(args.config) and args.config is not None: + if args.config is not None and str(args.config).endswith('.yaml'): + merged_args, cmd_args, yaml_config = merge_configs(args, args.config) + else: + merged_args, cmd_args, yaml_config = args, args, None + + # update the yaml_config with the cmd '-update' flag + update_nodes = args.update + final_args = update_configs(merged_args, update_nodes) + + # to simplify log output, we empty this + yaml_config_update = update_if_exist(yaml_config, update_nodes) + cmd_args_update = update_if_exist(cmd_args, update_nodes) + cmd_args_update.update = "" # clear update params + + final_args.yaml_config = yaml_config_update + final_args.cmd_args = cmd_args_update + + # update seed + if final_args.seed < 0: + import random + final_args.seed = random.randint(0, 65535) + + return final_args diff --git a/libs/engine/model_state.py b/libs/engine/model_state.py new file mode 100644 index 0000000000000000000000000000000000000000..54d0832094db5ff121dc891b42120edcae4bc8db --- /dev/null +++ b/libs/engine/model_state.py @@ -0,0 +1,339 @@ +# -*- coding: utf-8 -*- +# Copyright (c) XiMing Xing. All rights reserved. +# Author: XiMing Xing +# Description: +import gc +from functools import partial +from typing import Union, List +from pathlib import Path +from datetime import datetime, timedelta + +from omegaconf import DictConfig +from pprint import pprint +import torch +from accelerate.utils import LoggerType +from accelerate import ( + Accelerator, + GradScalerKwargs, + DistributedDataParallelKwargs, + InitProcessGroupKwargs +) + +from ..modules.ema import EMA +from ..utils.logging import get_logger + + +class ModelState: + """ + Handling logger and `hugging face` accelerate training + + features: + - Mixed Precision + - Gradient Scaler + - Gradient Accumulation + - Optimizer + - EMA + - Logger (default: python print) + - Monitor (default: wandb, tensorboard) + """ + + def __init__( + self, + args, + log_path_suffix: str = None, + ignore_log=False, # whether to create log file or not + ) -> None: + self.args: DictConfig = args + + """check valid""" + mixed_precision = self.args.get("mixed_precision") + # Bug: omegaconf convert 'no' to false + mixed_precision = "no" if type(mixed_precision) == bool else mixed_precision + split_batches = self.args.get("split_batches", False) + gradient_accumulate_step = self.args.get("gradient_accumulate_step", 1) + assert gradient_accumulate_step >= 1, f"except gradient_accumulate_step >= 1, get {gradient_accumulate_step}" + + """create working space""" + # rule: ['./config'. 'method_name', 'exp_name.yaml'] + # -> results_path: ./runs/{method_name}-{exp_name}, as a base folder + # config_prefix, config_name = str(self.args.get("config")).split('/') + # config_name_only = str(config_name).split(".")[0] + + config_name_only = str(self.args.get("config")).split(".")[0] + results_folder = self.args.get("results_path", None) + if results_folder is None: + # self.results_path = Path("./workdir") / f"{config_prefix}-{config_name_only}" + self.results_path = Path("./workdir") / f"{config_name_only}" + else: + # self.results_path = Path(results_folder) / f"{config_prefix}-{config_name_only}" + self.results_path = Path(results_folder) / f"{config_name_only}" + + # update results_path: ./runs/{method_name}-{exp_name}/{log_path_suffix} + # noting: can be understood as "results dir / methods / ablation study / your result" + if log_path_suffix is not None: + self.results_path = self.results_path / log_path_suffix + + kwargs_handlers = [] + """mixed precision training""" + if args.mixed_precision == "no": + scaler_handler = GradScalerKwargs( + init_scale=args.init_scale, + growth_factor=args.growth_factor, + backoff_factor=args.backoff_factor, + growth_interval=args.growth_interval, + enabled=True + ) + kwargs_handlers.append(scaler_handler) + + """distributed training""" + ddp_handler = DistributedDataParallelKwargs( + dim=0, + broadcast_buffers=True, + static_graph=False, + bucket_cap_mb=25, + find_unused_parameters=False, + check_reduction=False, + gradient_as_bucket_view=False + ) + kwargs_handlers.append(ddp_handler) + + init_handler = InitProcessGroupKwargs(timeout=timedelta(seconds=1200)) + kwargs_handlers.append(init_handler) + + """init visualized tracker""" + log_with = [] + self.args.visual = False + if args.use_wandb: + log_with.append(LoggerType.WANDB) + if args.tensorboard: + log_with.append(LoggerType.TENSORBOARD) + + """hugging face Accelerator""" + self.accelerator = Accelerator( + device_placement=True, + split_batches=split_batches, + mixed_precision=mixed_precision, + gradient_accumulation_steps=args.gradient_accumulate_step, + cpu=True if args.use_cpu else False, + log_with=None if len(log_with) == 0 else log_with, + project_dir=self.results_path / "vis", + kwargs_handlers=kwargs_handlers, + ) + + """logs""" + if self.accelerator.is_local_main_process: + # for logging results in a folder periodically + self.results_path.mkdir(parents=True, exist_ok=True) + if not ignore_log: + now_time = datetime.now().strftime('%Y-%m-%d-%H-%M') + self.logger = get_logger( + logs_dir=self.results_path.as_posix(), + file_name=f"{now_time}-log-{args.seed}.txt" + ) + + print("==> command line args: ") + print(args.cmd_args) + print("==> yaml config args: ") + print(args.yaml_config) + + print("\n***** Model State *****") + if self.accelerator.distributed_type != "NO": + print(f"-> Distributed Type: {self.accelerator.distributed_type}") + # print(f"-> Split Batch Size: {split_batches}, Total Batch Size: {self.actual_batch_size}") + print(f"-> Mixed Precision: {mixed_precision}, AMP: {self.accelerator.native_amp}," + f" Gradient Accumulate Step: {gradient_accumulate_step}") + print(f"-> Weight dtype: {self.weight_dtype}") + + if self.accelerator.scaler_handler is not None and self.accelerator.scaler_handler.enabled: + print(f"-> Enabled GradScaler: {self.accelerator.scaler_handler.to_kwargs()}") + + if args.use_wandb: + print(f"-> Init trackers: 'wandb' ") + self.args.visual = True + self.__init_tracker(project_name="my_project", tags=None, entity="") + + print(f"-> Working Space: '{self.results_path}'") + + """EMA""" + self.use_ema = args.get('ema', False) + self.ema_wrapper = self.__build_ema_wrapper() + + """glob step""" + self.step = 0 + + """log process""" + self.accelerator.wait_for_everyone() + print(f'Process {self.accelerator.process_index} using device: {self.accelerator.device}') + + self.print("-> state initialization complete \n") + + def __init_tracker(self, project_name, tags, entity): + self.accelerator.init_trackers( + project_name=project_name, + config=dict(self.args), + init_kwargs={ + "wandb": { + "notes": "accelerate trainer pipeline", + "tags": [ + f"total batch_size: {self.actual_batch_size}" + ], + "entity": entity, + }} + ) + + def __build_ema_wrapper(self): + if self.use_ema: + self.print(f"-> EMA: {self.use_ema}, decay: {self.args.ema_decay}, " + f"update_after_step: {self.args.ema_update_after_step}, " + f"update_every: {self.args.ema_update_every}") + ema_wrapper = partial( + EMA, beta=self.args.ema_decay, + update_after_step=self.args.ema_update_after_step, + update_every=self.args.ema_update_every + ) + else: + ema_wrapper = None + + return ema_wrapper + + @property + def device(self): + return self.accelerator.device + + @property + def weight_dtype(self): + weight_dtype = torch.float32 + if self.accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif self.accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + return weight_dtype + + @property + def actual_batch_size(self): + if self.accelerator.split_batches is False: + actual_batch_size = self.args.batch_size * self.accelerator.num_processes * self.accelerator.gradient_accumulation_steps + else: + assert self.actual_batch_size % self.accelerator.num_processes == 0 + actual_batch_size = self.args.batch_size + return actual_batch_size + + @property + def n_gpus(self): + return self.accelerator.num_processes + + @property + def no_decay_params_names(self): + no_decay = [ + "bn", "LayerNorm", "GroupNorm", + ] + return no_decay + + def no_decay_params(self, model, weight_decay): + """optimization tricks""" + optimizer_grouped_parameters = [ + { + "params": [ + p for n, p in model.named_parameters() + if not any(nd in n for nd in self.no_decay_params_names) + ], + "weight_decay": weight_decay, + }, + { + "params": [ + p for n, p in model.named_parameters() + if any(nd in n for nd in self.no_decay_params_names) + ], + "weight_decay": 0.0, + }, + ] + return optimizer_grouped_parameters + + def optimized_params(self, model: torch.nn.Module, verbose=True) -> List: + """return parameters if `requires_grad` is True + + Args: + model: pytorch models + verbose: log optimized parameters + + Examples: + >>> self.params_optimized = self.optimized_params(uvit, verbose=True) + >>> optimizer = torch.optim.AdamW(self.params_optimized, lr=args.lr) + + Returns: + a list of parameters + """ + params_optimized = [] + for key, value in model.named_parameters(): + if value.requires_grad: + params_optimized.append(value) + if verbose: + self.print("\t {}, {}, {}".format(key, value.numel(), value.shape)) + return params_optimized + + def save_everything(self, fpath: str): + """Saving and loading the model, optimizer, RNG generators, and the GradScaler.""" + if not self.accelerator.is_main_process: + return + self.accelerator.save_state(fpath) + + def load_save_everything(self, fpath: str): + """Loading the model, optimizer, RNG generators, and the GradScaler.""" + self.accelerator.load_state(fpath) + + def save(self, milestone: Union[str, float, int], checkpoint: object) -> None: + if not self.accelerator.is_main_process: + return + + torch.save(checkpoint, self.results_path / f'model-{milestone}.pt') + + def save_in(self, root: Union[str, Path], checkpoint: object) -> None: + if not self.accelerator.is_main_process: + return + + torch.save(checkpoint, root) + + def load_ckpt_model_only(self, model: torch.nn.Module, path: Union[str, Path], rm_module_prefix: bool = False): + ckpt = torch.load(path, map_location=self.accelerator.device) + + unwrapped_model = self.accelerator.unwrap_model(model) + if rm_module_prefix: + unwrapped_model.load_state_dict({k.replace('module.', ''): v for k, v in ckpt.items()}) + else: + unwrapped_model.load_state_dict(ckpt) + return unwrapped_model + + def load_shared_weights(self, model: torch.nn.Module, path: Union[str, Path]): + ckpt = torch.load(path, map_location=self.accelerator.device) + self.print(f"pretrained_dict len: {len(ckpt)}") + unwrapped_model = self.accelerator.unwrap_model(model) + model_dict = unwrapped_model.state_dict() + pretrained_dict = {k: v for k, v in ckpt.items() if k in model_dict} + model_dict.update(pretrained_dict) + unwrapped_model.load_state_dict(model_dict, strict=False) + self.print(f"selected pretrained_dict: {len(model_dict)}") + return unwrapped_model + + def print(self, *args, **kwargs): + """Use in replacement of `print()` to only print once per server.""" + self.accelerator.print(*args, **kwargs) + + def pretty_print(self, msg): + if self.accelerator.is_local_main_process: + pprint(dict(msg)) + + def close_tracker(self): + self.accelerator.end_training() + + def free_memory(self): + self.accelerator.clear() + + def close(self, msg: str = "Training complete."): + """Use in end of training.""" + self.free_memory() + + if torch.cuda.is_available(): + self.print(f'\nGPU memory usage: {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB') + if self.args.visual: + self.close_tracker() + self.print(msg) diff --git a/libs/metric/__init__.py b/libs/metric/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2ad761f2f5443eb41b15afc4116a66ecdfa9d918 --- /dev/null +++ b/libs/metric/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- +# Copyright (c) XiMing Xing. All rights reserved. +# Author: XiMing Xing +# Description: diff --git a/libs/metric/accuracy.py b/libs/metric/accuracy.py new file mode 100644 index 0000000000000000000000000000000000000000..96ab47b38acde0e0e2f4cb6d924a90e9daaad681 --- /dev/null +++ b/libs/metric/accuracy.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- +# Copyright (c) XiMing Xing. All rights reserved. +# Author: XiMing Xing +# Description: + + +def accuracy(output, target, topk=(1,)): + """ + Computes the accuracy over the k top predictions for the specified values of k. + + Args + output: logits or probs (num of batch, num of classes) + target: (num of batch, 1) or (num of batch, ) + topk: list of returned k + + refer: https://github.com/pytorch/examples/blob/master/imagenet/main.py + """ + maxK = max(topk) # get k in top-k + batch_size = target.size(0) + + _, pred = output.topk(k=maxK, dim=1, largest=True, sorted=True) # pred: [num of batch, k] + pred = pred.t() # pred: [k, num of batch] + + # [1, num of batch] -> [k, num_of_batch] : bool + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res # np.shape(res): [k, 1] diff --git a/libs/metric/clip_score/__init__.py b/libs/metric/clip_score/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ea7239bd179c2f1b9cefaf7e47d3aff49fe3dde7 --- /dev/null +++ b/libs/metric/clip_score/__init__.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- +# Copyright (c) XiMing Xing. All rights reserved. +# Author: XiMing Xing +# Description: + +from .openaiCLIP_loss import CLIPScoreWrapper + +__all__ = ['CLIPScoreWrapper'] diff --git a/libs/metric/clip_score/openaiCLIP_loss.py b/libs/metric/clip_score/openaiCLIP_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..3893bfcaef3697274d5064bc9d5b5124a18df399 --- /dev/null +++ b/libs/metric/clip_score/openaiCLIP_loss.py @@ -0,0 +1,305 @@ +# -*- coding: utf-8 -*- +# Copyright (c) XiMing Xing. All rights reserved. +# Author: XiMing Xing +# Description: + +from typing import Union, List, Tuple +from collections import OrderedDict +from functools import partial + +import numpy as np +import torch +import torch.nn as nn +import torchvision.transforms as transforms + + +class CLIPScoreWrapper(nn.Module): + + def __init__(self, + clip_model_name: str, + download_root: str = None, + device: torch.device = "cuda" if torch.cuda.is_available() else "cpu", + jit: bool = False, + # additional params + visual_score: bool = False, + feats_loss_type: str = None, + feats_loss_weights: List[float] = None, + fc_loss_weight: float = None, + context_length: int = 77): + super().__init__() + + import clip # local import + + # check model info + self.clip_model_name = clip_model_name + self.device = device + self.available_models = clip.available_models() + assert clip_model_name in self.available_models, f"A model backbone: {clip_model_name} that does not exist" + + # load CLIP + self.model, self.preprocess = clip.load(clip_model_name, device, jit=jit, download_root=download_root) + self.model.eval() + + # load tokenize + self.tokenize_fn = partial(clip.tokenize, context_length=context_length) + + # load CLIP visual + self.visual_encoder = VisualEncoderWrapper(self.model, clip_model_name).to(device) + self.visual_encoder.eval() + + # check loss weights + self.visual_score = visual_score + if visual_score: + assert feats_loss_type in ["l1", "l2", "cosine"], f"{feats_loss_type} is not exist." + if clip_model_name.startswith("ViT"): assert len(feats_loss_weights) == 12 + if clip_model_name.startswith("RN"): assert len(feats_loss_weights) == 5 + + # load visual loss wrapper + self.visual_loss_fn = CLIPVisualLossWrapper(self.visual_encoder, feats_loss_type, + feats_loss_weights, + fc_loss_weight) + + @property + def input_resolution(self): + return self.model.visual.input_resolution # default: 224 + + @property + def resize(self): # Resize only + return transforms.Compose([self.preprocess.transforms[0]]) + + @property + def normalize(self): + return transforms.Compose([ + self.preprocess.transforms[0], # Resize + self.preprocess.transforms[1], # CenterCrop + self.preprocess.transforms[-1], # Normalize + ]) + + @property + def norm_(self): # Normalize only + return transforms.Compose([self.preprocess.transforms[-1]]) + + def encode_image_layer_wise(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: + semantic_vec, feature_maps = self.visual_encoder(x) + return semantic_vec, feature_maps + + def encode_text(self, text: Union[str, List[str]], norm: bool = True) -> torch.Tensor: + tokens = self.tokenize_fn(text).to(self.device) + text_features = self.model.encode_text(tokens) + if norm: + text_features = text_features.mean(axis=0, keepdim=True) + text_features_norm = text_features / text_features.norm(dim=-1, keepdim=True) + return text_features_norm + return text_features + + def encode_image(self, image: torch.Tensor, norm: bool = True) -> torch.Tensor: + image_features = self.model.encode_image(image) + if norm: + image_features_norm = image_features / image_features.norm(dim=-1, keepdim=True) + return image_features_norm + return image_features + + @torch.no_grad() + def predict(self, + image: torch.Tensor, + text: Union[str, List[str]]) -> Tuple[torch.Tensor, torch.Tensor, np.ndarray]: + image_features = self.model.encode_image(image) + text_tokenize = self.tokenize_fn(text).to(self.device) + text_features = self.model.encode_text(text_tokenize) + logits_per_image, logits_per_text = self.model(image, text) + probs = logits_per_image.softmax(dim=-1).cpu().numpy() + return image_features, text_features, probs + + def compute_text_visual_distance( + self, image: torch.Tensor, text: Union[str, List[str]] + ) -> torch.Tensor: + image_features = self.model.encode_image(image) + text_tokenize = self.tokenize_fn(text).to(self.device) + text_features = self.model.encode_text(text_tokenize) + text_features = text_features.to(self.device) + + image_features_norm = image_features / image_features.norm(dim=-1, keepdim=True) + text_features_norm = text_features / text_features.norm(dim=-1, keepdim=True) + loss = - (image_features_norm @ text_features_norm.T) + return loss.mean() + + def directional_loss(self, src_text, src_img, tar_text, tar_img, thresh=None): + # delta img + img_direction = (tar_img - src_img) + img_direction_norm = img_direction / img_direction.norm(dim=-1, keepdim=True) + # # delta text + text_direction = (1 * tar_text - src_text).repeat(tar_img.size(0), 1) + text_direction_norm = text_direction / text_direction.norm(dim=-1, keepdim=True) + # Directional CLIP Loss + loss_dir = (1 - torch.cosine_similarity(img_direction_norm, text_direction_norm, dim=1)) + if thresh is not None: + loss_dir[loss_dir < thresh] = 0 # set value=0 when lt 0 + loss_dir = loss_dir.mean() + return loss_dir + else: + return loss_dir.mean() + + def compute_visual_distance( + self, x: torch.Tensor, y: torch.Tensor, clip_norm: bool = True, + ) -> Tuple[torch.Tensor, List]: + # return a fc loss and the list of feat loss + assert self.visual_score is True + assert x.shape == y.shape + assert x.shape[-1] == self.input_resolution and x.shape[-2] == self.input_resolution + assert y.shape[-1] == self.input_resolution and y.shape[-2] == self.input_resolution + + if clip_norm: + return self.visual_loss_fn(self.normalize(x), self.normalize(y)) + else: + return self.visual_loss_fn(x, y) + + +class VisualEncoderWrapper(nn.Module): + """ + semantic features and layer by layer feature maps are obtained from CLIP visual encoder. + """ + + def __init__(self, clip_model: nn.Module, clip_model_name: str): + super().__init__() + self.clip_model = clip_model + self.clip_model_name = clip_model_name + + if clip_model_name.startswith("ViT"): + self.feature_maps = OrderedDict() + for i in range(12): # 12 ResBlocks in ViT visual transformer + self.clip_model.visual.transformer.resblocks[i].register_forward_hook( + self.make_hook(i) + ) + + if clip_model_name.startswith("RN"): + layers = list(self.clip_model.visual.children()) + init_layers = torch.nn.Sequential(*layers)[:8] + self.layer1 = layers[8] + self.layer2 = layers[9] + self.layer3 = layers[10] + self.layer4 = layers[11] + self.att_pool2d = layers[12] + + def make_hook(self, name): + def hook(module, input, output): + if len(output.shape) == 3: + # LND -> NLD (B, 77, 768) + self.feature_maps[name] = output.permute(1, 0, 2) + else: + self.feature_maps[name] = output + + return hook + + def _forward_vit(self, x: torch.Tensor) -> Tuple[torch.Tensor, List]: + fc_feature = self.clip_model.encode_image(x).float() + feature_maps = [self.feature_maps[k] for k in range(12)] + + # fc_feature len: 1 ,feature_maps len: 12 + return fc_feature, feature_maps + + def _forward_resnet(self, x: torch.Tensor) -> Tuple[torch.Tensor, List]: + def stem(m, x): + for conv, bn, relu in [(m.conv1, m.bn1, m.relu1), (m.conv2, m.bn2, m.relu2), (m.conv3, m.bn3, m.relu3)]: + x = torch.relu(bn(conv(x))) + x = m.avgpool(x) + return x + + x = x.type(self.clip_model.visual.conv1.weight.dtype) + x = stem(self.clip_model.visual, x) + x1 = self.layer1(x) + x2 = self.layer2(x1) + x3 = self.layer3(x2) + x4 = self.layer4(x3) + y = self.att_pool2d(x4) + + # fc_features len: 1 ,feature_maps len: 5 + return y, [x, x1, x2, x3, x4] + + def forward(self, x) -> Tuple[torch.Tensor, List[torch.Tensor]]: + if self.clip_model_name.startswith("ViT"): + fc_feat, visual_feat_maps = self._forward_vit(x) + if self.clip_model_name.startswith("RN"): + fc_feat, visual_feat_maps = self._forward_resnet(x) + + return fc_feat, visual_feat_maps + + +class CLIPVisualLossWrapper(nn.Module): + """ + Visual Feature Loss + FC loss + """ + + def __init__( + self, + visual_encoder: nn.Module, + feats_loss_type: str = None, + feats_loss_weights: List[float] = None, + fc_loss_weight: float = None, + ): + super().__init__() + self.visual_encoder = visual_encoder + self.feats_loss_weights = feats_loss_weights + self.fc_loss_weight = fc_loss_weight + + self.layer_criterion = layer_wise_distance(feats_loss_type) + + def forward(self, x: torch.Tensor, y: torch.Tensor): + x_fc_feature, x_feat_maps = self.visual_encoder(x) + y_fc_feature, y_feat_maps = self.visual_encoder(y) + + # visual feature loss + if sum(self.feats_loss_weights) == 0: + feats_loss_list = [torch.tensor(0, device=x.device)] + else: + feats_loss = self.layer_criterion(x_feat_maps, y_feat_maps, self.visual_encoder.clip_model_name) + feats_loss_list = [] + for layer, w in enumerate(self.feats_loss_weights): + if w: + feats_loss_list.append(feats_loss[layer] * w) + + # visual fc loss, default: cosine similarity + if self.fc_loss_weight == 0: + fc_loss = torch.tensor(0, device=x.device) + else: + fc_loss = (1 - torch.cosine_similarity(x_fc_feature, y_fc_feature, dim=1)).mean() + fc_loss = fc_loss * self.fc_loss_weight + + return fc_loss, feats_loss_list + + +################################################################################# +# layer wise metric # +################################################################################# + +def layer_wise_distance(metric_name: str): + return { + "l1": l1_layer_wise, + "l2": l2_layer_wise, + "cosine": cosine_layer_wise + }.get(metric_name.lower()) + + +def l2_layer_wise(x_features, y_features, clip_model_name): + return [ + torch.square(x_conv - y_conv).mean() + for x_conv, y_conv in zip(x_features, y_features) + ] + + +def l1_layer_wise(x_features, y_features, clip_model_name): + return [ + torch.abs(x_conv - y_conv).mean() + for x_conv, y_conv in zip(x_features, y_features) + ] + + +def cosine_layer_wise(x_features, y_features, clip_model_name): + if clip_model_name.startswith("RN"): + return [ + (1 - torch.cosine_similarity(x_conv, y_conv, dim=1)).mean() + for x_conv, y_conv in zip(x_features, y_features) + ] + return [ + (1 - torch.cosine_similarity(x_conv, y_conv, dim=1)).mean() + for x_conv, y_conv in zip(x_features, y_features) + ] diff --git a/libs/metric/lpips_origin/__init__.py b/libs/metric/lpips_origin/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8cb4332c66e4725112fad88db377e3a50471f3ee --- /dev/null +++ b/libs/metric/lpips_origin/__init__.py @@ -0,0 +1,3 @@ +from .lpips import LPIPS + +__all__ = ['LPIPS'] diff --git a/libs/metric/lpips_origin/lpips.py b/libs/metric/lpips_origin/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..fa97aad10e2b18fa5b92b913e9def4f5135a83c0 --- /dev/null +++ b/libs/metric/lpips_origin/lpips.py @@ -0,0 +1,184 @@ +from __future__ import absolute_import + +import os + +import torch +import torch.nn as nn + +from . import pretrained_networks as pretrained_torch_models + + +def spatial_average(x, keepdim=True): + return x.mean([2, 3], keepdim=keepdim) + + +def upsample(x): + return nn.Upsample(size=x.shape[2:], mode='bilinear', align_corners=False)(x) + + +def normalize_tensor(in_feat, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True)) + return in_feat / (norm_factor + eps) + + +# Learned perceptual metric +class LPIPS(nn.Module): + + def __init__(self, + pretrained=True, + net='alex', + version='0.1', + lpips=True, + spatial=False, + pnet_rand=False, + pnet_tune=False, + use_dropout=True, + model_path=None, + eval_mode=True, + verbose=True): + """ Initializes a perceptual loss torch.nn.Module + + Parameters (default listed first) + --------------------------------- + lpips : bool + [True] use linear layers on top of base/trunk network + [False] means no linear layers; each layer is averaged together + pretrained : bool + This flag controls the linear layers, which are only in effect when lpips=True above + [True] means linear layers are calibrated with human perceptual judgments + [False] means linear layers are randomly initialized + pnet_rand : bool + [False] means trunk loaded with ImageNet classification weights + [True] means randomly initialized trunk + net : str + ['alex','vgg','squeeze'] are the base/trunk networks available + version : str + ['v0.1'] is the default and latest + ['v0.0'] contained a normalization bug; corresponds to old arxiv v1 (https://arxiv.org/abs/1801.03924v1) + model_path : 'str' + [None] is default and loads the pretrained weights from paper https://arxiv.org/abs/1801.03924v1 + + The following parameters should only be changed if training the network: + + eval_mode : bool + [True] is for test mode (default) + [False] is for training mode + pnet_tune + [False] keep base/trunk frozen + [True] tune the base/trunk network + use_dropout : bool + [True] to use dropout when training linear layers + [False] for no dropout when training linear layers + """ + super(LPIPS, self).__init__() + if verbose: + print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]' % + ('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off')) + + self.pnet_type = net + self.pnet_tune = pnet_tune + self.pnet_rand = pnet_rand + self.spatial = spatial + self.lpips = lpips # false means baseline of just averaging all layers + self.version = version + self.scaling_layer = ScalingLayer() + + if self.pnet_type in ['vgg', 'vgg16']: + net_type = pretrained_torch_models.vgg16 + self.chns = [64, 128, 256, 512, 512] + elif self.pnet_type == 'alex': + net_type = pretrained_torch_models.alexnet + self.chns = [64, 192, 384, 256, 256] + elif self.pnet_type == 'squeeze': + net_type = pretrained_torch_models.squeezenet + self.chns = [64, 128, 256, 384, 384, 512, 512] + self.L = len(self.chns) + + self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) + + if lpips: + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + if self.pnet_type == 'squeeze': # 7 layers for squeezenet + self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) + self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) + self.lins += [self.lin5, self.lin6] + self.lins = nn.ModuleList(self.lins) + + if pretrained: + if model_path is None: + model_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + f"weights/v{version}/{net}.pth" + ) + if verbose: + print('Loading model from: %s' % model_path) + self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False) + + if eval_mode: + self.eval() + + def forward(self, in0, in1, return_per_layer=False, normalize=False): + if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, 1] + in0 = 2 * in0 - 1 + in1 = 2 * in1 - 1 + + # Noting: v0.0 - original release had a bug, where input was not scaled + if self.version == '0.1': + in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) + else: + in0_input, in1_input = in0, in1 + + # model forward + outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) + + feats0, feats1, diffs = {}, {}, {} + for kk in range(self.L): + feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + if self.lpips: + if self.spatial: + res = [upsample(self.lins[kk](diffs[kk])) for kk in range(self.L)] + else: + res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)] + else: + if self.spatial: + res = [upsample(diffs[kk].sum(dim=1, keepdim=True)) for kk in range(self.L)] + else: + res = [spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) for kk in range(self.L)] + + loss = sum(res) + + if return_per_layer: + return loss, res + else: + return loss + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """A single linear layer which does a 1x1 conv""" + + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + + layers = [nn.Dropout(), ] if (use_dropout) else [] + layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] + self.model = nn.Sequential(*layers) + + def forward(self, x): + return self.model(x) diff --git a/libs/metric/lpips_origin/pretrained_networks.py b/libs/metric/lpips_origin/pretrained_networks.py new file mode 100644 index 0000000000000000000000000000000000000000..484b808da02eecb59c132e63a0fe4ae90b1e4d2e --- /dev/null +++ b/libs/metric/lpips_origin/pretrained_networks.py @@ -0,0 +1,196 @@ +from collections import namedtuple + +import torch +import torchvision.models as tv_models + + +class squeezenet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(squeezenet, self).__init__() + pretrained_features = tv_models.squeezenet1_1(weights=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.slice6 = torch.nn.Sequential() + self.slice7 = torch.nn.Sequential() + self.N_slices = 7 + for x in range(2): + self.slice1.add_module(str(x), pretrained_features[x]) + for x in range(2, 5): + self.slice2.add_module(str(x), pretrained_features[x]) + for x in range(5, 8): + self.slice3.add_module(str(x), pretrained_features[x]) + for x in range(8, 10): + self.slice4.add_module(str(x), pretrained_features[x]) + for x in range(10, 11): + self.slice5.add_module(str(x), pretrained_features[x]) + for x in range(11, 12): + self.slice6.add_module(str(x), pretrained_features[x]) + for x in range(12, 13): + self.slice7.add_module(str(x), pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1 = h + h = self.slice2(h) + h_relu2 = h + h = self.slice3(h) + h_relu3 = h + h = self.slice4(h) + h_relu4 = h + h = self.slice5(h) + h_relu5 = h + h = self.slice6(h) + h_relu6 = h + h = self.slice7(h) + h_relu7 = h + vgg_outputs = namedtuple("SqueezeOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5', 'relu6', 'relu7']) + out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7) + + return out + + +class alexnet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(alexnet, self).__init__() + weights = tv_models.AlexNet_Weights.IMAGENET1K_V1 if pretrained else None + alexnet_pretrained_features = tv_models.alexnet(weights=weights).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(2): + self.slice1.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(2, 5): + self.slice2.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(5, 8): + self.slice3.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(8, 10): + self.slice4.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(10, 12): + self.slice5.add_module(str(x), alexnet_pretrained_features[x]) + + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1 = h + h = self.slice2(h) + h_relu2 = h + h = self.slice3(h) + h_relu3 = h + h = self.slice4(h) + h_relu4 = h + h = self.slice5(h) + h_relu5 = h + alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) + out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) + + return out + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + weights = tv_models.VGG16_Weights.IMAGENET1K_V1 if pretrained else None + vgg_pretrained_features = tv_models.vgg16(weights=weights).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + + return out + + +class resnet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True, num=18): + super(resnet, self).__init__() + + if num == 18: + weights = tv_models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None + self.net = tv_models.resnet18(weights=weights) + elif num == 34: + weights = tv_models.ResNet34_Weights.IMAGENET1K_V1 if pretrained else None + self.net = tv_models.resnet34(weights=weights) + elif num == 50: + weights = tv_models.ResNet50_Weights.IMAGENET1K_V2 if pretrained else None + self.net = tv_models.resnet50(weights=weights) + elif num == 101: + weights = tv_models.ResNet101_Weights.IMAGENET1K_V2 if pretrained else None + self.net = tv_models.resnet101(weights=weights) + elif num == 152: + weights = tv_models.ResNet152_Weights.IMAGENET1K_V2 if pretrained else None + self.net = tv_models.resnet152(weights=weights) + self.N_slices = 5 + + if not requires_grad: + for param in self.net.parameters(): + param.requires_grad = False + + self.conv1 = self.net.conv1 + self.bn1 = self.net.bn1 + self.relu = self.net.relu + self.maxpool = self.net.maxpool + self.layer1 = self.net.layer1 + self.layer2 = self.net.layer2 + self.layer3 = self.net.layer3 + self.layer4 = self.net.layer4 + + def forward(self, X): + h = self.conv1(X) + h = self.bn1(h) + h = self.relu(h) + h_relu1 = h + h = self.maxpool(h) + h = self.layer1(h) + h_conv2 = h + h = self.layer2(h) + h_conv3 = h + h = self.layer3(h) + h_conv4 = h + h = self.layer4(h) + h_conv5 = h + + outputs = namedtuple("Outputs", ['relu1', 'conv2', 'conv3', 'conv4', 'conv5']) + out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) + + return out diff --git a/libs/metric/lpips_origin/weights/v0.1/alex.pth b/libs/metric/lpips_origin/weights/v0.1/alex.pth new file mode 100644 index 0000000000000000000000000000000000000000..fa4067abc5d4da16a7204fd94776506e4868030e --- /dev/null +++ b/libs/metric/lpips_origin/weights/v0.1/alex.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:df73285e35b22355a2df87cdb6b70b343713b667eddbda73e1977e0c860835c0 +size 6009 diff --git a/libs/metric/lpips_origin/weights/v0.1/squeeze.pth b/libs/metric/lpips_origin/weights/v0.1/squeeze.pth new file mode 100644 index 0000000000000000000000000000000000000000..f892a84a130828b1c9e2e8156e84fc5a962c665d --- /dev/null +++ b/libs/metric/lpips_origin/weights/v0.1/squeeze.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4a5350f23600cb79923ce65bb07cbf57dca461329894153e05a1346bd531cf76 +size 10811 diff --git a/libs/metric/lpips_origin/weights/v0.1/vgg.pth b/libs/metric/lpips_origin/weights/v0.1/vgg.pth new file mode 100644 index 0000000000000000000000000000000000000000..f57dcf5cc764d61c8a460365847fb2137ff0a62d --- /dev/null +++ b/libs/metric/lpips_origin/weights/v0.1/vgg.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868 +size 7289 diff --git a/libs/metric/piq/__init__.py b/libs/metric/piq/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..54f5e8e115957584d775c4ff25aec65a4a17085e --- /dev/null +++ b/libs/metric/piq/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- +# Copyright (c) XiMing Xing. All rights reserved. +# Author: XiMing Xing +# Description: + +# install: pip install piq +# repo: https://github.com/photosynthesis-team/piq diff --git a/libs/metric/piq/functional/__init__.py b/libs/metric/piq/functional/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..131231bd3249723d3c21d98c45109ff99fb19612 --- /dev/null +++ b/libs/metric/piq/functional/__init__.py @@ -0,0 +1,15 @@ +from .base import ifftshift, get_meshgrid, similarity_map, gradient_map, pow_for_complex, crop_patches +from .colour_conversion import rgb2lmn, rgb2xyz, xyz2lab, rgb2lab, rgb2yiq, rgb2lhm +from .filters import haar_filter, hann_filter, scharr_filter, prewitt_filter, gaussian_filter +from .filters import binomial_filter1d, average_filter2d +from .layers import L2Pool2d +from .resize import imresize + +__all__ = [ + 'ifftshift', 'get_meshgrid', 'similarity_map', 'gradient_map', 'pow_for_complex', 'crop_patches', + 'rgb2lmn', 'rgb2xyz', 'xyz2lab', 'rgb2lab', 'rgb2yiq', 'rgb2lhm', + 'haar_filter', 'hann_filter', 'scharr_filter', 'prewitt_filter', 'gaussian_filter', + 'binomial_filter1d', 'average_filter2d', + 'L2Pool2d', + 'imresize', +] diff --git a/libs/metric/piq/functional/base.py b/libs/metric/piq/functional/base.py new file mode 100644 index 0000000000000000000000000000000000000000..a8d34790ad89b225fd28ade34507c498a315fb6d --- /dev/null +++ b/libs/metric/piq/functional/base.py @@ -0,0 +1,111 @@ +r"""General purpose functions""" +from typing import Tuple, Union, Optional +import torch +from ..utils import _parse_version + + +def ifftshift(x: torch.Tensor) -> torch.Tensor: + r""" Similar to np.fft.ifftshift but applies to PyTorch Tensors""" + shift = [-(ax // 2) for ax in x.size()] + return torch.roll(x, shift, tuple(range(len(shift)))) + + +def get_meshgrid(size: Tuple[int, int], device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor: + r"""Return coordinate grid matrices centered at zero point. + Args: + size: Shape of meshgrid to create + device: device to use for creation + dtype: dtype to use for creation + Returns: + Meshgrid of size on device with dtype values. + """ + if size[0] % 2: + # Odd + x = torch.arange(-(size[0] - 1) / 2, size[0] / 2, device=device, dtype=dtype) / (size[0] - 1) + else: + # Even + x = torch.arange(- size[0] / 2, size[0] / 2, device=device, dtype=dtype) / size[0] + + if size[1] % 2: + # Odd + y = torch.arange(-(size[1] - 1) / 2, size[1] / 2, device=device, dtype=dtype) / (size[1] - 1) + else: + # Even + y = torch.arange(- size[1] / 2, size[1] / 2, device=device, dtype=dtype) / size[1] + # Use indexing param depending on torch version + recommended_torch_version = _parse_version("1.10.0") + torch_version = _parse_version(torch.__version__) + if len(torch_version) > 0 and torch_version >= recommended_torch_version: + return torch.meshgrid(x, y, indexing='ij') + return torch.meshgrid(x, y) + + +def similarity_map(map_x: torch.Tensor, map_y: torch.Tensor, constant: float, alpha: float = 0.0) -> torch.Tensor: + r""" Compute similarity_map between two tensors using Dice-like equation. + + Args: + map_x: Tensor with map to be compared + map_y: Tensor with map to be compared + constant: Used for numerical stability + alpha: Masking coefficient. Subtracts - `alpha` * map_x * map_y from denominator and nominator + """ + return (2.0 * map_x * map_y - alpha * map_x * map_y + constant) / \ + (map_x ** 2 + map_y ** 2 - alpha * map_x * map_y + constant) + + +def gradient_map(x: torch.Tensor, kernels: torch.Tensor) -> torch.Tensor: + r""" Compute gradient map for a given tensor and stack of kernels. + + Args: + x: Tensor with shape (N, C, H, W). + kernels: Stack of tensors for gradient computation with shape (k_N, k_H, k_W) + Returns: + Gradients of x per-channel with shape (N, C, H, W) + """ + padding = kernels.size(-1) // 2 + grads = torch.nn.functional.conv2d(x, kernels, padding=padding) + + return torch.sqrt(torch.sum(grads ** 2, dim=-3, keepdim=True)) + + +def pow_for_complex(base: torch.Tensor, exp: Union[int, float]) -> torch.Tensor: + r""" Takes the power of each element in a 4D tensor with negative values or 5D tensor with complex values. + Complex numbers are represented by modulus and argument: r * \exp(i * \phi). + + It will likely to be redundant with introduction of torch.ComplexTensor. + + Args: + base: Tensor with shape (N, C, H, W) or (N, C, H, W, 2). + exp: Exponent + Returns: + Complex tensor with shape (N, C, H, W, 2). + """ + if base.dim() == 4: + x_complex_r = base.abs() + x_complex_phi = torch.atan2(torch.zeros_like(base), base) + elif base.dim() == 5 and base.size(-1) == 2: + x_complex_r = base.pow(2).sum(dim=-1).sqrt() + x_complex_phi = torch.atan2(base[..., 1], base[..., 0]) + else: + raise ValueError(f'Expected real or complex tensor, got {base.size()}') + + x_complex_pow_r = x_complex_r ** exp + x_complex_pow_phi = x_complex_phi * exp + x_real_pow = x_complex_pow_r * torch.cos(x_complex_pow_phi) + x_imag_pow = x_complex_pow_r * torch.sin(x_complex_pow_phi) + return torch.stack((x_real_pow, x_imag_pow), dim=-1) + + +def crop_patches(x: torch.Tensor, size=64, stride=32) -> torch.Tensor: + r"""Crop tensor with images into small patches + Args: + x: Tensor with shape (N, C, H, W), expected to be images-like entities + size: Size of a square patch + stride: Step between patches + """ + assert (x.shape[2] >= size) and (x.shape[3] >= size), \ + f"Images must be bigger than patch size. Got ({x.shape[2], x.shape[3]}) and ({size}, {size})" + channels = x.shape[1] + patches = x.unfold(1, channels, channels).unfold(2, size, stride).unfold(3, size, stride) + patches = patches.reshape(-1, channels, size, size) + return patches diff --git a/libs/metric/piq/functional/colour_conversion.py b/libs/metric/piq/functional/colour_conversion.py new file mode 100644 index 0000000000000000000000000000000000000000..9de6eb031a60aa765a326cb6ef8cf67c37177d97 --- /dev/null +++ b/libs/metric/piq/functional/colour_conversion.py @@ -0,0 +1,136 @@ +r"""Colour space conversion functions""" +from typing import Union, Dict +import torch + + +def rgb2lmn(x: torch.Tensor) -> torch.Tensor: + r"""Convert a batch of RGB images to a batch of LMN images + + Args: + x: Batch of images with shape (N, 3, H, W). RGB colour space. + + Returns: + Batch of images with shape (N, 3, H, W). LMN colour space. + """ + weights_rgb_to_lmn = torch.tensor([[0.06, 0.63, 0.27], + [0.30, 0.04, -0.35], + [0.34, -0.6, 0.17]], dtype=x.dtype, device=x.device).t() + x_lmn = torch.matmul(x.permute(0, 2, 3, 1), weights_rgb_to_lmn).permute(0, 3, 1, 2) + return x_lmn + + +def rgb2xyz(x: torch.Tensor) -> torch.Tensor: + r"""Convert a batch of RGB images to a batch of XYZ images + + Args: + x: Batch of images with shape (N, 3, H, W). RGB colour space. + + Returns: + Batch of images with shape (N, 3, H, W). XYZ colour space. + """ + mask_below = (x <= 0.04045).type(x.dtype) + mask_above = (x > 0.04045).type(x.dtype) + + tmp = x / 12.92 * mask_below + torch.pow((x + 0.055) / 1.055, 2.4) * mask_above + + weights_rgb_to_xyz = torch.tensor([[0.4124564, 0.3575761, 0.1804375], + [0.2126729, 0.7151522, 0.0721750], + [0.0193339, 0.1191920, 0.9503041]], dtype=x.dtype, device=x.device) + + x_xyz = torch.matmul(tmp.permute(0, 2, 3, 1), weights_rgb_to_xyz.t()).permute(0, 3, 1, 2) + return x_xyz + + +def xyz2lab(x: torch.Tensor, illuminant: str = 'D50', observer: str = '2') -> torch.Tensor: + r"""Convert a batch of XYZ images to a batch of LAB images + + Args: + x: Batch of images with shape (N, 3, H, W). XYZ colour space. + illuminant: {“A”, “D50”, “D55”, “D65”, “D75”, “E”}, optional. The name of the illuminant. + observer: {“2”, “10”}, optional. The aperture angle of the observer. + + Returns: + Batch of images with shape (N, 3, H, W). LAB colour space. + """ + epsilon = 0.008856 + kappa = 903.3 + illuminants: Dict[str, Dict] = \ + {"A": {'2': (1.098466069456375, 1, 0.3558228003436005), + '10': (1.111420406956693, 1, 0.3519978321919493)}, + "D50": {'2': (0.9642119944211994, 1, 0.8251882845188288), + '10': (0.9672062750333777, 1, 0.8142801513128616)}, + "D55": {'2': (0.956797052643698, 1, 0.9214805860173273), + '10': (0.9579665682254781, 1, 0.9092525159847462)}, + "D65": {'2': (0.95047, 1., 1.08883), # This was: `lab_ref_white` + '10': (0.94809667673716, 1, 1.0730513595166162)}, + "D75": {'2': (0.9497220898840717, 1, 1.226393520724154), + '10': (0.9441713925645873, 1, 1.2064272211720228)}, + "E": {'2': (1.0, 1.0, 1.0), + '10': (1.0, 1.0, 1.0)}} + + illuminants_to_use = torch.tensor(illuminants[illuminant][observer], + dtype=x.dtype, device=x.device).view(1, 3, 1, 1) + + tmp = x / illuminants_to_use + + mask_below = (tmp <= epsilon).type(x.dtype) + mask_above = (tmp > epsilon).type(x.dtype) + tmp = torch.pow(tmp, 1. / 3.) * mask_above + (kappa * tmp + 16.) / 116. * mask_below + + weights_xyz_to_lab = torch.tensor([[0, 116., 0], + [500., -500., 0], + [0, 200., -200.]], dtype=x.dtype, device=x.device) + bias_xyz_to_lab = torch.tensor([-16., 0., 0.], dtype=x.dtype, device=x.device).view(1, 3, 1, 1) + + x_lab = torch.matmul(tmp.permute(0, 2, 3, 1), weights_xyz_to_lab.t()).permute(0, 3, 1, 2) + bias_xyz_to_lab + return x_lab + + +def rgb2lab(x: torch.Tensor, data_range: Union[int, float] = 255) -> torch.Tensor: + r"""Convert a batch of RGB images to a batch of LAB images + + Args: + x: Batch of images with shape (N, 3, H, W). RGB colour space. + data_range: dynamic range of the input image. + + Returns: + Batch of images with shape (N, 3, H, W). LAB colour space. + """ + return xyz2lab(rgb2xyz(x / float(data_range))) + + +def rgb2yiq(x: torch.Tensor) -> torch.Tensor: + r"""Convert a batch of RGB images to a batch of YIQ images + + Args: + x: Batch of images with shape (N, 3, H, W). RGB colour space. + + Returns: + Batch of images with shape (N, 3, H, W). YIQ colour space. + """ + yiq_weights = torch.tensor([ + [0.299, 0.587, 0.114], + [0.5959, -0.2746, -0.3213], + [0.2115, -0.5227, 0.3112]], dtype=x.dtype, device=x.device).t() + x_yiq = torch.matmul(x.permute(0, 2, 3, 1), yiq_weights).permute(0, 3, 1, 2) + return x_yiq + + +def rgb2lhm(x: torch.Tensor) -> torch.Tensor: + r"""Convert a batch of RGB images to a batch of LHM images + + Args: + x: Batch of images with shape (N, 3, H, W). RGB colour space. + + Returns: + Batch of images with shape (N, 3, H, W). LHM colour space. + + Reference: + https://arxiv.org/pdf/1608.07433.pdf + """ + lhm_weights = torch.tensor([ + [0.2989, 0.587, 0.114], + [0.3, 0.04, -0.35], + [0.34, -0.6, 0.17]], dtype=x.dtype, device=x.device).t() + x_lhm = torch.matmul(x.permute(0, 2, 3, 1), lhm_weights).permute(0, 3, 1, 2) + return x_lhm diff --git a/libs/metric/piq/functional/filters.py b/libs/metric/piq/functional/filters.py new file mode 100644 index 0000000000000000000000000000000000000000..ff5ef1ac5110fa57b75de7476567a409842c0dfc --- /dev/null +++ b/libs/metric/piq/functional/filters.py @@ -0,0 +1,111 @@ +r"""Filters for gradient computation, bluring, etc.""" +import torch +import numpy as np +from typing import Optional + + +def haar_filter(kernel_size: int, device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor: + r"""Creates Haar kernel + + Args: + kernel_size: size of the kernel + device: target device for kernel generation + dtype: target data type for kernel generation + Returns: + kernel: Tensor with shape (1, kernel_size, kernel_size) + """ + kernel = torch.ones((kernel_size, kernel_size), device=device, dtype=dtype) / kernel_size + kernel[kernel_size // 2:, :] = - kernel[kernel_size // 2:, :] + return kernel.unsqueeze(0) + + +def hann_filter(kernel_size: int, device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor: + r"""Creates Hann kernel + Args: + kernel_size: size of the kernel + device: target device for kernel generation + dtype: target data type for kernel generation + Returns: + kernel: Tensor with shape (1, kernel_size, kernel_size) + """ + # Take bigger window and drop borders + window = torch.hann_window(kernel_size + 2, periodic=False, device=device, dtype=dtype)[1:-1] + kernel = window[:, None] * window[None, :] + # Normalize and reshape kernel + return kernel.view(1, kernel_size, kernel_size) / kernel.sum() + + +def gaussian_filter(kernel_size: int, sigma: float, device: Optional[str] = None, + dtype: Optional[type] = None) -> torch.Tensor: + r"""Returns 2D Gaussian kernel N(0,`sigma`^2) + Args: + size: Size of the kernel + sigma: Std of the distribution + device: target device for kernel generation + dtype: target data type for kernel generation + Returns: + gaussian_kernel: Tensor with shape (1, kernel_size, kernel_size) + """ + coords = torch.arange(kernel_size, dtype=dtype, device=device) + coords -= (kernel_size - 1) / 2. + + g = coords ** 2 + g = (- (g.unsqueeze(0) + g.unsqueeze(1)) / (2 * sigma ** 2)).exp() + + g /= g.sum() + return g.unsqueeze(0) + + +# Gradient operator kernels +def scharr_filter(device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor: + r"""Utility function that returns a normalized 3x3 Scharr kernel in X direction + + Args: + device: target device for kernel generation + dtype: target data type for kernel generation + Returns: + kernel: Tensor with shape (1, 3, 3) + """ + return torch.tensor([[[-3., 0., 3.], [-10., 0., 10.], [-3., 0., 3.]]], device=device, dtype=dtype) / 16 + + +def prewitt_filter(device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor: + r"""Utility function that returns a normalized 3x3 Prewitt kernel in X direction + + Args: + device: target device for kernel generation + dtype: target data type for kernel generation + Returns: + kernel: Tensor with shape (1, 3, 3)""" + return torch.tensor([[[-1., 0., 1.], [-1., 0., 1.], [-1., 0., 1.]]], device=device, dtype=dtype) / 3 + + +def binomial_filter1d(kernel_size: int, device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor: + r"""Creates 1D normalized binomial filter + + Args: + kernel_size (int): kernel size + device: target device for kernel generation + dtype: target data type for kernel generation + + Returns: + Binomial kernel with shape (1, 1, kernel_size) + """ + kernel = np.poly1d([0.5, 0.5]) ** (kernel_size - 1) + return torch.tensor(kernel.c, dtype=dtype, device=device).view(1, 1, kernel_size) + + +def average_filter2d(kernel_size: int, device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor: + r"""Creates 2D normalized average filter + + Args: + kernel_size (int): kernel size + device: target device for kernel generation + dtype: target data type for kernel generation + + Returns: + kernel: Tensor with shape (1, kernel_size, kernel_size) + """ + window = torch.ones(kernel_size, dtype=dtype, device=device) / kernel_size + kernel = window[:, None] * window[None, :] + return kernel.unsqueeze(0) diff --git a/libs/metric/piq/functional/layers.py b/libs/metric/piq/functional/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..3f0701dbc5acfc47e97aba32c0e04aa80c5bc8fc --- /dev/null +++ b/libs/metric/piq/functional/layers.py @@ -0,0 +1,33 @@ +r"""Custom layers used in metrics computations""" +import torch +from typing import Optional + +from .filters import hann_filter + + +class L2Pool2d(torch.nn.Module): + r"""Applies L2 pooling with Hann window of size 3x3 + Args: + x: Tensor with shape (N, C, H, W)""" + EPS = 1e-12 + + def __init__(self, kernel_size: int = 3, stride: int = 2, padding=1) -> None: + super().__init__() + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + + self.kernel: Optional[torch.Tensor] = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.kernel is None: + C = x.size(1) + self.kernel = hann_filter(self.kernel_size).repeat((C, 1, 1, 1)).to(x) + + out = torch.nn.functional.conv2d( + x ** 2, self.kernel, + stride=self.stride, + padding=self.padding, + groups=x.shape[1] + ) + return (out + self.EPS).sqrt() diff --git a/libs/metric/piq/functional/resize.py b/libs/metric/piq/functional/resize.py new file mode 100644 index 0000000000000000000000000000000000000000..b3a39b45a10a71cc38ac07a89929cf2fad033239 --- /dev/null +++ b/libs/metric/piq/functional/resize.py @@ -0,0 +1,426 @@ +""" +A standalone PyTorch implementation for fast and efficient bicubic resampling. +The resulting values are the same to MATLAB function imresize('bicubic'). +## Author: Sanghyun Son +## Email: sonsang35@gmail.com (primary), thstkdgus35@snu.ac.kr (secondary) +## Version: 1.2.0 +## Last update: July 9th, 2020 (KST) +Dependency: torch +Example:: +>>> import torch +>>> import core +>>> x = torch.arange(16).float().view(1, 1, 4, 4) +>>> y = core.imresize(x, sizes=(3, 3)) +>>> print(y) +tensor([[[[ 0.7506, 2.1004, 3.4503], + [ 6.1505, 7.5000, 8.8499], + [11.5497, 12.8996, 14.2494]]]]) +""" + +import math +import typing + +import torch +from torch.nn import functional as F + +__all__ = ['imresize'] + +_I = typing.Optional[int] +_D = typing.Optional[torch.dtype] + + +def nearest_contribution(x: torch.Tensor) -> torch.Tensor: + range_around_0 = torch.logical_and(x.gt(-0.5), x.le(0.5)) + cont = range_around_0.to(dtype=x.dtype) + return cont + + +def linear_contribution(x: torch.Tensor) -> torch.Tensor: + ax = x.abs() + range_01 = ax.le(1) + cont = (1 - ax) * range_01.to(dtype=x.dtype) + return cont + + +def cubic_contribution(x: torch.Tensor, a: float = -0.5) -> torch.Tensor: + ax = x.abs() + ax2 = ax * ax + ax3 = ax * ax2 + + range_01 = ax.le(1) + range_12 = torch.logical_and(ax.gt(1), ax.le(2)) + + cont_01 = (a + 2) * ax3 - (a + 3) * ax2 + 1 + cont_01 = cont_01 * range_01.to(dtype=x.dtype) + + cont_12 = (a * ax3) - (5 * a * ax2) + (8 * a * ax) - (4 * a) + cont_12 = cont_12 * range_12.to(dtype=x.dtype) + + cont = cont_01 + cont_12 + return cont + + +def gaussian_contribution(x: torch.Tensor, sigma: float = 2.0) -> torch.Tensor: + range_3sigma = (x.abs() <= 3 * sigma + 1) + # Normalization will be done after + cont = torch.exp(-x.pow(2) / (2 * sigma ** 2)) + cont = cont * range_3sigma.to(dtype=x.dtype) + return cont + + +def discrete_kernel( + kernel: str, scale: float, antialiasing: bool = True) -> torch.Tensor: + ''' + For downsampling with integer scale only. + ''' + downsampling_factor = int(1 / scale) + if kernel == 'cubic': + kernel_size_orig = 4 + else: + raise ValueError('Pass!') + + if antialiasing: + kernel_size = kernel_size_orig * downsampling_factor + else: + kernel_size = kernel_size_orig + + if downsampling_factor % 2 == 0: + a = kernel_size_orig * (0.5 - 1 / (2 * kernel_size)) + else: + kernel_size -= 1 + a = kernel_size_orig * (0.5 - 1 / (kernel_size + 1)) + + with torch.no_grad(): + r = torch.linspace(-a, a, steps=kernel_size) + k = cubic_contribution(r).view(-1, 1) + k = torch.matmul(k, k.t()) + k /= k.sum() + + return k + + +def reflect_padding( + x: torch.Tensor, + dim: int, + pad_pre: int, + pad_post: int) -> torch.Tensor: + ''' + Apply reflect padding to the given Tensor. + Note that it is slightly different from the PyTorch functional.pad, + where boundary elements are used only once. + Instead, we follow the MATLAB implementation + which uses boundary elements twice. + For example, + [a, b, c, d] would become [b, a, b, c, d, c] with the PyTorch implementation, + while our implementation yields [a, a, b, c, d, d]. + ''' + b, c, h, w = x.size() + if dim == 2 or dim == -2: + padding_buffer = x.new_zeros(b, c, h + pad_pre + pad_post, w) + padding_buffer[..., pad_pre:(h + pad_pre), :].copy_(x) + for p in range(pad_pre): + padding_buffer[..., pad_pre - p - 1, :].copy_(x[..., p, :]) + for p in range(pad_post): + padding_buffer[..., h + pad_pre + p, :].copy_(x[..., -(p + 1), :]) + else: + padding_buffer = x.new_zeros(b, c, h, w + pad_pre + pad_post) + padding_buffer[..., pad_pre:(w + pad_pre)].copy_(x) + for p in range(pad_pre): + padding_buffer[..., pad_pre - p - 1].copy_(x[..., p]) + for p in range(pad_post): + padding_buffer[..., w + pad_pre + p].copy_(x[..., -(p + 1)]) + + return padding_buffer + + +def padding( + x: torch.Tensor, + dim: int, + pad_pre: int, + pad_post: int, + padding_type: typing.Optional[str] = 'reflect') -> torch.Tensor: + if padding_type is None: + return x + elif padding_type == 'reflect': + x_pad = reflect_padding(x, dim, pad_pre, pad_post) + else: + raise ValueError('{} padding is not supported!'.format(padding_type)) + + return x_pad + + +def get_padding( + base: torch.Tensor, + kernel_size: int, + x_size: int) -> typing.Tuple[int, int, torch.Tensor]: + base = base.long() + r_min = base.min() + r_max = base.max() + kernel_size - 1 + + if r_min <= 0: + pad_pre = -r_min + pad_pre = pad_pre.item() + base += pad_pre + else: + pad_pre = 0 + + if r_max >= x_size: + pad_post = r_max - x_size + 1 + pad_post = pad_post.item() + else: + pad_post = 0 + + return pad_pre, pad_post, base + + +def get_weight( + dist: torch.Tensor, + kernel_size: int, + kernel: str = 'cubic', + sigma: float = 2.0, + antialiasing_factor: float = 1) -> torch.Tensor: + buffer_pos = dist.new_zeros(kernel_size, len(dist)) + for idx, buffer_sub in enumerate(buffer_pos): + buffer_sub.copy_(dist - idx) + + # Expand (downsampling) / Shrink (upsampling) the receptive field. + buffer_pos *= antialiasing_factor + if kernel == 'cubic': + weight = cubic_contribution(buffer_pos) + elif kernel == 'gaussian': + weight = gaussian_contribution(buffer_pos, sigma=sigma) + else: + raise ValueError('{} kernel is not supported!'.format(kernel)) + + weight /= weight.sum(dim=0, keepdim=True) + return weight + + +def reshape_tensor(x: torch.Tensor, dim: int, kernel_size: int) -> torch.Tensor: + # Resize height + if dim == 2 or dim == -2: + k = (kernel_size, 1) + h_out = x.size(-2) - kernel_size + 1 + w_out = x.size(-1) + # Resize width + else: + k = (1, kernel_size) + h_out = x.size(-2) + w_out = x.size(-1) - kernel_size + 1 + + unfold = F.unfold(x, k) + unfold = unfold.view(unfold.size(0), -1, h_out, w_out) + return unfold + + +def reshape_input(x: torch.Tensor) -> typing.Tuple[torch.Tensor, _I, _I, int, int]: + if x.dim() == 4: + b, c, h, w = x.size() + elif x.dim() == 3: + c, h, w = x.size() + b = None + elif x.dim() == 2: + h, w = x.size() + b = c = None + else: + raise ValueError('{}-dim Tensor is not supported!'.format(x.dim())) + + x = x.view(-1, 1, h, w) + return x, b, c, h, w + + +def reshape_output(x: torch.Tensor, b: _I, c: _I) -> torch.Tensor: + rh = x.size(-2) + rw = x.size(-1) + # Back to the original dimension + if b is not None: + x = x.view(b, c, rh, rw) # 4-dim + else: + if c is not None: + x = x.view(c, rh, rw) # 3-dim + else: + x = x.view(rh, rw) # 2-dim + + return x + + +def cast_input(x: torch.Tensor) -> typing.Tuple[torch.Tensor, _D]: + if x.dtype != torch.float32 or x.dtype != torch.float64: + dtype = x.dtype + x = x.float() + else: + dtype = None + + return x, dtype + + +def cast_output(x: torch.Tensor, dtype: _D) -> torch.Tensor: + if dtype is not None: + if not dtype.is_floating_point: + x = x.round() + # To prevent over/underflow when converting types + if dtype is torch.uint8: + x = x.clamp(0, 255) + + x = x.to(dtype=dtype) + + return x + + +def resize_1d( + x: torch.Tensor, + dim: int, + size: int, + scale: float, + kernel: str = 'cubic', + sigma: float = 2.0, + padding_type: str = 'reflect', + antialiasing: bool = True) -> torch.Tensor: + ''' + Args: + x (torch.Tensor): A torch.Tensor of dimension (B x C, 1, H, W). + dim (int): + scale (float): + size (int): + Return: + ''' + # Identity case + if scale == 1: + return x + + # Default bicubic kernel with antialiasing (only when downsampling) + if kernel == 'cubic': + kernel_size = 4 + else: + kernel_size = math.floor(6 * sigma) + + if antialiasing and (scale < 1): + antialiasing_factor = scale + kernel_size = math.ceil(kernel_size / antialiasing_factor) + else: + antialiasing_factor = 1 + + # We allow margin to both sizes + kernel_size += 2 + + # Weights only depend on the shape of input and output, + # so we do not calculate gradients here. + with torch.no_grad(): + pos = torch.linspace( + 0, size - 1, steps=size, dtype=x.dtype, device=x.device, + ) + pos = (pos + 0.5) / scale - 0.5 + base = pos.floor() - (kernel_size // 2) + 1 + dist = pos - base + weight = get_weight( + dist, + kernel_size, + kernel=kernel, + sigma=sigma, + antialiasing_factor=antialiasing_factor, + ) + pad_pre, pad_post, base = get_padding(base, kernel_size, x.size(dim)) + + # To backpropagate through x + x_pad = padding(x, dim, pad_pre, pad_post, padding_type=padding_type) + unfold = reshape_tensor(x_pad, dim, kernel_size) + # Subsampling first + if dim == 2 or dim == -2: + sample = unfold[..., base, :] + weight = weight.view(1, kernel_size, sample.size(2), 1) + else: + sample = unfold[..., base] + weight = weight.view(1, kernel_size, 1, sample.size(3)) + + # Apply the kernel + x = sample * weight + x = x.sum(dim=1, keepdim=True) + return x + + +def downsampling_2d( + x: torch.Tensor, + k: torch.Tensor, + scale: int, + padding_type: str = 'reflect') -> torch.Tensor: + c = x.size(1) + k_h = k.size(-2) + k_w = k.size(-1) + + k = k.to(dtype=x.dtype, device=x.device) + k = k.view(1, 1, k_h, k_w) + k = k.repeat(c, c, 1, 1) + e = torch.eye(c, dtype=k.dtype, device=k.device, requires_grad=False) + e = e.view(c, c, 1, 1) + k = k * e + + pad_h = (k_h - scale) // 2 + pad_w = (k_w - scale) // 2 + x = padding(x, -2, pad_h, pad_h, padding_type=padding_type) + x = padding(x, -1, pad_w, pad_w, padding_type=padding_type) + y = F.conv2d(x, k, padding=0, stride=scale) + return y + + +def imresize( + x: torch.Tensor, + scale: typing.Optional[float] = None, + sizes: typing.Optional[typing.Tuple[int, int]] = None, + kernel: typing.Union[str, torch.Tensor] = 'cubic', + sigma: float = 2, + rotation_degree: float = 0, + padding_type: str = 'reflect', + antialiasing: bool = True) -> torch.Tensor: + """ + Args: + x (torch.Tensor): + scale (float): + sizes (tuple(int, int)): + kernel (str, default='cubic'): + sigma (float, default=2): + rotation_degree (float, default=0): + padding_type (str, default='reflect'): + antialiasing (bool, default=True): + Return: + torch.Tensor: + """ + if scale is None and sizes is None: + raise ValueError('One of scale or sizes must be specified!') + if scale is not None and sizes is not None: + raise ValueError('Please specify scale or sizes to avoid conflict!') + + x, b, c, h, w = reshape_input(x) + + if sizes is None and scale is not None: + ''' + # Check if we can apply the convolution algorithm + scale_inv = 1 / scale + if isinstance(kernel, str) and scale_inv.is_integer(): + kernel = discrete_kernel(kernel, scale, antialiasing=antialiasing) + elif isinstance(kernel, torch.Tensor) and not scale_inv.is_integer(): + raise ValueError( + 'An integer downsampling factor ' + 'should be used with a predefined kernel!' + ) + ''' + # Determine output size + sizes = (math.ceil(h * scale), math.ceil(w * scale)) + scales = (scale, scale) + + if scale is None and sizes is not None: + scales = (sizes[0] / h, sizes[1] / w) + + x, dtype = cast_input(x) + + if isinstance(kernel, str) and sizes is not None: + # Core resizing module + x = resize_1d(x, -2, size=sizes[0], scale=scales[0], kernel=kernel, sigma=sigma, padding_type=padding_type, + antialiasing=antialiasing) + x = resize_1d(x, -1, size=sizes[1], scale=scales[1], kernel=kernel, sigma=sigma, padding_type=padding_type, + antialiasing=antialiasing) + elif isinstance(kernel, torch.Tensor) and scale is not None: + x = downsampling_2d(x, kernel, scale=int(1 / scale)) + + x = reshape_output(x, b, c) + x = cast_output(x, dtype) + return x diff --git a/libs/metric/piq/perceptual.py b/libs/metric/piq/perceptual.py new file mode 100644 index 0000000000000000000000000000000000000000..68a704d4f21fa569bcd6d1b4ce7862b780ba8d2a --- /dev/null +++ b/libs/metric/piq/perceptual.py @@ -0,0 +1,496 @@ +""" +Implementation of Content loss, Style loss, LPIPS and DISTS metrics +References: + .. [1] Gatys, Leon and Ecker, Alexander and Bethge, Matthias + (2016). A Neural Algorithm of Artistic Style} + Association for Research in Vision and Ophthalmology (ARVO) + https://arxiv.org/abs/1508.06576 + .. [2] Zhang, Richard and Isola, Phillip and Efros, et al. + (2018) The Unreasonable Effectiveness of Deep Features as a Perceptual Metric + 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition + https://arxiv.org/abs/1801.03924 +""" +from typing import List, Union, Collection + +import torch +import torch.nn as nn +from torch.nn.modules.loss import _Loss +from torchvision.models import vgg16, vgg19, VGG16_Weights, VGG19_Weights + +from .utils import _validate_input, _reduce +from .functional import similarity_map, L2Pool2d + +# Map VGG names to corresponding number in torchvision layer +VGG16_LAYERS = { + "conv1_1": '0', "relu1_1": '1', + "conv1_2": '2', "relu1_2": '3', + "pool1": '4', + "conv2_1": '5', "relu2_1": '6', + "conv2_2": '7', "relu2_2": '8', + "pool2": '9', + "conv3_1": '10', "relu3_1": '11', + "conv3_2": '12', "relu3_2": '13', + "conv3_3": '14', "relu3_3": '15', + "pool3": '16', + "conv4_1": '17', "relu4_1": '18', + "conv4_2": '19', "relu4_2": '20', + "conv4_3": '21', "relu4_3": '22', + "pool4": '23', + "conv5_1": '24', "relu5_1": '25', + "conv5_2": '26', "relu5_2": '27', + "conv5_3": '28', "relu5_3": '29', + "pool5": '30', +} + +VGG19_LAYERS = { + "conv1_1": '0', "relu1_1": '1', + "conv1_2": '2', "relu1_2": '3', + "pool1": '4', + "conv2_1": '5', "relu2_1": '6', + "conv2_2": '7', "relu2_2": '8', + "pool2": '9', + "conv3_1": '10', "relu3_1": '11', + "conv3_2": '12', "relu3_2": '13', + "conv3_3": '14', "relu3_3": '15', + "conv3_4": '16', "relu3_4": '17', + "pool3": '18', + "conv4_1": '19', "relu4_1": '20', + "conv4_2": '21', "relu4_2": '22', + "conv4_3": '23', "relu4_3": '24', + "conv4_4": '25', "relu4_4": '26', + "pool4": '27', + "conv5_1": '28', "relu5_1": '29', + "conv5_2": '30', "relu5_2": '31', + "conv5_3": '32', "relu5_3": '33', + "conv5_4": '34', "relu5_4": '35', + "pool5": '36', +} + +IMAGENET_MEAN = [0.485, 0.456, 0.406] +IMAGENET_STD = [0.229, 0.224, 0.225] + +# Constant used in feature normalization to avoid zero division +EPS = 1e-10 + + +class ContentLoss(_Loss): + r"""Creates Content loss that can be used for image style transfer or as a measure for image to image tasks. + Uses pretrained VGG models from torchvision. + Expects input to be in range [0, 1] or normalized with ImageNet statistics into range [-1, 1] + + Args: + feature_extractor: Model to extract features or model name: ``'vgg16'`` | ``'vgg19'``. + layers: List of strings with layer names. Default: ``'relu3_3'`` + weights: List of float weight to balance different layers + replace_pooling: Flag to replace MaxPooling layer with AveragePooling. See references for details. + distance: Method to compute distance between features: ``'mse'`` | ``'mae'``. + reduction: Specifies the reduction type: + ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'`` + mean: List of float values used for data standardization. Default: ImageNet mean. + If there is no need to normalize data, use [0., 0., 0.]. + std: List of float values used for data standardization. Default: ImageNet std. + If there is no need to normalize data, use [1., 1., 1.]. + normalize_features: If true, unit-normalize each feature in channel dimension before scaling + and computing distance. See references for details. + + Examples: + >>> loss = ContentLoss() + >>> x = torch.rand(3, 3, 256, 256, requires_grad=True) + >>> y = torch.rand(3, 3, 256, 256) + >>> output = loss(x, y) + >>> output.backward() + + References: + Gatys, Leon and Ecker, Alexander and Bethge, Matthias (2016). + A Neural Algorithm of Artistic Style + Association for Research in Vision and Ophthalmology (ARVO) + https://arxiv.org/abs/1508.06576 + + Zhang, Richard and Isola, Phillip and Efros, et al. (2018) + The Unreasonable Effectiveness of Deep Features as a Perceptual Metric + IEEE/CVF Conference on Computer Vision and Pattern Recognition + https://arxiv.org/abs/1801.03924 + """ + + def __init__(self, feature_extractor: Union[str, torch.nn.Module] = "vgg16", layers: Collection[str] = ("relu3_3",), + weights: List[Union[float, torch.Tensor]] = [1.], replace_pooling: bool = False, + distance: str = "mse", reduction: str = "mean", mean: List[float] = IMAGENET_MEAN, + std: List[float] = IMAGENET_STD, normalize_features: bool = False, + allow_layers_weights_mismatch: bool = False) -> None: + + assert allow_layers_weights_mismatch or len(layers) == len(weights), \ + f'Lengths of provided layers and weighs mismatch ({len(weights)} weights and {len(layers)} layers), ' \ + f'which will cause incorrect results. Please provide weight for each layer.' + + super().__init__() + + if callable(feature_extractor): + self.model = feature_extractor + self.layers = layers + else: + if feature_extractor == "vgg16": + # self.model = vgg16(pretrained=True, progress=False).features + self.model = vgg16(weights=VGG16_Weights.DEFAULT, progress=False).features + self.layers = [VGG16_LAYERS[l] for l in layers] + elif feature_extractor == "vgg19": + # self.model = vgg19(pretrained=True, progress=False).features + self.model = vgg19(weights=VGG19_Weights.DEFAULT, progress=False).features + self.layers = [VGG19_LAYERS[l] for l in layers] + else: + raise ValueError("Unknown feature extractor") + + if replace_pooling: + self.model = self.replace_pooling(self.model) + + # Disable gradients + for param in self.model.parameters(): + param.requires_grad_(False) + + self.distance = { + "mse": nn.MSELoss, + "mae": nn.L1Loss, + }[distance](reduction='none') + + self.weights = [torch.tensor(w) if not isinstance(w, torch.Tensor) else w for w in weights] + + mean = torch.tensor(mean) + std = torch.tensor(std) + self.mean = mean.view(1, -1, 1, 1) + self.std = std.view(1, -1, 1, 1) + + self.normalize_features = normalize_features + self.reduction = reduction + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + r"""Computation of Content loss between feature representations of prediction :math:`x` and + target :math:`y` tensors. + + Args: + x: An input tensor. Shape :math:`(N, C, H, W)`. + y: A target tensor. Shape :math:`(N, C, H, W)`. + + Returns: + Content loss between feature representations + """ + _validate_input([x, y], dim_range=(4, 4), data_range=(0, -1)) + + self.model.to(x) + x_features = self.get_features(x) + y_features = self.get_features(y) + + distances = self.compute_distance(x_features, y_features) + + # Scale distances, then average in spatial dimensions, then stack and sum in channels dimension + loss = torch.cat([(d * w.to(d)).mean(dim=[2, 3]) for d, w in zip(distances, self.weights)], dim=1).sum(dim=1) + + return _reduce(loss, self.reduction) + + def compute_distance(self, x_features: List[torch.Tensor], y_features: List[torch.Tensor]) -> List[torch.Tensor]: + r"""Take L2 or L1 distance between feature maps depending on ``distance``. + + Args: + x_features: Features of the input tensor. + y_features: Features of the target tensor. + + Returns: + Distance between feature maps + """ + return [self.distance(x, y) for x, y in zip(x_features, y_features)] + + def get_features(self, x: torch.Tensor) -> List[torch.Tensor]: + r""" + Args: + x: Tensor. Shape :math:`(N, C, H, W)`. + + Returns: + List of features extracted from intermediate layers + """ + # Normalize input + x = (x - self.mean.to(x)) / self.std.to(x) + + features = [] + for name, module in self.model._modules.items(): + x = module(x) + if name in self.layers: + features.append(self.normalize(x) if self.normalize_features else x) + + return features + + @staticmethod + def normalize(x: torch.Tensor) -> torch.Tensor: + r"""Normalize feature maps in channel direction to unit length. + + Args: + x: Tensor. Shape :math:`(N, C, H, W)`. + + Returns: + Normalized input + """ + norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) + return x / (norm_factor + EPS) + + def replace_pooling(self, module: torch.nn.Module) -> torch.nn.Module: + r"""Turn All MaxPool layers into AveragePool + + Args: + module: Module to change MaxPool int AveragePool + + Returns: + Module with AveragePool instead MaxPool + + """ + module_output = module + if isinstance(module, torch.nn.MaxPool2d): + module_output = torch.nn.AvgPool2d(kernel_size=2, stride=2, padding=0) + + for name, child in module.named_children(): + module_output.add_module(name, self.replace_pooling(child)) + return module_output + + +class StyleLoss(ContentLoss): + r"""Creates Style loss that can be used for image style transfer or as a measure in + image to image tasks. Computes distance between Gram matrices of feature maps. + Uses pretrained VGG models from torchvision. + + By default expects input to be in range [0, 1], which is then normalized by ImageNet statistics into range [-1, 1]. + If no normalisation is required, change `mean` and `std` values accordingly. + + Args: + feature_extractor: Model to extract features or model name: ``'vgg16'`` | ``'vgg19'``. + layers: List of strings with layer names. Default: ``'relu3_3'`` + weights: List of float weight to balance different layers + replace_pooling: Flag to replace MaxPooling layer with AveragePooling. See references for details. + distance: Method to compute distance between features: ``'mse'`` | ``'mae'``. + reduction: Specifies the reduction type: + ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'`` + mean: List of float values used for data standardization. Default: ImageNet mean. + If there is no need to normalize data, use [0., 0., 0.]. + std: List of float values used for data standardization. Default: ImageNet std. + If there is no need to normalize data, use [1., 1., 1.]. + normalize_features: If true, unit-normalize each feature in channel dimension before scaling + and computing distance. See references for details. + + Examples: + >>> loss = StyleLoss() + >>> x = torch.rand(3, 3, 256, 256, requires_grad=True) + >>> y = torch.rand(3, 3, 256, 256) + >>> output = loss(x, y) + >>> output.backward() + + References: + Gatys, Leon and Ecker, Alexander and Bethge, Matthias (2016). + A Neural Algorithm of Artistic Style + Association for Research in Vision and Ophthalmology (ARVO) + https://arxiv.org/abs/1508.06576 + + Zhang, Richard and Isola, Phillip and Efros, et al. (2018) + The Unreasonable Effectiveness of Deep Features as a Perceptual Metric + IEEE/CVF Conference on Computer Vision and Pattern Recognition + https://arxiv.org/abs/1801.03924 + """ + + def compute_distance(self, x_features: torch.Tensor, y_features: torch.Tensor): + r"""Take L2 or L1 distance between Gram matrices of feature maps depending on ``distance``. + + Args: + x_features: Features of the input tensor. + y_features: Features of the target tensor. + + Returns: + Distance between Gram matrices + """ + x_gram = [self.gram_matrix(x) for x in x_features] + y_gram = [self.gram_matrix(x) for x in y_features] + return [self.distance(x, y) for x, y in zip(x_gram, y_gram)] + + @staticmethod + def gram_matrix(x: torch.Tensor) -> torch.Tensor: + r"""Compute Gram matrix for batch of features. + + Args: + x: Tensor. Shape :math:`(N, C, H, W)`. + + Returns: + Gram matrix for given input + """ + B, C, H, W = x.size() + gram = [] + for i in range(B): + features = x[i].view(C, H * W) + + # Add fake channel dimension + gram.append(torch.mm(features, features.t()).unsqueeze(0)) + + return torch.stack(gram) + + +class LPIPS(ContentLoss): + r"""Learned Perceptual Image Patch Similarity metric. Only VGG16 learned weights are supported. + + By default expects input to be in range [0, 1], which is then normalized by ImageNet statistics into range [-1, 1]. + If no normalisation is required, change `mean` and `std` values accordingly. + + Args: + replace_pooling: Flag to replace MaxPooling layer with AveragePooling. See references for details. + distance: Method to compute distance between features: ``'mse'`` | ``'mae'``. + reduction: Specifies the reduction type: + ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'`` + mean: List of float values used for data standardization. Default: ImageNet mean. + If there is no need to normalize data, use [0., 0., 0.]. + std: List of float values used for data standardization. Default: ImageNet std. + If there is no need to normalize data, use [1., 1., 1.]. + + Examples: + >>> loss = LPIPS() + >>> x = torch.rand(3, 3, 256, 256, requires_grad=True) + >>> y = torch.rand(3, 3, 256, 256) + >>> output = loss(x, y) + >>> output.backward() + + References: + Gatys, Leon and Ecker, Alexander and Bethge, Matthias (2016). + A Neural Algorithm of Artistic Style + Association for Research in Vision and Ophthalmology (ARVO) + https://arxiv.org/abs/1508.06576 + + Zhang, Richard and Isola, Phillip and Efros, et al. (2018) + The Unreasonable Effectiveness of Deep Features as a Perceptual Metric + IEEE/CVF Conference on Computer Vision and Pattern Recognition + https://arxiv.org/abs/1801.03924 + https://github.com/richzhang/PerceptualSimilarity + """ + _weights_url = "https://github.com/photosynthesis-team/" + \ + "photosynthesis.metrics/releases/download/v0.4.0/lpips_weights.pt" + + def __init__(self, replace_pooling: bool = False, distance: str = "mse", reduction: str = "mean", + mean: List[float] = IMAGENET_MEAN, std: List[float] = IMAGENET_STD, ) -> None: + lpips_layers = ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'] + lpips_weights = torch.hub.load_state_dict_from_url(self._weights_url, progress=False) + super().__init__("vgg16", layers=lpips_layers, weights=lpips_weights, + replace_pooling=replace_pooling, distance=distance, + reduction=reduction, mean=mean, std=std, + normalize_features=True) + + +class DISTS(ContentLoss): + r"""Deep Image Structure and Texture Similarity metric. + + By default expects input to be in range [0, 1], which is then normalized by ImageNet statistics into range [-1, 1]. + If no normalisation is required, change `mean` and `std` values accordingly. + + Args: + reduction: Specifies the reduction type: + ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'`` + mean: List of float values used for data standardization. Default: ImageNet mean. + If there is no need to normalize data, use [0., 0., 0.]. + std: List of float values used for data standardization. Default: ImageNet std. + If there is no need to normalize data, use [1., 1., 1.]. + + Examples: + >>> loss = DISTS() + >>> x = torch.rand(3, 3, 256, 256, requires_grad=True) + >>> y = torch.rand(3, 3, 256, 256) + >>> output = loss(x, y) + >>> output.backward() + + References: + Keyan Ding, Kede Ma, Shiqi Wang, Eero P. Simoncelli (2020). + Image Quality Assessment: Unifying Structure and Texture Similarity. + https://arxiv.org/abs/2004.07728 + https://github.com/dingkeyan93/DISTS + """ + _weights_url = "https://github.com/photosynthesis-team/piq/releases/download/v0.4.1/dists_weights.pt" + + def __init__(self, reduction: str = "mean", mean: List[float] = IMAGENET_MEAN, + std: List[float] = IMAGENET_STD) -> None: + dists_layers = ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'] + channels = [3, 64, 128, 256, 512, 512] + + weights = torch.hub.load_state_dict_from_url(self._weights_url, progress=False) + dists_weights = list(torch.split(weights['alpha'], channels, dim=1)) + dists_weights.extend(torch.split(weights['beta'], channels, dim=1)) + + super().__init__("vgg16", layers=dists_layers, weights=dists_weights, + replace_pooling=True, reduction=reduction, mean=mean, std=std, + normalize_features=False, allow_layers_weights_mismatch=True) + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + r""" + + Args: + x: An input tensor. Shape :math:`(N, C, H, W)`. + y: A target tensor. Shape :math:`(N, C, H, W)`. + + Returns: + Deep Image Structure and Texture Similarity loss, i.e. ``1-DISTS`` in range [0, 1]. + """ + _, _, H, W = x.shape + + if min(H, W) > 256: + x = torch.nn.functional.interpolate( + x, scale_factor=256 / min(H, W), recompute_scale_factor=False, mode='bilinear') + y = torch.nn.functional.interpolate( + y, scale_factor=256 / min(H, W), recompute_scale_factor=False, mode='bilinear') + + loss = super().forward(x, y) + return 1 - loss + + def compute_distance(self, x_features: torch.Tensor, y_features: torch.Tensor) -> List[torch.Tensor]: + r"""Compute structure similarity between feature maps + + Args: + x_features: Features of the input tensor. + y_features: Features of the target tensor. + + Returns: + Structural similarity distance between feature maps + """ + structure_distance, texture_distance = [], [] + # Small constant for numerical stability + EPS = 1e-6 + + for x, y in zip(x_features, y_features): + x_mean = x.mean([2, 3], keepdim=True) + y_mean = y.mean([2, 3], keepdim=True) + structure_distance.append(similarity_map(x_mean, y_mean, constant=EPS)) + + x_var = ((x - x_mean) ** 2).mean([2, 3], keepdim=True) + y_var = ((y - y_mean) ** 2).mean([2, 3], keepdim=True) + xy_cov = (x * y).mean([2, 3], keepdim=True) - x_mean * y_mean + texture_distance.append((2 * xy_cov + EPS) / (x_var + y_var + EPS)) + + return structure_distance + texture_distance + + def get_features(self, x: torch.Tensor) -> List[torch.Tensor]: + r""" + + Args: + x: Input tensor + + Returns: + List of features extracted from input tensor + """ + features = super().get_features(x) + + # Add input tensor as an additional feature + features.insert(0, x) + return features + + def replace_pooling(self, module: torch.nn.Module) -> torch.nn.Module: + r"""Turn All MaxPool layers into L2Pool + + Args: + module: Module to change MaxPool into L2Pool + + Returns: + Module with L2Pool instead of MaxPool + """ + module_output = module + if isinstance(module, torch.nn.MaxPool2d): + module_output = L2Pool2d(kernel_size=3, stride=2, padding=1) + + for name, child in module.named_children(): + module_output.add_module(name, self.replace_pooling(child)) + + return module_output diff --git a/libs/metric/piq/utils/__init__.py b/libs/metric/piq/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4ab6241444c024e5daa7b90190a45f481a66b69b --- /dev/null +++ b/libs/metric/piq/utils/__init__.py @@ -0,0 +1,7 @@ +from .common import _validate_input, _reduce, _parse_version + +__all__ = [ + "_validate_input", + "_reduce", + '_parse_version' +] diff --git a/libs/metric/piq/utils/common.py b/libs/metric/piq/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..1ceb336a52669616ae5609941d90c916997a53eb --- /dev/null +++ b/libs/metric/piq/utils/common.py @@ -0,0 +1,158 @@ +import torch +import re +import warnings + +from typing import Tuple, List, Optional, Union, Dict, Any + +SEMVER_VERSION_PATTERN = re.compile( + r""" + ^ + (?P0|[1-9]\d*) + \. + (?P0|[1-9]\d*) + \. + (?P0|[1-9]\d*) + (?:-(?P + (?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*) + (?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))* + ))? + (?:\+(?P + [0-9a-zA-Z-]+ + (?:\.[0-9a-zA-Z-]+)* + ))? + $ + """, + re.VERBOSE, +) + + +PEP_440_VERSION_PATTERN = r""" + v? + (?: + (?:(?P[0-9]+)!)? # epoch + (?P[0-9]+(?:\.[0-9]+)*) # release segment + (?P
                                          # pre-release
+            [-_\.]?
+            (?P(a|b|c|rc|alpha|beta|pre|preview))
+            [-_\.]?
+            (?P[0-9]+)?
+        )?
+        (?P                                         # post release
+            (?:-(?P[0-9]+))
+            |
+            (?:
+                [-_\.]?
+                (?Ppost|rev|r)
+                [-_\.]?
+                (?P[0-9]+)?
+            )
+        )?
+        (?P                                          # dev release
+            [-_\.]?
+            (?Pdev)
+            [-_\.]?
+            (?P[0-9]+)?
+        )?
+    )
+    (?:\+(?P[a-z0-9]+(?:[-_\.][a-z0-9]+)*))?       # local version
+"""
+
+
+def _validate_input(
+        tensors: List[torch.Tensor],
+        dim_range: Tuple[int, int] = (0, -1),
+        data_range: Tuple[float, float] = (0., -1.),
+        # size_dim_range: Tuple[float, float] = (0., -1.),
+        size_range: Optional[Tuple[int, int]] = None,
+) -> None:
+    r"""Check that input(-s)  satisfies the requirements
+    Args:
+        tensors: Tensors to check
+        dim_range: Allowed number of dimensions. (min, max)
+        data_range: Allowed range of values in tensors. (min, max)
+        size_range: Dimensions to include in size comparison. (start_dim, end_dim + 1)
+    """
+
+    if not __debug__:
+        return
+
+    x = tensors[0]
+
+    for t in tensors:
+        assert torch.is_tensor(t), f'Expected torch.Tensor, got {type(t)}'
+        assert t.device == x.device, f'Expected tensors to be on {x.device}, got {t.device}'
+
+        if size_range is None:
+            assert t.size() == x.size(), f'Expected tensors with same size, got {t.size()} and {x.size()}'
+        else:
+            assert t.size()[size_range[0]: size_range[1]] == x.size()[size_range[0]: size_range[1]], \
+                f'Expected tensors with same size at given dimensions, got {t.size()} and {x.size()}'
+
+        if dim_range[0] == dim_range[1]:
+            assert t.dim() == dim_range[0], f'Expected number of dimensions to be {dim_range[0]}, got {t.dim()}'
+        elif dim_range[0] < dim_range[1]:
+            assert dim_range[0] <= t.dim() <= dim_range[1], \
+                f'Expected number of dimensions to be between {dim_range[0]} and {dim_range[1]}, got {t.dim()}'
+
+        if data_range[0] < data_range[1]:
+            assert data_range[0] <= t.min(), \
+                f'Expected values to be greater or equal to {data_range[0]}, got {t.min()}'
+            assert t.max() <= data_range[1], \
+                f'Expected values to be lower or equal to {data_range[1]}, got {t.max()}'
+
+
+def _reduce(x: torch.Tensor, reduction: str = 'mean') -> torch.Tensor:
+    r"""Reduce input in batch dimension if needed.
+
+    Args:
+        x: Tensor with shape (N, *).
+        reduction: Specifies the reduction type:
+            ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``
+    """
+    if reduction == 'none':
+        return x
+    elif reduction == 'mean':
+        return x.mean(dim=0)
+    elif reduction == 'sum':
+        return x.sum(dim=0)
+    else:
+        raise ValueError("Unknown reduction. Expected one of {'none', 'mean', 'sum'}")
+
+
+def _parse_version(version: Union[str, bytes]) -> Tuple[int, ...]:
+    """ Parses valid Python versions according to Semver and PEP 440 specifications.
+    For more on Semver check: https://semver.org/
+    For more on PEP 440 check: https://www.python.org/dev/peps/pep-0440/.
+
+    Implementation is inspired by:
+    - https://github.com/python-semver
+    - https://github.com/pypa/packaging
+
+    Args:
+        version: unparsed information about the library of interest.
+
+    Returns:
+        parsed information about the library of interest.
+    """
+    if isinstance(version, bytes):
+        version = version.decode("UTF-8")
+    elif not isinstance(version, str) and not isinstance(version, bytes):
+        raise TypeError(f"not expecting type {type(version)}")
+
+    # Semver processing
+    match = SEMVER_VERSION_PATTERN.match(version)
+    if match:
+        matched_version_parts: Dict[str, Any] = match.groupdict()
+        release = tuple([int(matched_version_parts[k]) for k in ['major', 'minor', 'patch']])
+        return release
+
+    # PEP 440 processing
+    regex = re.compile(r"^\s*" + PEP_440_VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE)
+    match = regex.search(version)
+
+    if match is None:
+        warnings.warn(f"{version} is not a valid SemVer or PEP 440 string")
+        return tuple()
+
+    release = tuple(int(i) for i in match.group("release").split("."))
+    return release
diff --git a/libs/metric/pytorch_fid/__init__.py b/libs/metric/pytorch_fid/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..782e20db5f2769a78f8783031cbcd327437ece9a
--- /dev/null
+++ b/libs/metric/pytorch_fid/__init__.py
@@ -0,0 +1,54 @@
+__version__ = '0.3.0'
+
+import torch
+from einops import rearrange, repeat
+
+from .inception import InceptionV3
+from .fid_score import calculate_frechet_distance
+
+
+class PytorchFIDFactory(torch.nn.Module):
+    """
+
+   Args:
+       channels:
+       inception_block_idx:
+
+    Examples:
+    >>> fid_factory =  PytorchFIDFactory()
+    >>> fid_score = fid_factory.score(real_samples=data, fake_samples=all_images)
+    >>> print(fid_score)
+   """
+
+    def __init__(self, channels: int = 3, inception_block_idx: int = 2048):
+        super().__init__()
+        self.channels = channels
+
+        # load models
+        assert inception_block_idx in InceptionV3.BLOCK_INDEX_BY_DIM
+        block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[inception_block_idx]
+        self.inception_v3 = InceptionV3([block_idx])
+
+    @torch.no_grad()
+    def calculate_activation_statistics(self, samples):
+        features = self.inception_v3(samples)[0]
+        features = rearrange(features, '... 1 1 -> ...')
+
+        mu = torch.mean(features, dim=0).cpu()
+        sigma = torch.cov(features).cpu()
+        return mu, sigma
+
+    def score(self, real_samples, fake_samples):
+        if self.channels == 1:
+            real_samples, fake_samples = map(
+                lambda t: repeat(t, 'b 1 ... -> b c ...', c=3), (real_samples, fake_samples)
+            )
+
+        min_batch = min(real_samples.shape[0], fake_samples.shape[0])
+        real_samples, fake_samples = map(lambda t: t[:min_batch], (real_samples, fake_samples))
+
+        m1, s1 = self.calculate_activation_statistics(real_samples)
+        m2, s2 = self.calculate_activation_statistics(fake_samples)
+
+        fid_value = calculate_frechet_distance(m1, s1, m2, s2)
+        return fid_value
diff --git a/libs/metric/pytorch_fid/fid_score.py b/libs/metric/pytorch_fid/fid_score.py
new file mode 100644
index 0000000000000000000000000000000000000000..117e0c77d25afded5e63429bb0a27a71967530f5
--- /dev/null
+++ b/libs/metric/pytorch_fid/fid_score.py
@@ -0,0 +1,322 @@
+"""Calculates the Frechet Inception Distance (FID) to evalulate GANs
+
+The FID metric calculates the distance between two distributions of images.
+Typically, we have summary statistics (mean & covariance matrix) of one
+of these distributions, while the 2nd distribution is given by a GAN.
+
+When run as a stand-alone program, it compares the distribution of
+images that are stored as PNG/JPEG at a specified location with a
+distribution given by summary statistics (in pickle format).
+
+The FID is calculated by assuming that X_1 and X_2 are the activations of
+the pool_3 layer of the inception net for generated samples and real world
+samples respectively.
+
+See --help to see further details.
+
+Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
+of Tensorflow
+
+Copyright 2018 Institute of Bioinformatics, JKU Linz
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+   http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import os
+import pathlib
+from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
+
+import numpy as np
+import torch
+import torchvision.transforms as TF
+from PIL import Image
+from scipy import linalg
+from torch.nn.functional import adaptive_avg_pool2d
+
+try:
+    from tqdm import tqdm
+except ImportError:
+    # If tqdm is not available, provide a mock version of it
+    def tqdm(x):
+        return x
+
+from .inception import InceptionV3
+
+parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
+parser.add_argument('--batch-size', type=int, default=50,
+                    help='Batch size to use')
+parser.add_argument('--num-workers', type=int,
+                    help=('Number of processes to use for data loading. '
+                          'Defaults to `min(8, num_cpus)`'))
+parser.add_argument('--device', type=str, default=None,
+                    help='Device to use. Like cuda, cuda:0 or cpu')
+parser.add_argument('--dims', type=int, default=2048,
+                    choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
+                    help=('Dimensionality of Inception features to use. '
+                          'By default, uses pool3 features'))
+parser.add_argument('--save-stats', action='store_true',
+                    help=('Generate an npz archive from a directory of samples. '
+                          'The first path is used as input and the second as output.'))
+parser.add_argument('path', type=str, nargs=2,
+                    help=('Paths to the generated images or '
+                          'to .npz statistic files'))
+
+IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm',
+                    'tif', 'tiff', 'webp'}
+
+
+class ImagePathDataset(torch.utils.data.Dataset):
+    def __init__(self, files, transforms=None):
+        self.files = files
+        self.transforms = transforms
+
+    def __len__(self):
+        return len(self.files)
+
+    def __getitem__(self, i):
+        path = self.files[i]
+        img = Image.open(path).convert('RGB')
+        if self.transforms is not None:
+            img = self.transforms(img)
+        return img
+
+
+def get_activations(files, model, batch_size=50, dims=2048, device='cpu',
+                    num_workers=1):
+    """Calculates the activations of the pool_3 layer for all images.
+
+    Params:
+    -- files       : List of image files paths
+    -- model       : Instance of inception model
+    -- batch_size  : Batch size of images for the model to process at once.
+                     Make sure that the number of samples is a multiple of
+                     the batch size, otherwise some samples are ignored. This
+                     behavior is retained to match the original FID score
+                     implementation.
+    -- dims        : Dimensionality of features returned by Inception
+    -- device      : Device to run calculations
+    -- num_workers : Number of parallel dataloader workers
+
+    Returns:
+    -- A numpy array of dimension (num images, dims) that contains the
+       activations of the given tensor when feeding inception with the
+       query tensor.
+    """
+    model.eval()
+
+    if batch_size > len(files):
+        print(('Warning: batch size is bigger than the data size. '
+               'Setting batch size to data size'))
+        batch_size = len(files)
+
+    dataset = ImagePathDataset(files, transforms=TF.ToTensor())
+    dataloader = torch.utils.data.DataLoader(dataset,
+                                             batch_size=batch_size,
+                                             shuffle=False,
+                                             drop_last=False,
+                                             num_workers=num_workers)
+
+    pred_arr = np.empty((len(files), dims))
+
+    start_idx = 0
+
+    for batch in tqdm(dataloader):
+        batch = batch.to(device)
+
+        with torch.no_grad():
+            pred = model(batch)[0]
+
+        # If model output is not scalar, apply global spatial average pooling.
+        # This happens if you choose a dimensionality not equal 2048.
+        if pred.size(2) != 1 or pred.size(3) != 1:
+            pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
+
+        pred = pred.squeeze(3).squeeze(2).cpu().numpy()
+
+        pred_arr[start_idx:start_idx + pred.shape[0]] = pred
+
+        start_idx = start_idx + pred.shape[0]
+
+    return pred_arr
+
+
+def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
+    """Numpy implementation of the Frechet Distance.
+    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
+    and X_2 ~ N(mu_2, C_2) is
+            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
+
+    Stable version by Dougal J. Sutherland.
+
+    Params:
+    -- mu1   : Numpy array containing the activations of a layer of the
+               inception net (like returned by the function 'get_predictions')
+               for generated samples.
+    -- mu2   : The sample mean over activations, precalculated on an
+               representative data set.
+    -- sigma1: The covariance matrix over activations for generated samples.
+    -- sigma2: The covariance matrix over activations, precalculated on an
+               representative data set.
+
+    Returns:
+    --   : The Frechet Distance.
+    """
+
+    mu1 = np.atleast_1d(mu1)
+    mu2 = np.atleast_1d(mu2)
+
+    sigma1 = np.atleast_2d(sigma1)
+    sigma2 = np.atleast_2d(sigma2)
+
+    assert mu1.shape == mu2.shape, \
+        'Training and test mean vectors have different lengths'
+    assert sigma1.shape == sigma2.shape, \
+        'Training and test covariances have different dimensions'
+
+    diff = mu1 - mu2
+
+    # Product might be almost singular
+    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
+    if not np.isfinite(covmean).all():
+        msg = ('fid calculation produces singular product; '
+               'adding %s to diagonal of cov estimates') % eps
+        print(msg)
+        offset = np.eye(sigma1.shape[0]) * eps
+        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
+
+    # Numerical error might give slight imaginary component
+    if np.iscomplexobj(covmean):
+        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
+            m = np.max(np.abs(covmean.imag))
+            raise ValueError('Imaginary component {}'.format(m))
+        covmean = covmean.real
+
+    tr_covmean = np.trace(covmean)
+
+    return (diff.dot(diff) + np.trace(sigma1)
+            + np.trace(sigma2) - 2 * tr_covmean)
+
+
+def calculate_activation_statistics(files, model, batch_size=50, dims=2048,
+                                    device='cpu', num_workers=1):
+    """Calculation of the statistics used by the FID.
+    Params:
+    -- files       : List of image files paths
+    -- model       : Instance of inception model
+    -- batch_size  : The images numpy array is split into batches with
+                     batch size batch_size. A reasonable batch size
+                     depends on the hardware.
+    -- dims        : Dimensionality of features returned by Inception
+    -- device      : Device to run calculations
+    -- num_workers : Number of parallel dataloader workers
+
+    Returns:
+    -- mu    : The mean over samples of the activations of the pool_3 layer of
+               the inception model.
+    -- sigma : The covariance matrix of the activations of the pool_3 layer of
+               the inception model.
+    """
+    act = get_activations(files, model, batch_size, dims, device, num_workers)
+    mu = np.mean(act, axis=0)
+    sigma = np.cov(act, rowvar=False)
+    return mu, sigma
+
+
+def compute_statistics_of_path(path, model, batch_size, dims, device,
+                               num_workers=1):
+    if path.endswith('.npz'):
+        with np.load(path) as f:
+            m, s = f['mu'][:], f['sigma'][:]
+    else:
+        path = pathlib.Path(path)
+        files = sorted([file for ext in IMAGE_EXTENSIONS
+                        for file in path.glob('*.{}'.format(ext))])
+        m, s = calculate_activation_statistics(files, model, batch_size,
+                                               dims, device, num_workers)
+
+    return m, s
+
+
+def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=1):
+    """Calculates the FID of two paths"""
+    for p in paths:
+        if not os.path.exists(p):
+            raise RuntimeError('Invalid path: %s' % p)
+
+    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
+
+    model = InceptionV3([block_idx]).to(device)
+
+    m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
+                                        dims, device, num_workers)
+    m2, s2 = compute_statistics_of_path(paths[1], model, batch_size,
+                                        dims, device, num_workers)
+    fid_value = calculate_frechet_distance(m1, s1, m2, s2)
+
+    return fid_value
+
+
+def save_fid_stats(paths, batch_size, device, dims, num_workers=1):
+    """Calculates the FID of two paths"""
+    if not os.path.exists(paths[0]):
+        raise RuntimeError('Invalid path: %s' % paths[0])
+
+    if os.path.exists(paths[1]):
+        raise RuntimeError('Existing output file: %s' % paths[1])
+
+    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
+
+    model = InceptionV3([block_idx]).to(device)
+
+    print(f"Saving statistics for {paths[0]}")
+
+    m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
+                                        dims, device, num_workers)
+
+    np.savez_compressed(paths[1], mu=m1, sigma=s1)
+
+
+def main():
+    args = parser.parse_args()
+
+    if args.device is None:
+        device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
+    else:
+        device = torch.device(args.device)
+
+    if args.num_workers is None:
+        try:
+            num_cpus = len(os.sched_getaffinity(0))
+        except AttributeError:
+            # os.sched_getaffinity is not available under Windows, use
+            # os.cpu_count instead (which may not return the *available* number
+            # of CPUs).
+            num_cpus = os.cpu_count()
+
+        num_workers = min(num_cpus, 8) if num_cpus is not None else 0
+    else:
+        num_workers = args.num_workers
+
+    if args.save_stats:
+        save_fid_stats(args.path, args.batch_size, device, args.dims, num_workers)
+        return
+
+    fid_value = calculate_fid_given_paths(args.path,
+                                          args.batch_size,
+                                          device,
+                                          args.dims,
+                                          num_workers)
+    print('FID: ', fid_value)
+
+
+if __name__ == '__main__':
+    main()
diff --git a/libs/metric/pytorch_fid/inception.py b/libs/metric/pytorch_fid/inception.py
new file mode 100644
index 0000000000000000000000000000000000000000..8898a20c0609f5bb31df3641e783ea95db45b95f
--- /dev/null
+++ b/libs/metric/pytorch_fid/inception.py
@@ -0,0 +1,341 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision
+
+try:
+    from torchvision.models.utils import load_state_dict_from_url
+except ImportError:
+    from torch.utils.model_zoo import load_url as load_state_dict_from_url
+
+# Inception weights ported to Pytorch from
+# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
+FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'  # noqa: E501
+
+
+class InceptionV3(nn.Module):
+    """Pretrained InceptionV3 network returning feature maps"""
+
+    # Index of default block of inception to return,
+    # corresponds to output of final average pooling
+    DEFAULT_BLOCK_INDEX = 3
+
+    # Maps feature dimensionality to their output blocks indices
+    BLOCK_INDEX_BY_DIM = {
+        64: 0,   # First max pooling features
+        192: 1,  # Second max pooling featurs
+        768: 2,  # Pre-aux classifier features
+        2048: 3  # Final average pooling features
+    }
+
+    def __init__(self,
+                 output_blocks=(DEFAULT_BLOCK_INDEX,),
+                 resize_input=True,
+                 normalize_input=True,
+                 requires_grad=False,
+                 use_fid_inception=True):
+        """Build pretrained InceptionV3
+
+        Parameters
+        ----------
+        output_blocks : list of int
+            Indices of blocks to return features of. Possible values are:
+                - 0: corresponds to output of first max pooling
+                - 1: corresponds to output of second max pooling
+                - 2: corresponds to output which is fed to aux classifier
+                - 3: corresponds to output of final average pooling
+        resize_input : bool
+            If true, bilinearly resizes input to width and height 299 before
+            feeding input to model. As the network without fully connected
+            layers is fully convolutional, it should be able to handle inputs
+            of arbitrary size, so resizing might not be strictly needed
+        normalize_input : bool
+            If true, scales the input from range (0, 1) to the range the
+            pretrained Inception network expects, namely (-1, 1)
+        requires_grad : bool
+            If true, parameters of the model require gradients. Possibly useful
+            for finetuning the network
+        use_fid_inception : bool
+            If true, uses the pretrained Inception model used in Tensorflow's
+            FID implementation. If false, uses the pretrained Inception model
+            available in torchvision. The FID Inception model has different
+            weights and a slightly different structure from torchvision's
+            Inception model. If you want to compute FID scores, you are
+            strongly advised to set this parameter to true to get comparable
+            results.
+        """
+        super(InceptionV3, self).__init__()
+
+        self.resize_input = resize_input
+        self.normalize_input = normalize_input
+        self.output_blocks = sorted(output_blocks)
+        self.last_needed_block = max(output_blocks)
+
+        assert self.last_needed_block <= 3, \
+            'Last possible output block index is 3'
+
+        self.blocks = nn.ModuleList()
+
+        if use_fid_inception:
+            inception = fid_inception_v3()
+        else:
+            inception = _inception_v3(weights='DEFAULT')
+
+        # Block 0: input to maxpool1
+        block0 = [
+            inception.Conv2d_1a_3x3,
+            inception.Conv2d_2a_3x3,
+            inception.Conv2d_2b_3x3,
+            nn.MaxPool2d(kernel_size=3, stride=2)
+        ]
+        self.blocks.append(nn.Sequential(*block0))
+
+        # Block 1: maxpool1 to maxpool2
+        if self.last_needed_block >= 1:
+            block1 = [
+                inception.Conv2d_3b_1x1,
+                inception.Conv2d_4a_3x3,
+                nn.MaxPool2d(kernel_size=3, stride=2)
+            ]
+            self.blocks.append(nn.Sequential(*block1))
+
+        # Block 2: maxpool2 to aux classifier
+        if self.last_needed_block >= 2:
+            block2 = [
+                inception.Mixed_5b,
+                inception.Mixed_5c,
+                inception.Mixed_5d,
+                inception.Mixed_6a,
+                inception.Mixed_6b,
+                inception.Mixed_6c,
+                inception.Mixed_6d,
+                inception.Mixed_6e,
+            ]
+            self.blocks.append(nn.Sequential(*block2))
+
+        # Block 3: aux classifier to final avgpool
+        if self.last_needed_block >= 3:
+            block3 = [
+                inception.Mixed_7a,
+                inception.Mixed_7b,
+                inception.Mixed_7c,
+                nn.AdaptiveAvgPool2d(output_size=(1, 1))
+            ]
+            self.blocks.append(nn.Sequential(*block3))
+
+        for param in self.parameters():
+            param.requires_grad = requires_grad
+
+    def forward(self, inp):
+        """Get Inception feature maps
+
+        Parameters
+        ----------
+        inp : torch.autograd.Variable
+            Input tensor of shape Bx3xHxW. Values are expected to be in
+            range (0, 1)
+
+        Returns
+        -------
+        List of torch.autograd.Variable, corresponding to the selected output
+        block, sorted ascending by index
+        """
+        outp = []
+        x = inp
+
+        if self.resize_input:
+            x = F.interpolate(x,
+                              size=(299, 299),
+                              mode='bilinear',
+                              align_corners=False)
+
+        if self.normalize_input:
+            x = 2 * x - 1  # Scale from range (0, 1) to range (-1, 1)
+
+        for idx, block in enumerate(self.blocks):
+            x = block(x)
+            if idx in self.output_blocks:
+                outp.append(x)
+
+            if idx == self.last_needed_block:
+                break
+
+        return outp
+
+
+def _inception_v3(*args, **kwargs):
+    """Wraps `torchvision.models.inception_v3`"""
+    try:
+        version = tuple(map(int, torchvision.__version__.split('.')[:2]))
+    except ValueError:
+        # Just a caution against weird version strings
+        version = (0,)
+
+    # Skips default weight inititialization if supported by torchvision
+    # version. See https://github.com/mseitzer/pytorch-fid/issues/28.
+    if version >= (0, 6):
+        kwargs['init_weights'] = False
+
+    # Backwards compatibility: `weights` argument was handled by `pretrained`
+    # argument prior to version 0.13.
+    if version < (0, 13) and 'weights' in kwargs:
+        if kwargs['weights'] == 'DEFAULT':
+            kwargs['pretrained'] = True
+        elif kwargs['weights'] is None:
+            kwargs['pretrained'] = False
+        else:
+            raise ValueError(
+                'weights=={} not supported in torchvision {}'.format(
+                    kwargs['weights'], torchvision.__version__
+                )
+            )
+        del kwargs['weights']
+
+    return torchvision.models.inception_v3(*args, **kwargs)
+
+
+def fid_inception_v3():
+    """Build pretrained Inception model for FID computation
+
+    The Inception model for FID computation uses a different set of weights
+    and has a slightly different structure than torchvision's Inception.
+
+    This method first constructs torchvision's Inception and then patches the
+    necessary parts that are different in the FID Inception model.
+    """
+    inception = _inception_v3(num_classes=1008,
+                              aux_logits=False,
+                              weights=None)
+    inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
+    inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
+    inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
+    inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
+    inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
+    inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
+    inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
+    inception.Mixed_7b = FIDInceptionE_1(1280)
+    inception.Mixed_7c = FIDInceptionE_2(2048)
+
+    state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
+    inception.load_state_dict(state_dict)
+    return inception
+
+
+class FIDInceptionA(torchvision.models.inception.InceptionA):
+    """InceptionA block patched for FID computation"""
+    def __init__(self, in_channels, pool_features):
+        super(FIDInceptionA, self).__init__(in_channels, pool_features)
+
+    def forward(self, x):
+        branch1x1 = self.branch1x1(x)
+
+        branch5x5 = self.branch5x5_1(x)
+        branch5x5 = self.branch5x5_2(branch5x5)
+
+        branch3x3dbl = self.branch3x3dbl_1(x)
+        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
+
+        # Patch: Tensorflow's average pool does not use the padded zero's in
+        # its average calculation
+        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
+                                   count_include_pad=False)
+        branch_pool = self.branch_pool(branch_pool)
+
+        outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
+        return torch.cat(outputs, 1)
+
+
+class FIDInceptionC(torchvision.models.inception.InceptionC):
+    """InceptionC block patched for FID computation"""
+    def __init__(self, in_channels, channels_7x7):
+        super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
+
+    def forward(self, x):
+        branch1x1 = self.branch1x1(x)
+
+        branch7x7 = self.branch7x7_1(x)
+        branch7x7 = self.branch7x7_2(branch7x7)
+        branch7x7 = self.branch7x7_3(branch7x7)
+
+        branch7x7dbl = self.branch7x7dbl_1(x)
+        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
+        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
+        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
+        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
+
+        # Patch: Tensorflow's average pool does not use the padded zero's in
+        # its average calculation
+        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
+                                   count_include_pad=False)
+        branch_pool = self.branch_pool(branch_pool)
+
+        outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
+        return torch.cat(outputs, 1)
+
+
+class FIDInceptionE_1(torchvision.models.inception.InceptionE):
+    """First InceptionE block patched for FID computation"""
+    def __init__(self, in_channels):
+        super(FIDInceptionE_1, self).__init__(in_channels)
+
+    def forward(self, x):
+        branch1x1 = self.branch1x1(x)
+
+        branch3x3 = self.branch3x3_1(x)
+        branch3x3 = [
+            self.branch3x3_2a(branch3x3),
+            self.branch3x3_2b(branch3x3),
+        ]
+        branch3x3 = torch.cat(branch3x3, 1)
+
+        branch3x3dbl = self.branch3x3dbl_1(x)
+        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+        branch3x3dbl = [
+            self.branch3x3dbl_3a(branch3x3dbl),
+            self.branch3x3dbl_3b(branch3x3dbl),
+        ]
+        branch3x3dbl = torch.cat(branch3x3dbl, 1)
+
+        # Patch: Tensorflow's average pool does not use the padded zero's in
+        # its average calculation
+        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
+                                   count_include_pad=False)
+        branch_pool = self.branch_pool(branch_pool)
+
+        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
+        return torch.cat(outputs, 1)
+
+
+class FIDInceptionE_2(torchvision.models.inception.InceptionE):
+    """Second InceptionE block patched for FID computation"""
+    def __init__(self, in_channels):
+        super(FIDInceptionE_2, self).__init__(in_channels)
+
+    def forward(self, x):
+        branch1x1 = self.branch1x1(x)
+
+        branch3x3 = self.branch3x3_1(x)
+        branch3x3 = [
+            self.branch3x3_2a(branch3x3),
+            self.branch3x3_2b(branch3x3),
+        ]
+        branch3x3 = torch.cat(branch3x3, 1)
+
+        branch3x3dbl = self.branch3x3dbl_1(x)
+        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+        branch3x3dbl = [
+            self.branch3x3dbl_3a(branch3x3dbl),
+            self.branch3x3dbl_3b(branch3x3dbl),
+        ]
+        branch3x3dbl = torch.cat(branch3x3dbl, 1)
+
+        # Patch: The FID Inception model uses max pooling instead of average
+        # pooling. This is likely an error in this specific Inception
+        # implementation, as other Inception models use average pooling here
+        # (which matches the description in the paper).
+        branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
+        branch_pool = self.branch_pool(branch_pool)
+
+        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
+        return torch.cat(outputs, 1)
diff --git a/libs/modules/__init__.py b/libs/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ad761f2f5443eb41b15afc4116a66ecdfa9d918
--- /dev/null
+++ b/libs/modules/__init__.py
@@ -0,0 +1,4 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) XiMing Xing. All rights reserved.
+# Author: XiMing Xing
+# Description:
diff --git a/libs/modules/edge_map/DoG/XDoG.py b/libs/modules/edge_map/DoG/XDoG.py
new file mode 100644
index 0000000000000000000000000000000000000000..4553df9deec12af5b88cee4701300a95ab17ebdc
--- /dev/null
+++ b/libs/modules/edge_map/DoG/XDoG.py
@@ -0,0 +1,78 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) XiMing Xing. All rights reserved.
+# Author: XiMing Xing
+# Description:
+
+import numpy as np
+import cv2
+from scipy import ndimage as ndi
+from skimage import filters
+
+
+class XDoG:
+
+    def __init__(self,
+                 gamma=0.98,
+                 phi=200,
+                 eps=-0.1,
+                 sigma=0.8,
+                 k=10,
+                 binarize: bool = True):
+        """
+        XDoG algorithm.
+
+        Args:
+            gamma: Control the size of the Gaussian filter
+            phi: Control changes in edge strength
+            eps: Threshold for controlling edge strength
+            sigma: The standard deviation of the Gaussian filter controls the degree of smoothness
+            k: Control the size ratio of Gaussian filter, (k=10 or k=1.6)
+            binarize(bool): Whether to binarize the output
+        """
+
+        super(XDoG, self).__init__()
+
+        self.gamma = gamma
+        assert 0 <= self.gamma <= 1
+
+        self.phi = phi
+        assert 0 <= self.phi <= 1500
+
+        self.eps = eps
+        assert -1 <= self.eps <= 1
+
+        self.sigma = sigma
+        assert 0.1 <= self.sigma <= 10
+
+        self.k = k
+        assert 1 <= self.k <= 100
+
+        self.binarize = binarize
+
+    def __call__(self, img):
+        # to gray if image is not already grayscale
+        if len(img.shape) == 3 and img.shape[2] == 3:
+            img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+        elif len(img.shape) == 3 and img.shape[2] == 4:
+            img = cv2.cvtColor(img, cv2.COLOR_BGRA2GRAY)
+
+        if np.isnan(img).any():
+            img[np.isnan(img)] = np.mean(img[~np.isnan(img)])
+
+        # gaussian filter
+        imf1 = ndi.gaussian_filter(img, self.sigma)
+        imf2 = ndi.gaussian_filter(img, self.sigma * self.k)
+        imdiff = imf1 - self.gamma * imf2
+
+        # XDoG
+        imdiff = (imdiff < self.eps) * 1.0 + (imdiff >= self.eps) * (1.0 + np.tanh(self.phi * imdiff))
+
+        # normalize
+        imdiff -= imdiff.min()
+        imdiff /= imdiff.max()
+
+        if self.binarize:
+            th = filters.threshold_otsu(imdiff)
+            imdiff = (imdiff >= th).astype('float32')
+
+        return imdiff
diff --git a/libs/modules/edge_map/DoG/__init__.py b/libs/modules/edge_map/DoG/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bea1d52fb83eb6335ce73cf4a1e4c7fb28fa671
--- /dev/null
+++ b/libs/modules/edge_map/DoG/__init__.py
@@ -0,0 +1,8 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) XiMing Xing. All rights reserved.
+# Author: XiMing Xing
+# Description:
+
+from .XDoG import XDoG
+
+__all__ = ['XDoG']
diff --git a/libs/modules/edge_map/__init__.py b/libs/modules/edge_map/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ad761f2f5443eb41b15afc4116a66ecdfa9d918
--- /dev/null
+++ b/libs/modules/edge_map/__init__.py
@@ -0,0 +1,4 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) XiMing Xing. All rights reserved.
+# Author: XiMing Xing
+# Description:
diff --git a/libs/modules/edge_map/canny/__init__.py b/libs/modules/edge_map/canny/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbed53c90ffbadec00cbaa41552c71f2818dc9b6
--- /dev/null
+++ b/libs/modules/edge_map/canny/__init__.py
@@ -0,0 +1,15 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) XiMing Xing. All rights reserved.
+# Author: XiMing Xing
+# Description:
+
+import cv2
+
+
+class CannyDetector:
+
+    def __call__(self, img, low_threshold, high_threshold, L2gradient=False):
+        return cv2.Canny(img, low_threshold, high_threshold, L2gradient)
+
+
+__all__ = ['CannyDetector']
diff --git a/libs/modules/edge_map/image_grads/__init__.py b/libs/modules/edge_map/image_grads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..63070b57210f6d66256a299a3ea4538d7350127f
--- /dev/null
+++ b/libs/modules/edge_map/image_grads/__init__.py
@@ -0,0 +1,8 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) XiMing Xing. All rights reserved.
+# Author: XiMing Xing
+# Description:
+
+from .laplacian import LaplacianDetector
+
+__all__ = ['LaplacianDetector']
diff --git a/libs/modules/edge_map/image_grads/laplacian.py b/libs/modules/edge_map/image_grads/laplacian.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee1aa25286b5d46bd4578c8240ea473ca391c6e2
--- /dev/null
+++ b/libs/modules/edge_map/image_grads/laplacian.py
@@ -0,0 +1,13 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) XiMing Xing. All rights reserved.
+# Author: XiMing Xing
+# Description:
+
+
+import cv2
+
+
+class LaplacianDetector:
+
+    def __call__(self, img):
+        return cv2.Laplacian(img, cv2.CV_64F)
diff --git a/libs/modules/ema.py b/libs/modules/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca8667b19494ff2996d2d17e3fda0c2d9be24a3e
--- /dev/null
+++ b/libs/modules/ema.py
@@ -0,0 +1,198 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) XiMing Xing. All rights reserved.
+# Description: EMA model
+
+import copy
+
+import torch
+import torch.nn as nn
+
+__all__ = ['EMA']
+
+
+class EMA(nn.Module):
+    """
+    Implements exponential moving average shadowing for your model.
+    Utilizes an inverse decay schedule to manage longer term training runs.
+    By adjusting the power, you can control how fast EMA will ramp up to your specified beta.
+    @crowsonkb's notes on EMA Warmup:
+    If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are
+    good values for models you plan to train for a million or more steps (reaches decay
+    factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models
+    you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
+    215.4k steps).
+    Args:
+        inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
+        power (float): Exponential factor of EMA warmup. Default: 1.
+        min_value (float): The minimum EMA decay rate. Default: 0.
+    """
+
+    def __init__(
+            self,
+            model,
+            # if your model has lazylinears or other types of non-deepcopyable modules,
+            # you can pass in your own ema model
+            ema_model=None,
+            beta=0.9999,
+            update_after_step=100,
+            update_every=10,
+            inv_gamma=1.0,
+            power=2 / 3,
+            min_value=0.0,
+            param_or_buffer_names_no_ema=set(),
+            ignore_names=set(),
+            ignore_startswith_names=set(),
+            # set this to False if you do not wish for the online model to be
+            # saved along with the ema model (managed externally)
+            include_online_model=True
+    ):
+        super().__init__()
+        self.beta = beta
+
+        # whether to include the online model within the module tree, so that state_dict also saves it
+        self.include_online_model = include_online_model
+
+        if include_online_model:
+            self.online_model = model
+        else:
+            self.online_model = [model]  # hack
+
+        # ema model
+        self.ema_model = ema_model
+
+        if not exists(self.ema_model):
+            try:
+                self.ema_model = copy.deepcopy(model)
+            except:
+                print('Your model was not copyable. Please make sure you are not using any LazyLinear')
+                exit()
+
+        self.ema_model.requires_grad_(False)
+
+        self.parameter_names = {name for name, param in self.ema_model.named_parameters() if param.dtype == torch.float}
+        self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if buffer.dtype == torch.float}
+
+        self.update_every = update_every
+        self.update_after_step = update_after_step
+
+        self.inv_gamma = inv_gamma
+        self.power = power
+        self.min_value = min_value
+
+        assert isinstance(param_or_buffer_names_no_ema, (set, list))
+        self.param_or_buffer_names_no_ema = param_or_buffer_names_no_ema  # parameter or buffer
+
+        self.ignore_names = ignore_names
+        self.ignore_startswith_names = ignore_startswith_names
+
+        self.register_buffer('initted', torch.Tensor([False]))
+        self.register_buffer('step', torch.tensor([0]))
+
+    @property
+    def model(self):
+        return self.online_model if self.include_online_model else self.online_model[0]
+
+    def restore_ema_model_device(self):
+        device = self.initted.device
+        self.ema_model.to(device)
+
+    def get_params_iter(self, model):
+        for name, param in model.named_parameters():
+            if name not in self.parameter_names:
+                continue
+            yield name, param
+
+    def get_buffers_iter(self, model):
+        for name, buffer in model.named_buffers():
+            if name not in self.buffer_names:
+                continue
+            yield name, buffer
+
+    def copy_params_from_model_to_ema(self):
+        for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model),
+                                                       self.get_params_iter(self.model)):
+            ma_params.data.copy_(current_params.data)
+
+        for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model),
+                                                         self.get_buffers_iter(self.model)):
+            ma_buffers.data.copy_(current_buffers.data)
+
+    def get_current_decay(self):
+        epoch = clamp(self.step.item() - self.update_after_step - 1, min_value=0.)
+        value = 1 - (1 + epoch / self.inv_gamma) ** - self.power
+
+        if epoch <= 0:
+            return 0.
+
+        return clamp(value, min_value=self.min_value, max_value=self.beta)
+
+    def update(self):
+        step = self.step.item()
+        self.step += 1
+
+        if (step % self.update_every) != 0:
+            return
+
+        if step <= self.update_after_step:
+            self.copy_params_from_model_to_ema()
+            return
+
+        if not self.initted.item():
+            self.copy_params_from_model_to_ema()
+            self.initted.data.copy_(torch.Tensor([True]))
+
+        self.update_moving_average(self.ema_model, self.model)
+
+    @torch.no_grad()
+    def update_moving_average(self, ma_model, current_model):
+        current_decay = self.get_current_decay()
+
+        for (name, current_params), (_, ma_params) in zip(self.get_params_iter(current_model),
+                                                          self.get_params_iter(ma_model)):
+            if name in self.ignore_names:
+                continue
+
+            if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
+                continue
+
+            if name in self.param_or_buffer_names_no_ema:
+                ma_params.data.copy_(current_params.data)
+                continue
+
+            ma_params.data.lerp_(current_params.data, 1. - current_decay)
+
+        for (name, current_buffer), (_, ma_buffer) in zip(self.get_buffers_iter(current_model),
+                                                          self.get_buffers_iter(ma_model)):
+            if name in self.ignore_names:
+                continue
+
+            if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
+                continue
+
+            if name in self.param_or_buffer_names_no_ema:
+                ma_buffer.data.copy_(current_buffer.data)
+                continue
+
+            ma_buffer.data.lerp_(current_buffer.data, 1. - current_decay)
+
+    def __call__(self, *args, **kwargs):
+        return self.ema_model(*args, **kwargs)
+
+
+def exists(val):
+    return val is not None
+
+
+def is_float_dtype(dtype):
+    return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)])
+
+
+def clamp(value, min_value=None, max_value=None):
+    assert exists(min_value) or exists(max_value)
+    if exists(min_value):
+        value = max(value, min_value)
+
+    if exists(max_value):
+        value = min(value, max_value)
+
+    return value
diff --git a/libs/modules/vision/__init__.py b/libs/modules/vision/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0804535662dfb998c093d164f45950c7bffefca8
--- /dev/null
+++ b/libs/modules/vision/__init__.py
@@ -0,0 +1,12 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) XiMing Xing. All rights reserved.
+# Author: XiMing Xing
+# Description:
+
+from .inception import inception_v3
+from .vgg import VGG
+
+__all__ = [
+    'inception_v3',
+    'VGG'
+]
diff --git a/libs/modules/vision/inception.py b/libs/modules/vision/inception.py
new file mode 100644
index 0000000000000000000000000000000000000000..45f20be02e0863329d996a77df68bedaf56deafc
--- /dev/null
+++ b/libs/modules/vision/inception.py
@@ -0,0 +1,482 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) XiMing Xing. All rights reserved.
+# Author: XiMing Xing
+# Description:
+
+from collections import namedtuple
+import warnings
+from typing import Callable, Any, Optional, Tuple, List
+
+import torch
+from torch import nn, Tensor
+import torch.nn.functional as F
+from torch.utils.model_zoo import load_url as load_state_dict_from_url
+
+__all__ = ['Inception3', 'inception_v3', 'InceptionOutputs', '_InceptionOutputs']
+
+model_urls = {
+    # Inception v3 ported from TensorFlow
+    'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth',
+}
+
+InceptionOutputs = namedtuple('InceptionOutputs', ['logits', 'aux_logits'])
+InceptionOutputs.__annotations__ = {'logits': Tensor, 'aux_logits': Optional[Tensor]}
+
+# Script annotations failed with _GoogleNetOutputs = namedtuple ...
+# _InceptionOutputs set here for backwards compat
+_InceptionOutputs = InceptionOutputs
+
+
+def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> "Inception3":
+    r"""Inception v3 model architecture from
+    `"Rethinking the Inception Architecture for Computer Vision" `_.
+
+    .. note::
+        **Important**: In contrast to the other models the inception_v3 expects tensors with a size of
+        N x 3 x 299 x 299, so ensure your images are sized accordingly.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+        aux_logits (bool): If True, add an auxiliary branch that can improve training.
+            Default: *True*
+        transform_input (bool): If True, preprocesses the input according to the method with which it
+            was trained on ImageNet. Default: *False*
+    """
+    if pretrained:
+        if 'transform_input' not in kwargs:
+            kwargs['transform_input'] = True
+        if 'aux_logits' in kwargs:
+            original_aux_logits = kwargs['aux_logits']
+            kwargs['aux_logits'] = True
+        else:
+            original_aux_logits = True
+        kwargs['init_weights'] = False  # we are loading weights from a pretrained model
+        model = Inception3(**kwargs)
+        state_dict = load_state_dict_from_url(model_urls['inception_v3_google'],
+                                              progress=progress)
+        model.load_state_dict(state_dict)
+        if not original_aux_logits:
+            model.aux_logits = False
+            model.AuxLogits = None
+        return model
+
+    return Inception3(**kwargs)
+
+
+class Inception3(nn.Module):
+
+    def __init__(
+            self,
+            num_classes: int = 1000,
+            aux_logits: bool = True,
+            transform_input: bool = False,
+            inception_blocks: Optional[List[Callable[..., nn.Module]]] = None,
+            init_weights: Optional[bool] = None
+    ) -> None:
+        super(Inception3, self).__init__()
+        if inception_blocks is None:
+            inception_blocks = [
+                BasicConv2d, InceptionA, InceptionB, InceptionC,
+                InceptionD, InceptionE, InceptionAux
+            ]
+        if init_weights is None:
+            warnings.warn('The default weight initialization of inception_v3 will be changed in future releases of '
+                          'torchvision. If you wish to keep the old behavior (which leads to long initialization times'
+                          ' due to scipy/scipy#11299), please set init_weights=True.', FutureWarning)
+            init_weights = True
+        assert len(inception_blocks) == 7
+        conv_block = inception_blocks[0]
+        inception_a = inception_blocks[1]
+        inception_b = inception_blocks[2]
+        inception_c = inception_blocks[3]
+        inception_d = inception_blocks[4]
+        inception_e = inception_blocks[5]
+        inception_aux = inception_blocks[6]
+
+        self.aux_logits = aux_logits
+        self.transform_input = transform_input
+        self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2)
+        self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
+        self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
+        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2)
+        self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
+        self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
+        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2)
+        self.Mixed_5b = inception_a(192, pool_features=32)
+        self.Mixed_5c = inception_a(256, pool_features=64)
+        self.Mixed_5d = inception_a(288, pool_features=64)
+        self.Mixed_6a = inception_b(288)
+        self.Mixed_6b = inception_c(768, channels_7x7=128)
+        self.Mixed_6c = inception_c(768, channels_7x7=160)
+        self.Mixed_6d = inception_c(768, channels_7x7=160)
+        self.Mixed_6e = inception_c(768, channels_7x7=192)
+        self.AuxLogits: Optional[nn.Module] = None
+        if aux_logits:
+            self.AuxLogits = inception_aux(768, num_classes)
+        self.Mixed_7a = inception_d(768)
+        self.Mixed_7b = inception_e(1280)
+        self.Mixed_7c = inception_e(2048)
+        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+        self.dropout = nn.Dropout()
+        self.fc = nn.Linear(2048, num_classes)
+        if init_weights:
+            for m in self.modules():
+                if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
+                    import scipy.stats as stats
+                    stddev = m.stddev if hasattr(m, 'stddev') else 0.1
+                    X = stats.truncnorm(-2, 2, scale=stddev)
+                    values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
+                    values = values.view(m.weight.size())
+                    with torch.no_grad():
+                        m.weight.copy_(values)
+                elif isinstance(m, nn.BatchNorm2d):
+                    nn.init.constant_(m.weight, 1)
+                    nn.init.constant_(m.bias, 0)
+
+    def _transform_input(self, x: Tensor) -> Tensor:
+        if self.transform_input:
+            x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
+            x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
+            x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
+            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
+        return x
+
+    def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
+        # N x 3 x 299 x 299
+        x = self.Conv2d_1a_3x3(x)
+        # N x 32 x 149 x 149
+        x = self.Conv2d_2a_3x3(x)
+        # N x 32 x 147 x 147
+        x = self.Conv2d_2b_3x3(x)
+        # N x 64 x 147 x 147
+        feat = self.maxpool1(x)
+        # N x 64 x 73 x 73
+        x = self.Conv2d_3b_1x1(feat)
+        # N x 80 x 73 x 73
+        x = self.Conv2d_4a_3x3(x)
+        # N x 192 x 71 x 71
+        x = self.maxpool2(x)
+        # N x 192 x 35 x 35
+        x = self.Mixed_5b(x)
+        # N x 256 x 35 x 35
+        x = self.Mixed_5c(x)
+        # N x 288 x 35 x 35
+        x = self.Mixed_5d(x)
+        # N x 288 x 35 x 35
+        x = self.Mixed_6a(x)
+        # N x 768 x 17 x 17
+        x = self.Mixed_6b(x)
+        # N x 768 x 17 x 17
+        x = self.Mixed_6c(x)
+        # N x 768 x 17 x 17
+        x = self.Mixed_6d(x)
+        # N x 768 x 17 x 17
+        x = self.Mixed_6e(x)
+        # N x 768 x 17 x 17
+        aux: Optional[Tensor] = None
+        if self.AuxLogits is not None:
+            if self.training:
+                aux = self.AuxLogits(x)
+        # N x 768 x 17 x 17
+        x = self.Mixed_7a(x)
+        # N x 1280 x 8 x 8
+        x = self.Mixed_7b(x)
+        # N x 2048 x 8 x 8
+        x = self.Mixed_7c(x)
+        # N x 2048 x 8 x 8
+        # Adaptive average pooling
+        x = self.avgpool(x)
+        # N x 2048 x 1 x 1
+        x = self.dropout(x)
+        # N x 2048 x 1 x 1
+        x = torch.flatten(x, 1)
+        # N x 2048
+        x = self.fc(x)
+        # N x 1000 (num_classes)
+        return feat, x, aux
+
+    @torch.jit.unused
+    def eager_outputs(self, x: Tensor, aux: Optional[Tensor]) -> InceptionOutputs:
+        if self.training and self.aux_logits:
+            return InceptionOutputs(x, aux)
+        else:
+            return x  # type: ignore[return-value]
+
+    def forward(self, x: Tensor) -> InceptionOutputs:
+        x = self._transform_input(x)
+        feat, x, aux = self._forward(x)
+        aux_defined = self.training and self.aux_logits
+        if torch.jit.is_scripting():
+            if not aux_defined:
+                warnings.warn("Scripted Inception3 always returns Inception3 Tuple")
+            return feat, InceptionOutputs(x, aux)
+        else:
+            return feat, self.eager_outputs(x, aux)
+
+
+class InceptionA(nn.Module):
+
+    def __init__(
+            self,
+            in_channels: int,
+            pool_features: int,
+            conv_block: Optional[Callable[..., nn.Module]] = None
+    ) -> None:
+        super(InceptionA, self).__init__()
+        if conv_block is None:
+            conv_block = BasicConv2d
+        self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)
+
+        self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
+        self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)
+
+        self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
+        self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
+        self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)
+
+        self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)
+
+    def _forward(self, x: Tensor) -> List[Tensor]:
+        branch1x1 = self.branch1x1(x)
+
+        branch5x5 = self.branch5x5_1(x)
+        branch5x5 = self.branch5x5_2(branch5x5)
+
+        branch3x3dbl = self.branch3x3dbl_1(x)
+        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
+
+        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
+        branch_pool = self.branch_pool(branch_pool)
+
+        outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
+        return outputs
+
+    def forward(self, x: Tensor) -> Tensor:
+        outputs = self._forward(x)
+        return torch.cat(outputs, 1)
+
+
+class InceptionB(nn.Module):
+
+    def __init__(
+            self,
+            in_channels: int,
+            conv_block: Optional[Callable[..., nn.Module]] = None
+    ) -> None:
+        super(InceptionB, self).__init__()
+        if conv_block is None:
+            conv_block = BasicConv2d
+        self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
+
+        self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
+        self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
+        self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
+
+    def _forward(self, x: Tensor) -> List[Tensor]:
+        branch3x3 = self.branch3x3(x)
+
+        branch3x3dbl = self.branch3x3dbl_1(x)
+        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
+
+        branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
+
+        outputs = [branch3x3, branch3x3dbl, branch_pool]
+        return outputs
+
+    def forward(self, x: Tensor) -> Tensor:
+        outputs = self._forward(x)
+        return torch.cat(outputs, 1)
+
+
+class InceptionC(nn.Module):
+
+    def __init__(
+            self,
+            in_channels: int,
+            channels_7x7: int,
+            conv_block: Optional[Callable[..., nn.Module]] = None
+    ) -> None:
+        super(InceptionC, self).__init__()
+        if conv_block is None:
+            conv_block = BasicConv2d
+        self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)
+
+        c7 = channels_7x7
+        self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
+        self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
+        self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))
+
+        self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
+        self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
+        self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
+        self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
+        self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))
+
+        self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
+
+    def _forward(self, x: Tensor) -> List[Tensor]:
+        branch1x1 = self.branch1x1(x)
+
+        branch7x7 = self.branch7x7_1(x)
+        branch7x7 = self.branch7x7_2(branch7x7)
+        branch7x7 = self.branch7x7_3(branch7x7)
+
+        branch7x7dbl = self.branch7x7dbl_1(x)
+        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
+        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
+        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
+        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
+
+        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
+        branch_pool = self.branch_pool(branch_pool)
+
+        outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
+        return outputs
+
+    def forward(self, x: Tensor) -> Tensor:
+        outputs = self._forward(x)
+        return torch.cat(outputs, 1)
+
+
+class InceptionD(nn.Module):
+
+    def __init__(
+            self,
+            in_channels: int,
+            conv_block: Optional[Callable[..., nn.Module]] = None
+    ) -> None:
+        super(InceptionD, self).__init__()
+        if conv_block is None:
+            conv_block = BasicConv2d
+        self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
+        self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
+
+        self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
+        self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
+        self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
+        self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
+
+    def _forward(self, x: Tensor) -> List[Tensor]:
+        branch3x3 = self.branch3x3_1(x)
+        branch3x3 = self.branch3x3_2(branch3x3)
+
+        branch7x7x3 = self.branch7x7x3_1(x)
+        branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
+        branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
+        branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
+
+        branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
+        outputs = [branch3x3, branch7x7x3, branch_pool]
+        return outputs
+
+    def forward(self, x: Tensor) -> Tensor:
+        outputs = self._forward(x)
+        return torch.cat(outputs, 1)
+
+
+class InceptionE(nn.Module):
+
+    def __init__(
+            self,
+            in_channels: int,
+            conv_block: Optional[Callable[..., nn.Module]] = None
+    ) -> None:
+        super(InceptionE, self).__init__()
+        if conv_block is None:
+            conv_block = BasicConv2d
+        self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)
+
+        self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
+        self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
+        self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
+
+        self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
+        self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
+        self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
+        self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
+
+        self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
+
+    def _forward(self, x: Tensor) -> List[Tensor]:
+        branch1x1 = self.branch1x1(x)
+
+        branch3x3 = self.branch3x3_1(x)
+        branch3x3 = [
+            self.branch3x3_2a(branch3x3),
+            self.branch3x3_2b(branch3x3),
+        ]
+        branch3x3 = torch.cat(branch3x3, 1)
+
+        branch3x3dbl = self.branch3x3dbl_1(x)
+        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+        branch3x3dbl = [
+            self.branch3x3dbl_3a(branch3x3dbl),
+            self.branch3x3dbl_3b(branch3x3dbl),
+        ]
+        branch3x3dbl = torch.cat(branch3x3dbl, 1)
+
+        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
+        branch_pool = self.branch_pool(branch_pool)
+
+        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
+        return outputs
+
+    def forward(self, x: Tensor) -> Tensor:
+        outputs = self._forward(x)
+        return torch.cat(outputs, 1)
+
+
+class InceptionAux(nn.Module):
+
+    def __init__(
+            self,
+            in_channels: int,
+            num_classes: int,
+            conv_block: Optional[Callable[..., nn.Module]] = None
+    ) -> None:
+        super(InceptionAux, self).__init__()
+        if conv_block is None:
+            conv_block = BasicConv2d
+        self.conv0 = conv_block(in_channels, 128, kernel_size=1)
+        self.conv1 = conv_block(128, 768, kernel_size=5)
+        self.conv1.stddev = 0.01  # type: ignore[assignment]
+        self.fc = nn.Linear(768, num_classes)
+        self.fc.stddev = 0.001  # type: ignore[assignment]
+
+    def forward(self, x: Tensor) -> Tensor:
+        # N x 768 x 17 x 17
+        x = F.avg_pool2d(x, kernel_size=5, stride=3)
+        # N x 768 x 5 x 5
+        x = self.conv0(x)
+        # N x 128 x 5 x 5
+        x = self.conv1(x)
+        # N x 768 x 1 x 1
+        # Adaptive average pooling
+        x = F.adaptive_avg_pool2d(x, (1, 1))
+        # N x 768 x 1 x 1
+        x = torch.flatten(x, 1)
+        # N x 768
+        x = self.fc(x)
+        # N x 1000
+        return x
+
+
+class BasicConv2d(nn.Module):
+
+    def __init__(
+            self,
+            in_channels: int,
+            out_channels: int,
+            **kwargs: Any
+    ) -> None:
+        super(BasicConv2d, self).__init__()
+        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
+        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
+
+    def forward(self, x: Tensor) -> Tensor:
+        x = self.conv(x)
+        x = self.bn(x)
+        return F.relu(x, inplace=True)
diff --git a/libs/modules/vision/vgg.py b/libs/modules/vision/vgg.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5e06c1dee136b1c01cc891f0bbed3da5553e066
--- /dev/null
+++ b/libs/modules/vision/vgg.py
@@ -0,0 +1,194 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) XiMing Xing. All rights reserved.
+# Author: XiMing Xing
+# Description:
+
+from typing import Union, List, Dict, Any, cast
+
+import torch
+import torch.nn as nn
+from torch.utils.model_zoo import load_url as load_state_dict_from_url
+
+__all__ = [
+    'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
+    'vgg19_bn', 'vgg19',
+]
+
+model_urls = {
+    'vgg11': 'https://download.pytorch.org/models/vgg11-8a719046.pth',
+    'vgg13': 'https://download.pytorch.org/models/vgg13-19584684.pth',
+    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
+    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
+    'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
+    'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
+    'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
+    'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
+}
+
+
+class VGG(nn.Module):
+
+    def __init__(
+            self,
+            features: nn.Module,
+            num_classes: int = 1000,
+            init_weights: bool = True
+    ) -> None:
+        super(VGG, self).__init__()
+        self.features = features
+        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
+        self.classifier = nn.Sequential(
+            nn.Linear(512 * 7 * 7, 4096),
+            nn.ReLU(True),
+            nn.Dropout(),
+            nn.Linear(4096, 4096),
+            nn.ReLU(True),
+            nn.Dropout(),
+            nn.Linear(4096, num_classes),
+        )
+        if init_weights:
+            self._initialize_weights()
+
+    def forward(self, x: torch.Tensor):
+        feat = self.features(x)
+        x = self.avgpool(feat)
+        x = torch.flatten(x, 1)
+        x = self.classifier(x)
+        return feat, x
+
+    def _initialize_weights(self) -> None:
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+                if m.bias is not None:
+                    nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.BatchNorm2d):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.Linear):
+                nn.init.normal_(m.weight, 0, 0.01)
+                nn.init.constant_(m.bias, 0)
+
+
+def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential:
+    layers: List[nn.Module] = []
+    in_channels = 3
+    for v in cfg:
+        if v == 'M':
+            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
+        else:
+            v = cast(int, v)
+            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
+            if batch_norm:
+                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
+            else:
+                layers += [conv2d, nn.ReLU(inplace=True)]
+            in_channels = v
+    return nn.Sequential(*layers)
+
+
+cfgs: Dict[str, List[Union[str, int]]] = {
+    'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
+    'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
+    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
+    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
+}
+
+
+def _vgg(arch: str, cfg: str, batch_norm: bool, pretrained: bool, progress: bool, **kwargs: Any) -> VGG:
+    if pretrained:
+        kwargs['init_weights'] = False
+    model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
+    if pretrained:
+        state_dict = load_state_dict_from_url(model_urls[arch],
+                                              progress=progress)
+        model.load_state_dict(state_dict)
+    return model
+
+
+def vgg11(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
+    r"""VGG 11-layer model (configuration "A") from
+    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs)
+
+
+def vgg11_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
+    r"""VGG 11-layer model (configuration "A") with batch normalization
+    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs)
+
+
+def vgg13(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
+    r"""VGG 13-layer model (configuration "B")
+    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs)
+
+
+def vgg13_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
+    r"""VGG 13-layer model (configuration "B") with batch normalization
+    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs)
+
+
+def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
+    r"""VGG 16-layer model (configuration "D")
+    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs)
+
+
+def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
+    r"""VGG 16-layer model (configuration "D") with batch normalization
+    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs)
+
+
+def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
+    r"""VGG 19-layer model (configuration "E")
+    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs)
+
+
+def vgg19_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
+    r"""VGG 19-layer model (configuration 'E') with batch normalization
+    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs)
diff --git a/libs/modules/visual/__init__.py b/libs/modules/visual/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ad761f2f5443eb41b15afc4116a66ecdfa9d918
--- /dev/null
+++ b/libs/modules/visual/__init__.py
@@ -0,0 +1,4 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) XiMing Xing. All rights reserved.
+# Author: XiMing Xing
+# Description:
diff --git a/libs/modules/visual/imshow.py b/libs/modules/visual/imshow.py
new file mode 100644
index 0000000000000000000000000000000000000000..896670001d6c3a40aab3b0d66991e21f64fe8f21
--- /dev/null
+++ b/libs/modules/visual/imshow.py
@@ -0,0 +1,177 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) XiMing Xing. All rights reserved.
+# Author: XiMing Xing
+# Description:
+
+import pathlib
+from typing import Union, List, Text, BinaryIO, AnyStr
+
+import matplotlib.pyplot as plt
+import torch
+import torchvision.transforms as transforms
+from torchvision.utils import make_grid
+
+__all__ = [
+    'sample2pil_transforms',
+    'pt2numpy_transforms',
+    'plt_pt_img',
+    'save_grid_images_and_labels',
+    'save_grid_images_and_captions',
+]
+
+# generate sample to PIL images
+sample2pil_transforms = transforms.Compose([
+    # unnormalizing to [0,1]
+    transforms.Lambda(lambda t: torch.clamp((t + 1) / 2, min=0.0, max=1.0)),
+    # Add 0.5 after unnormalizing to [0, 255]
+    transforms.Lambda(lambda t: torch.clamp(t * 255. + 0.5, min=0, max=255)),
+    # CHW to HWC
+    transforms.Lambda(lambda t: t.permute(1, 2, 0)),
+    # to numpy ndarray, dtype int8
+    transforms.Lambda(lambda t: t.to('cpu', torch.uint8).numpy()),
+    # Converts a numpy ndarray of shape H x W x C to a PIL Image
+    transforms.ToPILImage(),
+])
+
+# generate sample to PIL images
+pt2numpy_transforms = transforms.Compose([
+    # Add 0.5 after unnormalizing to [0, 255]
+    transforms.Lambda(lambda t: torch.clamp(t * 255. + 0.5, min=0, max=255)),
+    # CHW to HWC
+    transforms.Lambda(lambda t: t.permute(1, 2, 0)),
+    # to numpy ndarray, dtype int8
+    transforms.Lambda(lambda t: t.to('cpu', torch.uint8).numpy()),
+])
+
+
+def plt_pt_img(
+        pt_img: torch.Tensor,
+        save_path: AnyStr = None,
+        title: AnyStr = None,
+        dpi: int = 300
+):
+    grid = make_grid(pt_img, normalize=True, pad_value=2)
+    ndarr = pt2numpy_transforms(grid)
+    plt.imshow(ndarr)
+    plt.axis("off")
+    plt.tight_layout()
+    if title is not None:
+        plt.title(f"{title}")
+
+    plt.show()
+    if save_path is not None:
+        plt.savefig(save_path, dpi=dpi)
+
+    plt.close()
+
+
+@torch.no_grad()
+def save_grid_images_and_labels(
+        images: Union[torch.Tensor, List[torch.Tensor]],
+        probs: Union[torch.Tensor, List[torch.Tensor]],
+        labels: Union[torch.Tensor, List[torch.Tensor]],
+        classes: Union[torch.Tensor, List[torch.Tensor]],
+        fp: Union[Text, pathlib.Path, BinaryIO],
+        nrow: int = 4,
+        normalize: bool = True
+) -> None:
+    """Save a given Tensor into an image file.
+    """
+    num_images = len(images)
+    num_rows, num_cols = _get_subplot_shape(num_images, nrow)
+
+    fig = plt.figure(figsize=(25, 20))
+
+    for i in range(num_images):
+        ax = fig.add_subplot(num_rows, num_cols, i + 1)
+
+        image, true_label, prob = images[i], labels[i], probs[i]
+
+        true_prob = prob[true_label]
+        incorrect_prob, incorrect_label = torch.max(prob, dim=0)
+        true_class = classes[true_label]
+
+        incorrect_class = classes[incorrect_label]
+
+        if normalize:
+            image = sample2pil_transforms(image)
+
+        ax.imshow(image)
+        title = f'true label: {true_class} ({true_prob:.3f})\n ' \
+                f'pred label: {incorrect_class} ({incorrect_prob:.3f})'
+        ax.set_title(title, fontsize=20)
+        ax.axis('off')
+
+    fig.subplots_adjust(hspace=0.3)
+
+    plt.savefig(fp)
+    plt.close()
+
+
+@torch.no_grad()
+def save_grid_images_and_captions(
+        images: Union[torch.Tensor, List[torch.Tensor]],
+        captions: List,
+        fp: Union[Text, pathlib.Path, BinaryIO],
+        nrow: int = 4,
+        normalize: bool = True
+) -> None:
+    """
+    Save a grid of images and their captions into an image file.
+
+    Args:
+        images (Union[torch.Tensor, List[torch.Tensor]]): A list of images to display.
+        captions (List): A list of captions for each image.
+        fp (Union[Text, pathlib.Path, BinaryIO]): The file path to save the image to.
+        nrow (int, optional): The number of images to display in each row. Defaults to 4.
+        normalize (bool, optional): Whether to normalize the image or not. Defaults to False.
+    """
+    num_images = len(images)
+    num_rows, num_cols = _get_subplot_shape(num_images, nrow)
+
+    fig = plt.figure(figsize=(25, 20))
+
+    for i in range(num_images):
+        ax = fig.add_subplot(num_rows, num_cols, i + 1)
+        image, caption = images[i], captions[i]
+
+        if normalize:
+            image = sample2pil_transforms(image)
+
+        ax.imshow(image)
+        title = f'"{caption}"' if num_images > 1 else f'"{captions}"'
+        title = _insert_newline(title)
+        ax.set_title(title, fontsize=20)
+        ax.axis('off')
+
+    fig.subplots_adjust(hspace=0.3)
+
+    plt.savefig(fp)
+    plt.close()
+
+
+def _get_subplot_shape(num_images, nrow):
+    """
+    Calculate the number of rows and columns required to display images in a grid.
+
+    Args:
+        num_images (int): The total number of images to display.
+        nrow (int): The maximum number of images to display in each row.
+
+    Returns:
+        Tuple[int, int]: The number of rows and columns required to display images in a grid.
+    """
+    num_cols = min(num_images, nrow)
+    num_rows = (num_images + num_cols - 1) // num_cols
+    return num_rows, num_cols
+
+
+def _insert_newline(string, point=9):
+    # split by blank
+    words = string.split()
+    if len(words) <= point:
+        return string
+
+    word_chunks = [words[i:i + point] for i in range(0, len(words), point)]
+    new_string = "\n".join(" ".join(chunk) for chunk in word_chunks)
+    return new_string
diff --git a/libs/modules/visual/video.py b/libs/modules/visual/video.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b67882f73efca8f408bead99c63e110b1695e8b
--- /dev/null
+++ b/libs/modules/visual/video.py
@@ -0,0 +1,38 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) XiMing Xing. All rights reserved.
+# Author: XiMing Xing
+# Description:
+from typing import Any, Union
+import pathlib
+
+import cv2
+
+
+def create_video(num_iter: int,
+                 save_dir: Union[Any, pathlib.Path],
+                 video_frame_freq: int = 1,
+                 fname: str = "rendering_process",
+                 verbose: bool = True):
+    if not isinstance(save_dir, pathlib.Path):
+        save_dir = pathlib.Path(save_dir)
+
+    img_array = []
+    for i in range(0, num_iter):
+        if i % video_frame_freq == 0 or i == num_iter - 1:
+            filename = save_dir / f"iter{i}.png"
+            img = cv2.imread(filename.as_posix())
+            img_array.append(img)
+
+    video_name = save_dir / f"{fname}.mp4"
+    out = cv2.VideoWriter(
+        video_name.as_posix(),
+        cv2.VideoWriter_fourcc(*'mp4v'),
+        30.0,  # fps
+        (600, 600)  # video size
+    )
+    for iii in range(len(img_array)):
+        out.write(img_array[iii])
+    out.release()
+
+    if verbose:
+        print(f"video saved in '{video_name}'.")
diff --git a/libs/solver/__init__.py b/libs/solver/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ad761f2f5443eb41b15afc4116a66ecdfa9d918
--- /dev/null
+++ b/libs/solver/__init__.py
@@ -0,0 +1,4 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) XiMing Xing. All rights reserved.
+# Author: XiMing Xing
+# Description:
diff --git a/libs/solver/lr_scheduler.py b/libs/solver/lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9bb82e96c5cb05d1ce296fb859709ac44c3eaed
--- /dev/null
+++ b/libs/solver/lr_scheduler.py
@@ -0,0 +1,350 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch optimization for diffusion models."""
+
+import math
+from enum import Enum
+from typing import Optional, Union
+
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import LambdaLR
+
+
+class SchedulerType(Enum):
+    LINEAR = "linear"
+    COSINE = "cosine"
+    COSINE_WITH_RESTARTS = "cosine_with_restarts"
+    POLYNOMIAL = "polynomial"
+    CONSTANT = "constant"
+    CONSTANT_WITH_WARMUP = "constant_with_warmup"
+    PIECEWISE_CONSTANT = "piecewise_constant"
+
+
+def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
+    """
+    Create a schedule with a constant learning rate, using the learning rate set in optimizer.
+
+    Args:
+        optimizer ([`~torch.optim.Optimizer`]):
+            The optimizer for which to schedule the learning rate.
+        last_epoch (`int`, *optional*, defaults to -1):
+            The index of the last epoch when resuming training.
+
+    Return:
+        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+    """
+    return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
+
+
+def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
+    """
+    Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
+    increases linearly between 0 and the initial lr set in the optimizer.
+
+    Args:
+        optimizer ([`~torch.optim.Optimizer`]):
+            The optimizer for which to schedule the learning rate.
+        num_warmup_steps (`int`):
+            The number of steps for the warmup phase.
+        last_epoch (`int`, *optional*, defaults to -1):
+            The index of the last epoch when resuming training.
+
+    Return:
+        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+    """
+
+    def lr_lambda(current_step: int):
+        if current_step < num_warmup_steps:
+            return float(current_step) / float(max(1.0, num_warmup_steps))
+        return 1.0
+
+    return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
+
+
+def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1):
+    """
+    Create a schedule with a constant learning rate, using the learning rate set in optimizer.
+
+    Args:
+        optimizer ([`~torch.optim.Optimizer`]):
+            The optimizer for which to schedule the learning rate.
+        step_rules (`string`):
+            The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate
+            if multiple 1 for the first 10 steps, mutiple 0.1 for the next 20 steps, multiple 0.01 for the next 30
+            steps and multiple 0.005 for the other steps.
+        last_epoch (`int`, *optional*, defaults to -1):
+            The index of the last epoch when resuming training.
+
+    Return:
+        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+    """
+
+    rules_dict = {}
+    rule_list = step_rules.split(",")
+    for rule_str in rule_list[:-1]:
+        value_str, steps_str = rule_str.split(":")
+        steps = int(steps_str)
+        value = float(value_str)
+        rules_dict[steps] = value
+    last_lr_multiple = float(rule_list[-1])
+
+    def create_rules_function(rules_dict, last_lr_multiple):
+        def rule_func(steps: int) -> float:
+            sorted_steps = sorted(rules_dict.keys())
+            for i, sorted_step in enumerate(sorted_steps):
+                if steps < sorted_step:
+                    return rules_dict[sorted_steps[i]]
+            return last_lr_multiple
+
+        return rule_func
+
+    rules_func = create_rules_function(rules_dict, last_lr_multiple)
+
+    return LambdaLR(optimizer, rules_func, last_epoch=last_epoch)
+
+
+def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
+    """
+    Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
+    a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
+
+    Args:
+        optimizer ([`~torch.optim.Optimizer`]):
+            The optimizer for which to schedule the learning rate.
+        num_warmup_steps (`int`):
+            The number of steps for the warmup phase.
+        num_training_steps (`int`):
+            The total number of training steps.
+        last_epoch (`int`, *optional*, defaults to -1):
+            The index of the last epoch when resuming training.
+
+    Return:
+        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+    """
+
+    def lr_lambda(current_step: int):
+        if current_step < num_warmup_steps:
+            return float(current_step) / float(max(1, num_warmup_steps))
+        return max(
+            0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
+        )
+
+    return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+def get_cosine_schedule_with_warmup(
+        optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5,
+        last_epoch: int = -1
+):
+    """
+    Create a schedule with a learning rate that decreases following the values of the cosine function between the
+    initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
+    initial lr set in the optimizer.
+
+    Args:
+        optimizer ([`~torch.optim.Optimizer`]):
+            The optimizer for which to schedule the learning rate.
+        num_warmup_steps (`int`):
+            The number of steps for the warmup phase.
+        num_training_steps (`int`):
+            The total number of training steps.
+        num_periods (`float`, *optional*, defaults to 0.5):
+            The number of periods of the cosine function in a schedule (the default is to just decrease from the max
+            value to 0 following a half-cosine).
+        last_epoch (`int`, *optional*, defaults to -1):
+            The index of the last epoch when resuming training.
+
+    Return:
+        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+    """
+
+    def lr_lambda(current_step):
+        if current_step < num_warmup_steps:
+            return float(current_step) / float(max(1, num_warmup_steps))
+        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
+        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
+
+    return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+def get_cosine_with_hard_restarts_schedule_with_warmup(
+        optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
+):
+    """
+    Create a schedule with a learning rate that decreases following the values of the cosine function between the
+    initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
+    linearly between 0 and the initial lr set in the optimizer.
+
+    Args:
+        optimizer ([`~torch.optim.Optimizer`]):
+            The optimizer for which to schedule the learning rate.
+        num_warmup_steps (`int`):
+            The number of steps for the warmup phase.
+        num_training_steps (`int`):
+            The total number of training steps.
+        num_cycles (`int`, *optional*, defaults to 1):
+            The number of hard restarts to use.
+        last_epoch (`int`, *optional*, defaults to -1):
+            The index of the last epoch when resuming training.
+
+    Return:
+        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+    """
+
+    def lr_lambda(current_step):
+        if current_step < num_warmup_steps:
+            return float(current_step) / float(max(1, num_warmup_steps))
+        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
+        if progress >= 1.0:
+            return 0.0
+        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
+
+    return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+def get_polynomial_decay_schedule_with_warmup(
+        optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1
+):
+    """
+    Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
+    optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
+    initial lr set in the optimizer.
+
+    Args:
+        optimizer ([`~torch.optim.Optimizer`]):
+            The optimizer for which to schedule the learning rate.
+        num_warmup_steps (`int`):
+            The number of steps for the warmup phase.
+        num_training_steps (`int`):
+            The total number of training steps.
+        lr_end (`float`, *optional*, defaults to 1e-7):
+            The end LR.
+        power (`float`, *optional*, defaults to 1.0):
+            Power factor.
+        last_epoch (`int`, *optional*, defaults to -1):
+            The index of the last epoch when resuming training.
+
+    Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
+    implementation at
+    https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
+
+    Return:
+        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+
+    """
+
+    lr_init = optimizer.defaults["lr"]
+    if not (lr_init > lr_end):
+        raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})")
+
+    def lr_lambda(current_step: int):
+        if current_step < num_warmup_steps:
+            return float(current_step) / float(max(1, num_warmup_steps))
+        elif current_step > num_training_steps:
+            return lr_end / lr_init  # as LambdaLR multiplies by lr_init
+        else:
+            lr_range = lr_init - lr_end
+            decay_steps = num_training_steps - num_warmup_steps
+            pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
+            decay = lr_range * pct_remaining ** power + lr_end
+            return decay / lr_init  # as LambdaLR multiplies by lr_init
+
+    return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+TYPE_TO_SCHEDULER_FUNCTION = {
+    SchedulerType.LINEAR: get_linear_schedule_with_warmup,
+    SchedulerType.COSINE: get_cosine_schedule_with_warmup,
+    SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
+    SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
+    SchedulerType.CONSTANT: get_constant_schedule,
+    SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
+    SchedulerType.PIECEWISE_CONSTANT: get_piecewise_constant_schedule,
+}
+
+
+def get_scheduler(
+        name: Union[str, SchedulerType],
+        optimizer: Optimizer,
+        step_rules: Optional[str] = None,
+        num_warmup_steps: Optional[int] = None,
+        num_training_steps: Optional[int] = None,
+        num_cycles: int = 1,
+        power: float = 1.0,
+        last_epoch: int = -1,
+):
+    """
+    Unified API to get any scheduler from its name.
+
+    Args:
+        name (`str` or `SchedulerType`):
+            The name of the scheduler to use.
+        optimizer (`torch.optim.Optimizer`):
+            The optimizer that will be used during training.
+        step_rules (`str`, *optional*):
+            A string representing the step rules to use. This is only used by the `PIECEWISE_CONSTANT` scheduler.
+        num_warmup_steps (`int`, *optional*):
+            The number of warmup steps to do. This is not required by all schedulers (hence the argument being
+            optional), the function will raise an error if it's unset and the scheduler type requires it.
+        num_training_steps (`int``, *optional*):
+            The number of training steps to do. This is not required by all schedulers (hence the argument being
+            optional), the function will raise an error if it's unset and the scheduler type requires it.
+        num_cycles (`int`, *optional*):
+            The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
+        power (`float`, *optional*, defaults to 1.0):
+            Power factor. See `POLYNOMIAL` scheduler
+        last_epoch (`int`, *optional*, defaults to -1):
+            The index of the last epoch when resuming training.
+    """
+    name = SchedulerType(name)
+    schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
+    if name == SchedulerType.CONSTANT:
+        return schedule_func(optimizer, last_epoch=last_epoch)
+
+    if name == SchedulerType.PIECEWISE_CONSTANT:
+        return schedule_func(optimizer, step_rules=step_rules, last_epoch=last_epoch)
+
+    # All other schedulers require `num_warmup_steps`
+    if num_warmup_steps is None:
+        raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
+
+    if name == SchedulerType.CONSTANT_WITH_WARMUP:
+        return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, last_epoch=last_epoch)
+
+    # All other schedulers require `num_training_steps`
+    if num_training_steps is None:
+        raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
+
+    if name == SchedulerType.COSINE_WITH_RESTARTS:
+        return schedule_func(
+            optimizer,
+            num_warmup_steps=num_warmup_steps,
+            num_training_steps=num_training_steps,
+            num_cycles=num_cycles,
+            last_epoch=last_epoch,
+        )
+
+    if name == SchedulerType.POLYNOMIAL:
+        return schedule_func(
+            optimizer,
+            num_warmup_steps=num_warmup_steps,
+            num_training_steps=num_training_steps,
+            power=power,
+            last_epoch=last_epoch,
+        )
+
+    return schedule_func(
+        optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, last_epoch=last_epoch
+    )
diff --git a/libs/solver/optim.py b/libs/solver/optim.py
new file mode 100644
index 0000000000000000000000000000000000000000..c980a4feb4efa6d859abd6b24667a18a3b735e4c
--- /dev/null
+++ b/libs/solver/optim.py
@@ -0,0 +1,58 @@
+# -*- coding: utf-8 -*-
+# Author: ximing
+# Description: SVGDreamer - optim
+# Copyright (c) 2023, XiMing Xing.
+# License: MIT License
+from functools import partial
+
+import torch
+from omegaconf import DictConfig
+
+
+def get_optimizer(optimizer_name, parameters, lr=None, config: DictConfig = None):
+    param_dict = {}
+    if optimizer_name == "adam":
+        optimizer = partial(torch.optim.Adam, params=parameters)
+        if lr is not None:
+            optimizer = partial(torch.optim.Adam, params=parameters, lr=lr)
+        if config.get('betas'):
+            param_dict['betas'] = config.betas
+        if config.get('weight_decay'):
+            param_dict['weight_decay'] = config.weight_decay
+        if config.get('eps'):
+            param_dict['eps'] = config.eps
+    elif optimizer_name == "adamw":
+        optimizer = partial(torch.optim.AdamW, params=parameters)
+        if lr is not None:
+            optimizer = partial(torch.optim.AdamW, params=parameters, lr=lr)
+        if config.get('betas'):
+            param_dict['betas'] = config.betas
+        if config.get('weight_decay'):
+            param_dict['weight_decay'] = config.weight_decay
+        if config.get('eps'):
+            param_dict['eps'] = config.eps
+    elif optimizer_name == "radam":
+        optimizer = partial(torch.optim.RAdam, params=parameters)
+        if lr is not None:
+            optimizer = partial(torch.optim.RAdam, params=parameters, lr=lr)
+        if config.get('betas'):
+            param_dict['betas'] = config.betas
+        if config.get('weight_decay'):
+            param_dict['weight_decay'] = config.weight_decay
+    elif optimizer_name == "sgd":
+        optimizer = partial(torch.optim.SGD, params=parameters)
+        if lr is not None:
+            optimizer = partial(torch.optim.SGD, params=parameters, lr=lr)
+        if config.get('momentum'):
+            param_dict['momentum'] = config.momentum
+        if config.get('weight_decay'):
+            param_dict['weight_decay'] = config.weight_decay
+        if config.get('nesterov'):
+            param_dict['nesterov'] = config.nesterov
+    else:
+        raise NotImplementedError(f"Optimizer {optimizer_name} not implemented.")
+
+    if len(param_dict.keys()) > 0:
+        return optimizer(**param_dict)
+    else:
+        return optimizer()
diff --git a/libs/utils/__init__.py b/libs/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..854865eb7334af6aef8017506126341d93663a71
--- /dev/null
+++ b/libs/utils/__init__.py
@@ -0,0 +1,30 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) XiMing Xing. All rights reserved.
+# Author: XiMing Xing
+# Description:
+from . import lazy
+
+# __getattr__, __dir__, __all__ = lazy.attach(
+#     __name__,
+#     submodules={},
+#     submod_attrs={
+#         'misc': ['identity', 'exists', 'default', 'has_int_squareroot', 'sum_params', 'cycle', 'num_to_groups',
+#                  'extract', 'normalize', 'unnormalize'],
+#         'tqdm': ['tqdm_decorator'],
+#         'lazy': ['load']
+#     }
+# )
+
+from .misc import (
+    identity,
+    exists,
+    default,
+    has_int_squareroot,
+    sum_params,
+    cycle,
+    num_to_groups,
+    extract,
+    normalize,
+    unnormalize
+)
+from .tqdm import tqdm_decorator
diff --git a/libs/utils/argparse.py b/libs/utils/argparse.py
new file mode 100644
index 0000000000000000000000000000000000000000..bce9493007d9183545ea8b5322744e57e7e8168c
--- /dev/null
+++ b/libs/utils/argparse.py
@@ -0,0 +1,126 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) XiMing Xing. All rights reserved.
+# Author: XiMing Xing
+# Description:
+
+import argparse
+
+
+#################################################################################
+#                            practical argparse utils                           #
+#################################################################################
+
+def accelerate_parser():
+    parser = argparse.ArgumentParser(add_help=False)
+
+    # Device
+    parser.add_argument("-cpu", "--use_cpu", action="store_true",
+                        help="Whether or not disable cuda")
+
+    # Gradient Accumulation
+    parser.add_argument("-cumgard", "--gradient-accumulate-step",
+                        type=int, default=1)
+    parser.add_argument("--split-batches", action="store_true",
+                        help="Whether or not the accelerator should split the batches "
+                             "yielded by the dataloaders across the devices.")
+
+    # Nvidia-Apex and GradScaler
+    parser.add_argument("-mprec", "--mixed-precision",
+                        type=str, default='no', choices=['no', 'fp16', 'bf16'],
+                        help="Whether to use mixed precision. Choose"
+                             "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
+                             "and an Nvidia Ampere GPU.")
+    parser.add_argument("--init-scale",
+                        type=float, default=65536.0,
+                        help="Default value: `2.**16 = 65536.0` ,"
+                             "For ImageNet experiments, '2.**20 = 1048576.0' was a good default value."
+                             "the others: `2.**17 = 131072.0` ")
+    parser.add_argument("--growth-factor", type=float, default=2.0)
+    parser.add_argument("--backoff-factor", type=float, default=0.5)
+    parser.add_argument("--growth-interval", type=int, default=2000)
+
+    # Gradient Normalization
+    parser.add_argument("-gard_norm", "--max_grad_norm", type=float, default=-1)
+
+    # Trackers
+    parser.add_argument("--use-wandb", action="store_true")
+    parser.add_argument("--project-name", type=str, default="SketchGeneration")
+    parser.add_argument("--entity", type=str, default="ximinng")
+    parser.add_argument("--tensorboard", action="store_true")
+
+    # reproducibility
+    parser.add_argument("-d", "--seed", default=42, type=int)
+
+    # result path
+    parser.add_argument("-respath", "--results_path",
+                        type=str, default="",
+                        help="If it is None, it is automatically generated.")
+
+    # timing
+    parser.add_argument("-log_step", "--log_step", default=1000, type=int,
+                        help="can be use to control log.")
+    parser.add_argument("-eval_step", "--eval_step", default=5000, type=int,
+                        help="can be use to calculate some metrics.")
+    parser.add_argument("-save_step", "--save_step", default=5000, type=int,
+                        help="can be use to control saving checkpoint.")
+
+    # update configuration interface
+    # example: python main.py -c main.yaml -update "nnet.depth=16 batch_size=16"
+    parser.add_argument("-update",
+                        type=str, default=None,
+                        help="modified hyper-parameters of config file.")
+    return parser
+
+
+def ema_parser():
+    parser = argparse.ArgumentParser(add_help=False)
+    parser.add_argument('--ema', action='store_true', help='enable EMA model')
+    parser.add_argument("--ema_decay", type=float, default=0.9999)
+    parser.add_argument("--ema_update_after_step", type=int, default=100)
+    parser.add_argument("--ema_update_every", type=int, default=10)
+    return parser
+
+
+def base_data_parser():
+    parser = argparse.ArgumentParser(add_help=False)
+    parser.add_argument("-spl", "--split",
+                        default='test', type=str,
+                        choices=['train', 'val', 'test', 'all'],
+                        help="which part of the data set, 'all' means combine training and test sets.")
+    parser.add_argument("-j", "--num_workers",
+                        default=6, type=int,
+                        help="how many subprocesses to use for data loading.")
+    parser.add_argument("--shuffle",
+                        action='store_true',
+                        help="how many subprocesses to use for data loading.")
+    parser.add_argument("--drop_last",
+                        action='store_true',
+                        help="how many subprocesses to use for data loading.")
+    return parser
+
+
+def base_training_parser():
+    parser = argparse.ArgumentParser(add_help=False)
+    parser.add_argument("-tbz", "--train_batch_size",
+                        default=32, type=int,
+                        help="how many images to sample during training.")
+    parser.add_argument("-lr", "--learning_rate", default=1e-4, type=float)
+    parser.add_argument("-wd", "--weight_decay", default=0, type=float)
+    return parser
+
+
+def base_sampling_parser():
+    parser = argparse.ArgumentParser(add_help=False)
+    parser.add_argument("-vbz", "--valid_batch_size",
+                        default=1, type=int,
+                        help="how many images to sample during evaluation")
+    parser.add_argument("-ts", "--total_samples",
+                        default=2000, type=int,
+                        help="the total number of samples, can be used to calculate FID.")
+    parser.add_argument("-ns", "--num_samples",
+                        default=4, type=int,
+                        help="number of samples taken at a time, "
+                             "can be used to repeatedly induce samples from a generation model "
+                             "from a fixed guided information, "
+                             "eg: `one latent to ns samples` (1 latent to 5 photo generation) ")
+    return parser
diff --git a/libs/utils/lazy.py b/libs/utils/lazy.py
new file mode 100644
index 0000000000000000000000000000000000000000..170179a31867454546db3c7cee3f733650cd8d34
--- /dev/null
+++ b/libs/utils/lazy.py
@@ -0,0 +1,144 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) XiMing Xing. All rights reserved.
+# Author: XiMing Xing
+# Description:
+
+import importlib
+import importlib.util
+import os
+import sys
+
+
+def attach(package_name, submodules=None, submod_attrs=None):
+    """Attach lazily loaded submodules, functions, or other attributes.
+
+    Typically, modules import submodules and attributes as follows::
+
+      import mysubmodule
+      import anothersubmodule
+
+      from .foo import someattr
+
+    The idea is to replace a package's `__getattr__`, `__dir__`, and
+    `__all__`, such that all imports work exactly the way they did
+    before, except that they are only imported when used.
+
+    The typical way to call this function, replacing the above imports, is::
+
+      __getattr__, __lazy_dir__, __all__ = lazy.attach(
+        __name__,
+        ['mysubmodule', 'anothersubmodule'],
+        {'foo': 'someattr'}
+      )
+
+    This functionality requires Python 3.7 or higher.
+
+    Parameters
+    ----------
+    package_name : str
+        Typically use ``__name__``.
+    submodules : set
+        List of submodules to attach.
+    submod_attrs : dict
+        Dictionary of submodule -> list of attributes / functions.
+        These attributes are imported as they are used.
+
+    Returns
+    -------
+    __getattr__, __dir__, __all__
+
+    """
+    if submod_attrs is None:
+        submod_attrs = {}
+
+    if submodules is None:
+        submodules = set()
+    else:
+        submodules = set(submodules)
+
+    attr_to_modules = {
+        attr: mod for mod, attrs in submod_attrs.items() for attr in attrs
+    }
+
+    __all__ = list(submodules | attr_to_modules.keys())
+
+    def __getattr__(name):
+        if name in submodules:
+            return importlib.import_module(f'{package_name}.{name}')
+        elif name in attr_to_modules:
+            submod = importlib.import_module(
+                f'{package_name}.{attr_to_modules[name]}'
+            )
+            return getattr(submod, name)
+        else:
+            raise AttributeError(f'No {package_name} attribute {name}')
+
+    def __dir__():
+        return __all__
+
+    eager_import = os.environ.get('EAGER_IMPORT', '')
+    if eager_import not in ['', '0', 'false']:
+        for attr in set(attr_to_modules.keys()) | submodules:
+            __getattr__(attr)
+
+    return __getattr__, __dir__, list(__all__)
+
+
+def load(fullname):
+    """Return a lazily imported proxy for a module.
+
+    We often see the following pattern::
+
+      def myfunc():
+          import scipy as sp
+          sp.argmin(...)
+          ....
+
+    This is to prevent a module, in this case `scipy`, from being
+    imported at function definition time, since that can be slow.
+
+    This function provides a proxy module that, upon access, imports
+    the actual module.  So the idiom equivalent to the above example is::
+
+      sp = lazy.load("scipy")
+
+      def myfunc():
+          sp.argmin(...)
+          ....
+
+    The initial import time is fast because the actual import is delayed
+    until the first attribute is requested. The overall import time may
+    decrease as well for users that don't make use of large portions
+    of the library.
+
+    Parameters
+    ----------
+    fullname : str
+        The full name of the module or submodule to import.  For example::
+
+          sp = lazy.load('scipy')  # import scipy as sp
+          spla = lazy.load('scipy.linalg')  # import scipy.linalg as spla
+
+    Returns
+    -------
+    pm : importlib.util._LazyModule
+        Proxy module.  Can be used like any regularly imported module.
+        Actual loading of the module occurs upon first attribute request.
+
+    """
+    try:
+        return sys.modules[fullname]
+    except KeyError:
+        pass
+
+    spec = importlib.util.find_spec(fullname)
+    if spec is None:
+        raise ModuleNotFoundError(f"No module name '{fullname}'")
+
+    module = importlib.util.module_from_spec(spec)
+    sys.modules[fullname] = module
+
+    loader = importlib.util.LazyLoader(spec.loader)
+    loader.exec_module(module)
+
+    return module
diff --git a/libs/utils/logging.py b/libs/utils/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd9828d4ad6b640cc0dd08583b4b762195ef1d96
--- /dev/null
+++ b/libs/utils/logging.py
@@ -0,0 +1,65 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) XiMing Xing. All rights reserved.
+# Author: XiMing Xing
+# Description:
+
+import os
+import sys
+import errno
+
+
+def get_logger(logs_dir: str, file_name: str = "log.txt"):
+    logger = PrintLogger(os.path.join(logs_dir, file_name))
+    sys.stdout = logger  # record all python print
+    return logger
+
+
+class PrintLogger(object):
+
+    def __init__(self, fpath=None):
+        """
+        python standard input/output records
+        """
+        self.console = sys.stdout
+        self.file = None
+        if fpath is not None:
+            mkdir_if_missing(os.path.dirname(fpath))
+            self.file = open(fpath, 'w')
+
+    def __del__(self):
+        self.close()
+
+    def __enter__(self):
+        pass
+
+    def __exit__(self, *args):
+        self.close()
+
+    def write(self, msg):
+        self.console.write(msg)
+        if self.file is not None:
+            self.file.write(msg)
+
+    def write_in(self, msg):
+        """write in log only, not console"""
+        if self.file is not None:
+            self.file.write(msg)
+
+    def flush(self):
+        self.console.flush()
+        if self.file is not None:
+            self.file.flush()
+            os.fsync(self.file.fileno())
+
+    def close(self):
+        self.console.close()
+        if self.file is not None:
+            self.file.close()
+
+
+def mkdir_if_missing(dir_path):
+    try:
+        os.makedirs(dir_path)
+    except OSError as e:
+        if e.errno != errno.EEXIST:
+            raise
diff --git a/libs/utils/meter.py b/libs/utils/meter.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd32dfd7fa1d0193dcd914e4043b9a8d6b127ba5
--- /dev/null
+++ b/libs/utils/meter.py
@@ -0,0 +1,70 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) XiMing Xing. All rights reserved.
+# Author: XiMing Xing
+# Description:
+
+from enum import Enum
+
+import torch
+import torch.distributed as dist
+
+
+class Summary(Enum):
+    NONE = 0
+    AVERAGE = 1
+    SUM = 2
+    COUNT = 3
+
+
+class AverageMeter(object):
+    """Computes and stores the average and current value"""
+
+    def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE):
+        self.name = name
+        self.fmt = fmt
+        self.summary_type = summary_type
+        self.reset()
+
+    def reset(self):
+        self.val = 0
+        self.avg = 0
+        self.sum = 0
+        self.count = 0
+
+    def update(self, val, n=1):
+        self.val = val
+        self.sum += val * n
+        self.count += n
+        self.avg = self.sum / self.count
+
+    def all_reduce(self):
+        if torch.cuda.is_available():
+            device = torch.device("cuda")
+        elif torch.backends.mps.is_available():
+            device = torch.device("mps")
+        else:
+            device = torch.device("cpu")
+
+        total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device)
+        dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
+        self.sum, self.count = total.tolist()
+        self.avg = self.sum / self.count
+
+    def __str__(self):
+        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
+        return fmtstr.format(**self.__dict__)
+
+    def summary(self):
+        fmtstr = ''
+        if self.summary_type is Summary.NONE:
+            fmtstr = ''
+        elif self.summary_type is Summary.AVERAGE:
+            fmtstr = '{name} {avg:.3f}'
+        elif self.summary_type is Summary.SUM:
+            fmtstr = '{name} {sum:.3f}'
+        elif self.summary_type is Summary.COUNT:
+            fmtstr = '{name} {count:.3f}'
+        else:
+            raise ValueError('invalid summary type %r' % self.summary_type)
+
+        return fmtstr.format(**self.__dict__)
diff --git a/libs/utils/misc.py b/libs/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..352127dc4fbc9932249fbbbc89d4731eaca6f25a
--- /dev/null
+++ b/libs/utils/misc.py
@@ -0,0 +1,79 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) XiMing Xing. All rights reserved.
+# Author: XiMing Xing
+# Description:
+
+import math
+
+import torch
+
+
+def identity(t, *args, **kwargs):
+    """return t"""
+    return t
+
+
+def exists(x):
+    """whether x is None or not"""
+    return x is not None
+
+
+def default(val, d):
+    """ternary judgment: val != None ? val : d"""
+    if exists(val):
+        return val
+    return d() if callable(d) else d
+
+
+def has_int_squareroot(num):
+    return (math.sqrt(num) ** 2) == num
+
+
+def num_to_groups(num, divisor):
+    groups = num // divisor
+    remainder = num % divisor
+    arr = [divisor] * groups
+    if remainder > 0:
+        arr.append(remainder)
+    return arr
+
+
+#################################################################################
+#                             Model Utils                                       #
+#################################################################################
+
+def sum_params(model: torch.nn.Module, eps: float = 1e6):
+    return sum(p.numel() for p in model.parameters()) / eps
+
+
+#################################################################################
+#                            DataLoader Utils                                   #
+#################################################################################
+
+def cycle(dl):
+    while True:
+        for data in dl:
+            yield data
+
+
+#################################################################################
+#                            Diffusion Model Utils                              #
+#################################################################################
+
+def extract(a, t, x_shape):
+    b, *_ = t.shape
+    assert x_shape[0] == b
+    out = a.gather(-1, t)  # 1-D tensor, shape: (b,)
+    return out.reshape(b, *((1,) * (len(x_shape) - 1)))  # shape: [b, 1, 1, 1]
+
+
+def unnormalize(x):
+    """unnormalize_to_zero_to_one"""
+    x = (x + 1) * 0.5  # Map the data interval to [0, 1]
+    return torch.clamp(x, 0.0, 1.0)
+
+
+def normalize(x):
+    """normalize_to_neg_one_to_one"""
+    x = x * 2 - 1  # Map the data interval to [-1, 1]
+    return torch.clamp(x, -1.0, 1.0)
diff --git a/libs/utils/model_summary.py b/libs/utils/model_summary.py
new file mode 100644
index 0000000000000000000000000000000000000000..afd33987128d074d9e88dfa8aabed29abea1fcfb
--- /dev/null
+++ b/libs/utils/model_summary.py
@@ -0,0 +1,128 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) XiMing Xing. All rights reserved.
+# Author: XiMing Xing
+# Description:
+
+import sys
+from collections import OrderedDict
+
+import numpy as np
+import torch
+
+layer_modules = (torch.nn.MultiheadAttention,)
+
+
+def summary(model, input_data=None, input_data_args=None, input_shape=None, input_dtype=torch.FloatTensor,
+            batch_size=-1,
+            *args, **kwargs):
+    """
+    give example input data as least one way like below:
+    ① input_data ---> model.forward(input_data)
+    ② input_data_args ---> model.forward(*input_data_args)
+    ③ input_shape & input_dtype ---> model.forward(*[torch.rand(2, *size).type(input_dtype) for size in input_shape])
+    """
+
+    hooks = []
+    summary = OrderedDict()
+
+    def register_hook(module):
+        def hook(module, inputs, outputs):
+
+            class_name = str(module.__class__).split(".")[-1].split("'")[0]
+            module_idx = len(summary)
+
+            key = "%s-%i" % (class_name, module_idx + 1)
+
+            info = OrderedDict()
+            info["id"] = id(module)
+            if isinstance(outputs, (list, tuple)):
+                try:
+                    info["out"] = [batch_size] + list(outputs[0].size())[1:]
+                except AttributeError:
+                    # pack_padded_seq and pad_packed_seq store feature into data attribute
+                    info["out"] = [batch_size] + list(outputs[0].data.size())[1:]
+            else:
+                info["out"] = [batch_size] + list(outputs.size())[1:]
+
+            info["params_nt"], info["params"] = 0, 0
+            for name, param in module.named_parameters():
+                info["params"] += param.nelement() * param.requires_grad
+                info["params_nt"] += param.nelement() * (not param.requires_grad)
+
+            summary[key] = info
+
+        # ignore Sequential and ModuleList and other containers
+        if isinstance(module, layer_modules) or not module._modules:
+            hooks.append(module.register_forward_hook(hook))
+
+    model.apply(register_hook)
+
+    # multiple inputs to the network
+    if isinstance(input_shape, tuple):
+        input_shape = [input_shape]
+
+    if input_data is not None:
+        x = [input_data]
+    elif input_shape is not None:
+        # batch_size of 2 for batchnorm
+        x = [torch.rand(2, *size).type(input_dtype) for size in input_shape]
+    elif input_data_args is not None:
+        x = input_data_args
+    else:
+        x = []
+    try:
+        with torch.no_grad():
+            model(*x) if not (kwargs or args) else model(*x, *args, **kwargs)
+    except Exception:
+        # This can be usefull for debugging
+        print("Failed to run summary...")
+        raise
+    finally:
+        for hook in hooks:
+            hook.remove()
+    summary_logs = []
+    summary_logs.append("--------------------------------------------------------------------------")
+    line_new = "{:<30}  {:>20} {:>20}".format("Layer (type)", "Output Shape", "Param #")
+    summary_logs.append(line_new)
+    summary_logs.append("==========================================================================")
+    total_params = 0
+    total_output = 0
+    trainable_params = 0
+    for layer in summary:
+        # layer, output_shape, params
+        line_new = "{:<30}  {:>20} {:>20}".format(
+            layer,
+            str(summary[layer]["out"]),
+            "{0:,}".format(summary[layer]["params"] + summary[layer]["params_nt"])
+        )
+        total_params += (summary[layer]["params"] + summary[layer]["params_nt"])
+        total_output += np.prod(summary[layer]["out"])
+        trainable_params += summary[layer]["params"]
+        summary_logs.append(line_new)
+
+    # assume 4 bytes/number
+    if input_data is not None:
+        total_input_size = abs(sys.getsizeof(input_data) / (1024 ** 2.))
+    elif input_shape is not None:
+        total_input_size = abs(np.prod(input_shape) * batch_size * 4. / (1024 ** 2.))
+    else:
+        total_input_size = 0.0
+    total_output_size = abs(2. * total_output * 4. / (1024 ** 2.))  # x2 for gradients
+    total_params_size = abs(total_params * 4. / (1024 ** 2.))
+    total_size = total_params_size + total_output_size + total_input_size
+
+    summary_logs.append("==========================================================================")
+    summary_logs.append("Total params: {0:,}".format(total_params))
+    summary_logs.append("Trainable params: {0:,}".format(trainable_params))
+    summary_logs.append("Non-trainable params: {0:,}".format(total_params - trainable_params))
+    summary_logs.append("--------------------------------------------------------------------------")
+    summary_logs.append("Input size (MB): %0.6f" % total_input_size)
+    summary_logs.append("Forward/backward pass size (MB): %0.6f" % total_output_size)
+    summary_logs.append("Params size (MB): %0.6f" % total_params_size)
+    summary_logs.append("Estimated Total Size (MB): %0.6f" % total_size)
+    summary_logs.append("--------------------------------------------------------------------------")
+
+    summary_info = "\n".join(summary_logs)
+
+    print(summary_info)
+    return summary_info
diff --git a/libs/utils/tqdm.py b/libs/utils/tqdm.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0a88005fabf467a16f9a5c75ff81cdb91326703
--- /dev/null
+++ b/libs/utils/tqdm.py
@@ -0,0 +1,30 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) XiMing Xing. All rights reserved.
+# Author: XiMing Xing
+# Description:
+
+from typing import Callable
+from tqdm.auto import tqdm
+
+
+def tqdm_decorator(func: Callable):
+    """A decorator function called tqdm_decorator that takes a function as an argument and
+    returns a new function that wraps the input function with a tqdm progress bar.
+
+    Noting: **The input function is assumed to have an object self as its first argument**, which contains a step attribute,
+    an args attribute with a train_num_steps attribute, and an accelerator attribute with an is_main_process attribute.
+
+    Args:
+        func: tqdm_decorator
+
+    Returns:
+            a new function that wraps the input function with a tqdm progress bar.
+    """
+
+    def wrapper(*args, **kwargs):
+        with tqdm(initial=args[0].step,
+                  total=args[0].args.train_num_steps,
+                  disable=not args[0].accelerator.is_main_process) as pbar:
+            func(*args, **kwargs, pbar=pbar)
+
+    return wrapper
diff --git a/methods/__init__.py b/methods/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ad761f2f5443eb41b15afc4116a66ecdfa9d918
--- /dev/null
+++ b/methods/__init__.py
@@ -0,0 +1,4 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) XiMing Xing. All rights reserved.
+# Author: XiMing Xing
+# Description:
diff --git a/methods/diffusers_warp/__init__.py b/methods/diffusers_warp/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bcfe2a6204da3a4c4ae7b826722f07664ea1a4d0
--- /dev/null
+++ b/methods/diffusers_warp/__init__.py
@@ -0,0 +1,138 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) XiMing Xing. All rights reserved.
+# Author: XiMing Xing
+# Description:
+from typing import AnyStr
+import pathlib
+from collections import OrderedDict
+from packaging import version
+
+import torch
+from diffusers import StableDiffusionPipeline, SchedulerMixin, DiffusionPipeline
+from diffusers.utils import is_torch_version, is_xformers_available
+
+huggingface_model_dict = OrderedDict({
+    "sd14": "CompVis/stable-diffusion-v1-4",  # resolution: 512
+    "sd15": "runwayml/stable-diffusion-v1-5",  # resolution: 512
+    "sd21b": "stabilityai/stable-diffusion-2-1-base",  # resolution: 512
+    "sd21": "stabilityai/stable-diffusion-2-1",  # resolution: 768
+    "sdxl": "stabilityai/stable-diffusion-xl-base-1.0",  # resolution: 1024
+})
+
+_model2resolution = {
+    "sd14": 512,
+    "sd15": 512,
+    "sd21b": 512,
+    "sd21": 768,
+    "sdxl": 1024,
+}
+
+
+def model2res(model_id: str):
+    return _model2resolution.get(model_id, 512)
+
+
+def init_diffusion_pipeline(model_id: AnyStr,
+                            custom_pipeline: StableDiffusionPipeline,
+                            custom_scheduler: SchedulerMixin = None,
+                            device: torch.device = "cuda",
+                            torch_dtype: torch.dtype = torch.float32,
+                            local_files_only: bool = True,
+                            force_download: bool = False,
+                            resume_download: bool = False,
+                            ldm_speed_up: bool = False,
+                            enable_xformers: bool = True,
+                            gradient_checkpoint: bool = False,
+                            lora_path: AnyStr = None,
+                            unet_path: AnyStr = None) -> StableDiffusionPipeline:
+    """
+    A tool for initial diffusers model.
+
+    Args:
+        model_id (`str` or `os.PathLike`, *optional*): pretrained_model_name_or_path
+        custom_pipeline: any StableDiffusionPipeline pipeline
+        custom_scheduler: any scheduler
+        device: set device
+        local_files_only: prohibited download model
+        force_download: forced download model
+        resume_download: re-download model
+        ldm_speed_up: use the `torch.compile` api to speed up unet
+        enable_xformers: enable memory efficient attention from [xFormers]
+        gradient_checkpoint: activates gradient checkpointing for the current model
+        lora_path: load LoRA checkpoint
+        unet_path: load unet checkpoint
+
+    Returns:
+            diffusers.StableDiffusionPipeline
+    """
+
+    # get model id
+    model_id = huggingface_model_dict.get(model_id, model_id)
+
+    # process diffusion model
+    if custom_scheduler is not None:
+        pipeline = custom_pipeline.from_pretrained(
+            model_id,
+            torch_dtype=torch_dtype,
+            local_files_only=local_files_only,
+            force_download=force_download,
+            resume_download=resume_download,
+            scheduler=custom_scheduler.from_pretrained(model_id,
+                                                       subfolder="scheduler",
+                                                       local_files_only=local_files_only)
+        ).to(device)
+    else:
+        pipeline = custom_pipeline.from_pretrained(
+            model_id,
+            torch_dtype=torch_dtype,
+            local_files_only=local_files_only,
+            force_download=force_download,
+            resume_download=resume_download,
+        ).to(device)
+
+    # process unet model if exist
+    if unet_path is not None and pathlib.Path(unet_path).exists():
+        print(f"=> load u-net from {unet_path}")
+        pipeline.unet.from_pretrained(model_id, subfolder="unet")
+
+    # process lora layers if exist
+    if lora_path is not None and pathlib.Path(lora_path).exists():
+        pipeline.unet.load_attn_procs(lora_path)
+        print(f"=> load lora layers into U-Net from {lora_path} ...")
+
+    # torch.compile
+    if ldm_speed_up:
+        if is_torch_version(">=", "2.0.0"):
+            pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
+            print(f"=> enable torch.compile on U-Net")
+        else:
+            print(f"=> warning: calling torch.compile speed-up failed, since torch version <= 2.0.0")
+
+    # Meta xformers
+    if enable_xformers:
+        if is_xformers_available():
+            import xformers
+
+            xformers_version = version.parse(xformers.__version__)
+            if xformers_version == version.parse("0.0.16"):
+                print(
+                    "xFormers 0.0.16 cannot be used for training in some GPUs. "
+                    "If you observe problems during training, please update xFormers to at least 0.0.17. "
+                    "See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+                )
+            print(f"=> enable xformers")
+            pipeline.unet.enable_xformers_memory_efficient_attention()
+        else:
+            print(f"=> warning: calling xformers failed")
+
+    # gradient checkpointing
+    if gradient_checkpoint:
+        try:
+            print(f"=> enable gradient checkpointing")
+            pipeline.unet.enable_gradient_checkpointing()
+        except Exception as e:
+            print("=> waring: gradient checkpointing is not activated for this model.")
+
+    print(f"Diffusion Model: {model_id}")
+    print(pipeline.scheduler)
+    return pipeline
diff --git a/methods/diffvg_warp/__init__.py b/methods/diffvg_warp/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..00ba7f726646659600804480362ba8fb084597b2
--- /dev/null
+++ b/methods/diffvg_warp/__init__.py
@@ -0,0 +1,11 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) XiMing Xing. All rights reserved.
+# Author: XiMing Xing
+# Description:
+
+from .diffvg_state import DiffVGState, init_diffvg
+
+__all__ = [
+    'DiffVGState',
+    'init_diffvg'
+]
diff --git a/methods/diffvg_warp/diffvg_state.py b/methods/diffvg_warp/diffvg_state.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab6784c5d184bab87eca88627dd96364a8b7d7ab
--- /dev/null
+++ b/methods/diffvg_warp/diffvg_state.py
@@ -0,0 +1,252 @@
+# -*- coding: utf-8 -*-
+# Author: ximing
+# Description: parent class
+# Copyright (c) 2023, XiMing Xing.
+# License: MIT License
+import pathlib
+from typing import AnyStr, List, Union
+import xml.etree.ElementTree as etree
+
+import torch
+import pydiffvg
+
+
+def init_diffvg(device: torch.device,
+                use_gpu: bool = torch.cuda.is_available(),
+                print_timing: bool = False):
+    pydiffvg.set_device(device)
+    pydiffvg.set_use_gpu(use_gpu)
+    pydiffvg.set_print_timing(print_timing)
+
+
+class DiffVGState(torch.nn.Module):
+
+    def __init__(self,
+                 device: torch.device,
+                 use_gpu: bool = torch.cuda.is_available(),
+                 print_timing: bool = False,
+                 canvas_width: int = True,
+                 canvas_height: int = True):
+        super(DiffVGState, self).__init__()
+        # pydiffvg device setting
+        self.device = device
+        init_diffvg(device, use_gpu, print_timing)
+
+        self.canvas_width = canvas_width
+        self.canvas_height = canvas_height
+
+        # record all paths
+        self.shapes = []
+        self.shape_groups = []
+        # record the current optimized path
+        self.cur_shapes = []
+        self.cur_shape_groups = []
+
+        self.point_vars = []
+        self.color_vars = []
+
+        self.strokes_counter = 0  # counts the number of calls to "get_path"
+
+    def load_svg(self, path_svg):
+        canvas_width, canvas_height, shapes, shape_groups = pydiffvg.svg_to_scene(path_svg)
+        return canvas_width, canvas_height, shapes, shape_groups
+
+    def _save_svg(self,
+                  filename: Union[AnyStr, pathlib.Path],
+                  width: int = None,
+                  height: int = None,
+                  shapes: List = None,
+                  shape_groups: List = None,
+                  use_gamma: bool = False,
+                  background: str = None):
+        """
+        Save an SVG file with specified parameters and shapes.
+        Noting: New version of SVG saving function that is an adaptation of pydiffvg.save_svg.
+        The original version saved words resulting in incomplete glyphs.
+
+        Args:
+            filename (str): The path to save the SVG file.
+            width (int): The width of the SVG canvas.
+            height (int): The height of the SVG canvas.
+            shapes (list): A list of shapes to be included in the SVG.
+            shape_groups (list): A list of shape groups.
+            use_gamma (bool): Flag indicating whether to apply gamma correction.
+            background (str, optional): The background color of the SVG.
+
+        Returns:
+            None
+        """
+        root = etree.Element('svg')
+        root.set('version', '1.1')
+        root.set('xmlns', 'http://www.w3.org/2000/svg')
+        root.set('width', str(width))
+        root.set('height', str(height))
+
+        if background is not None:
+            print(f"setting background to {background}")
+            root.set('style', str(background))
+
+        defs = etree.SubElement(root, 'defs')
+        g = etree.SubElement(root, 'g')
+
+        if use_gamma:
+            f = etree.SubElement(defs, 'filter')
+            f.set('id', 'gamma')
+            f.set('x', '0')
+            f.set('y', '0')
+            f.set('width', '100%')
+            f.set('height', '100%')
+            gamma = etree.SubElement(f, 'feComponentTransfer')
+            gamma.set('color-interpolation-filters', 'sRGB')
+            feFuncR = etree.SubElement(gamma, 'feFuncR')
+            feFuncR.set('type', 'gamma')
+            feFuncR.set('amplitude', str(1))
+            feFuncR.set('exponent', str(1 / 2.2))
+            feFuncG = etree.SubElement(gamma, 'feFuncG')
+            feFuncG.set('type', 'gamma')
+            feFuncG.set('amplitude', str(1))
+            feFuncG.set('exponent', str(1 / 2.2))
+            feFuncB = etree.SubElement(gamma, 'feFuncB')
+            feFuncB.set('type', 'gamma')
+            feFuncB.set('amplitude', str(1))
+            feFuncB.set('exponent', str(1 / 2.2))
+            feFuncA = etree.SubElement(gamma, 'feFuncA')
+            feFuncA.set('type', 'gamma')
+            feFuncA.set('amplitude', str(1))
+            feFuncA.set('exponent', str(1 / 2.2))
+            g.set('style', 'filter:url(#gamma)')
+
+        # Store color
+        for i, shape_group in enumerate(shape_groups):
+            def add_color(shape_color, name):
+                if isinstance(shape_color, pydiffvg.LinearGradient):
+                    lg = shape_color
+                    color = etree.SubElement(defs, 'linearGradient')
+                    color.set('id', name)
+                    color.set('x1', str(lg.begin[0].item()))
+                    color.set('y1', str(lg.begin[1].item()))
+                    color.set('x2', str(lg.end[0].item()))
+                    color.set('y2', str(lg.end[1].item()))
+                    offsets = lg.offsets.data.cpu().numpy()
+                    stop_colors = lg.stop_colors.data.cpu().numpy()
+                    for j in range(offsets.shape[0]):
+                        stop = etree.SubElement(color, 'stop')
+                        stop.set('offset', str(offsets[j]))
+                        c = lg.stop_colors[j, :]
+                        stop.set('stop-color', 'rgb({}, {}, {})'.format(
+                            int(255 * c[0]), int(255 * c[1]), int(255 * c[2])
+                        ))
+                        stop.set('stop-opacity', '{}'.format(c[3]))
+                if isinstance(shape_color, pydiffvg.RadialGradient):
+                    lg = shape_color
+                    color = etree.SubElement(defs, 'radialGradient')
+                    color.set('id', name)
+                    color.set('cx', str(lg.center[0].item() / width))
+                    color.set('cy', str(lg.center[1].item() / height))
+                    # this only support width=height
+                    color.set('r', str(lg.radius[0].item() / width))
+                    offsets = lg.offsets.data.cpu().numpy()
+                    stop_colors = lg.stop_colors.data.cpu().numpy()
+                    for j in range(offsets.shape[0]):
+                        stop = etree.SubElement(color, 'stop')
+                        stop.set('offset', str(offsets[j]))
+                        c = lg.stop_colors[j, :]
+                        stop.set('stop-color', 'rgb({}, {}, {})'.format(
+                            int(255 * c[0]), int(255 * c[1]), int(255 * c[2])
+                        ))
+                        stop.set('stop-opacity', '{}'.format(c[3]))
+
+            if shape_group.fill_color is not None:
+                add_color(shape_group.fill_color, 'shape_{}_fill'.format(i))
+            if shape_group.stroke_color is not None:
+                add_color(shape_group.stroke_color, 'shape_{}_stroke'.format(i))
+
+        for i, shape_group in enumerate(shape_groups):
+            shape = shapes[shape_group.shape_ids[0]]
+            if isinstance(shape, pydiffvg.Circle):
+                shape_node = etree.SubElement(g, 'circle')
+                shape_node.set('r', str(shape.radius.item()))
+                shape_node.set('cx', str(shape.center[0].item()))
+                shape_node.set('cy', str(shape.center[1].item()))
+            elif isinstance(shape, pydiffvg.Polygon):
+                shape_node = etree.SubElement(g, 'polygon')
+                points = shape.points.data.cpu().numpy()
+                path_str = ''
+                for j in range(0, shape.points.shape[0]):
+                    path_str += '{} {}'.format(points[j, 0], points[j, 1])
+                    if j != shape.points.shape[0] - 1:
+                        path_str += ' '
+                shape_node.set('points', path_str)
+            elif isinstance(shape, pydiffvg.Path):
+                for j, id in enumerate(shape_group.shape_ids):
+                    shape = shapes[id]
+                    if isinstance(shape, pydiffvg.Path):
+                        if j == 0:
+                            shape_node = etree.SubElement(g, 'path')
+                            path_str = ''
+
+                        num_segments = shape.num_control_points.shape[0]
+                        num_control_points = shape.num_control_points.data.cpu().numpy()
+                        points = shape.points.data.cpu().numpy()
+                        num_points = shape.points.shape[0]
+                        path_str += 'M {} {}'.format(points[0, 0], points[0, 1])
+                        point_id = 1
+                        for j in range(0, num_segments):
+                            if num_control_points[j] == 0:
+                                p = point_id % num_points
+                                path_str += ' L {} {}'.format(
+                                    points[p, 0], points[p, 1])
+                                point_id += 1
+                            elif num_control_points[j] == 1:
+                                p1 = (point_id + 1) % num_points
+                                path_str += ' Q {} {} {} {}'.format(
+                                    points[point_id, 0], points[point_id, 1],
+                                    points[p1, 0], points[p1, 1])
+                                point_id += 2
+                            elif num_control_points[j] == 2:
+                                p2 = (point_id + 2) % num_points
+                                path_str += ' C {} {} {} {} {} {}'.format(
+                                    points[point_id, 0], points[point_id, 1],
+                                    points[point_id + 1, 0], points[point_id + 1, 1],
+                                    points[p2, 0], points[p2, 1])
+                                point_id += 3
+                shape_node.set('d', path_str)
+            elif isinstance(shape, pydiffvg.Rect):
+                shape_node = etree.SubElement(g, 'rect')
+                shape_node.set('x', str(shape.p_min[0].item()))
+                shape_node.set('y', str(shape.p_min[1].item()))
+                shape_node.set('width', str(shape.p_max[0].item() - shape.p_min[0].item()))
+                shape_node.set('height', str(shape.p_max[1].item() - shape.p_min[1].item()))
+            elif isinstance(shape, pydiffvg.Ellipse):
+                shape_node = etree.SubElement(g, 'ellipse')
+                shape_node.set('cx', str(shape.center[0].item()))
+                shape_node.set('cy', str(shape.center[1].item()))
+                shape_node.set('rx', str(shape.radius[0].item()))
+                shape_node.set('ry', str(shape.radius[1].item()))
+            else:
+                raise NotImplementedError(f'shape type: {type(shape)} is not involved in pydiffvg.')
+
+            shape_node.set('stroke-width', str(2 * shape.stroke_width.data.cpu().item()))
+            if shape_group.fill_color is not None:
+                if isinstance(shape_group.fill_color, pydiffvg.LinearGradient):
+                    shape_node.set('fill', 'url(#shape_{}_fill)'.format(i))
+                else:
+                    c = shape_group.fill_color.data.cpu().numpy()
+                    shape_node.set('fill', 'rgb({}, {}, {})'.format(
+                        int(255 * c[0]), int(255 * c[1]), int(255 * c[2])))
+                    shape_node.set('opacity', str(c[3]))
+            else:
+                shape_node.set('fill', 'none')
+            if shape_group.stroke_color is not None:
+                if isinstance(shape_group.stroke_color, pydiffvg.LinearGradient):
+                    shape_node.set('stroke', 'url(#shape_{}_stroke)'.format(i))
+                else:
+                    c = shape_group.stroke_color.data.cpu().numpy()
+                    shape_node.set('stroke', 'rgb({}, {}, {})'.format(
+                        int(255 * c[0]), int(255 * c[1]), int(255 * c[2])))
+                    shape_node.set('stroke-opacity', str(c[3]))
+                shape_node.set('stroke-linecap', 'round')
+                shape_node.set('stroke-linejoin', 'round')
+
+        with open(filename, "w") as f:
+            f.write(pydiffvg.prettify(root))
diff --git a/methods/painter/__init__.py b/methods/painter/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ad761f2f5443eb41b15afc4116a66ecdfa9d918
--- /dev/null
+++ b/methods/painter/__init__.py
@@ -0,0 +1,4 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) XiMing Xing. All rights reserved.
+# Author: XiMing Xing
+# Description:
diff --git a/methods/painter/diffsketcher/ASDS_SDXL_pipeline.py b/methods/painter/diffsketcher/ASDS_SDXL_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4b8764275fdc1be27a584a0bcb761308ceabd93
--- /dev/null
+++ b/methods/painter/diffsketcher/ASDS_SDXL_pipeline.py
@@ -0,0 +1,673 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) XiMing Xing. All rights reserved.
+# Author: XiMing Xing
+# Description:
+
+import PIL
+from PIL import Image
+from typing import Callable, List, Optional, Union, Tuple, AnyStr
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch.cuda.amp import custom_bwd, custom_fwd
+from torchvision import transforms
+from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
+from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
+
+from methods.token2attn.attn_control import AttentionStore
+from methods.token2attn.ptp_utils import text_under_image, view_images
+
+
+class Token2AttnMixinASDSSDXLPipeline(StableDiffusionXLPipeline):
+    r"""
+    Pipeline for text-to-image generation using Stable Diffusion XL.
+    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+    Args:
+        vae ([`AutoencoderKL`]):
+            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+        text_encoder ([`CLIPTextModel`]):
+            Frozen text-encoder. Stable Diffusion uses the text portion of
+            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+        tokenizer (`CLIPTokenizer`):
+            Tokenizer of class
+            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+        scheduler ([`SchedulerMixin`]):
+            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+        safety_checker ([`StableDiffusionSafetyChecker`]):
+            Classification module that estimates whether generated images could be considered offensive or harmful.
+            Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+        feature_extractor ([`CLIPFeatureExtractor`]):
+            Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+    """
+    _optional_components = ["safety_checker", "feature_extractor"]
+
+    @torch.no_grad()
+    def __call__(
+            self,
+            prompt: Union[str, List[str]],
+            prompt_2: Optional[Union[str, List[str]]] = None,
+            height: Optional[int] = None,
+            width: Optional[int] = None,
+            controller: AttentionStore = None,  # feed attention_store as control of ptp
+            num_inference_steps: int = 50,
+            denoising_end: Optional[float] = None,
+            guidance_scale: float = 5.0,
+            negative_prompt: Optional[Union[str, List[str]]] = None,
+            negative_prompt_2: Optional[Union[str, List[str]]] = None,
+            num_images_per_prompt: Optional[int] = 1,
+            eta: float = 0.0,
+            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+            latents: Optional[torch.FloatTensor] = None,
+            output_type: Optional[str] = "pil",
+            return_dict: bool = True,
+            callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+            callback_steps: Optional[int] = 1,
+            original_size: Optional[Tuple[int, int]] = None,
+            crops_coords_top_left: Tuple[int, int] = (0, 0),
+            target_size: Optional[Tuple[int, int]] = None,
+    ):
+        r"""
+        Function invoked when calling the pipeline for generation.
+
+        Args:
+            prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+                instead.
+            prompt_2 (`str` or `List[str]`, *optional*):
+                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+                used in both text-encoders
+            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+                The height in pixels of the generated image.
+            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+                The width in pixels of the generated image.
+            num_inference_steps (`int`, *optional*, defaults to 50):
+                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+                expense of slower inference.
+            denoising_end (`float`, *optional*):
+                When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+                completed before it is intentionally prematurely terminated. As a result, the returned sample will
+                still retain a substantial amount of noise as determined by the discrete timesteps selected by the
+                scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
+                "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+                Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
+            guidance_scale (`float`, *optional*, defaults to 5.0):
+                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+                `guidance_scale` is defined as `w` of equation 2. of [Imagen
+                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+                usually at the expense of lower image quality.
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. If not defined, one has to pass
+                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+                less than `1`).
+            negative_prompt_2 (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+            num_images_per_prompt (`int`, *optional*, defaults to 1):
+                The number of images to generate per prompt.
+            eta (`float`, *optional*, defaults to 0.0):
+                Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+                [`schedulers.DDIMScheduler`], will be ignored for others.
+            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+                to make generation deterministic.
+            latents (`torch.FloatTensor`, *optional*):
+                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+                tensor will ge generated by sampling using the supplied random `generator`.
+            output_type (`str`, *optional*, defaults to `"pil"`):
+                The output format of the generate image. Choose between
+                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+                of a plain tuple.
+            callback (`Callable`, *optional*):
+                A function that will be called every `callback_steps` steps during inference. The function will be
+                called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+            callback_steps (`int`, *optional*, defaults to 1):
+                The frequency at which the `callback` function will be called. If not specified, the callback will be
+                called at every step.
+            original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+                If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+                `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
+                explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+            crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+                `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+                `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+                `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+            target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+                For most cases, `target_size` should be set to the desired height and width of the generated image. If
+                not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
+                section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+
+        Examples:
+
+        Returns:
+            [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
+            [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+            `tuple`. When returning a tuple, the first element is a list with the generated images.
+        """
+
+        self.register_attention_control(controller)  # add attention controller
+
+        # 0. Default height and width to unet
+        height = height or self.unet.config.sample_size * self.vae_scale_factor
+        width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+        original_size = original_size or (height, width)
+        target_size = target_size or (height, width)
+
+        # 1. Check inputs. Raise error if not correct
+        self.check_inputs(prompt, prompt_2, height, width, callback_steps)
+
+        # 2. Define call parameters
+        batch_size = 1 if isinstance(prompt, str) else len(prompt)
+        device = self._execution_device
+        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+        # corresponds to doing no classifier free guidance.
+        do_classifier_free_guidance = guidance_scale > 1.0
+
+        # 3. Encode input prompt
+        (
+            text_embeddings,
+            negative_text_embeddings,
+            pooled_text_embeddings,
+            negative_pooled_text_embeddings,
+        ) = self.encode_prompt(
+            prompt=prompt,
+            prompt_2=prompt_2,
+            device=device,
+            num_images_per_prompt=num_images_per_prompt,
+            do_classifier_free_guidance=do_classifier_free_guidance,
+            negative_prompt=negative_prompt,
+            negative_prompt_2=negative_prompt_2,
+        )
+
+        # 4. Prepare timesteps
+        self.scheduler.set_timesteps(num_inference_steps, device=device)
+        timesteps = self.scheduler.timesteps
+
+        # 5. Prepare latent variables
+        try:
+            num_channels_latents = self.unet.config.in_channels
+        except Exception or Warning:
+            num_channels_latents = self.unet.in_channels
+
+        latents = self.prepare_latents(
+            batch_size * num_images_per_prompt,
+            num_channels_latents,
+            height,
+            width,
+            text_embeddings.dtype,
+            device,
+            generator,
+            latents,
+        )
+
+        # 6. Prepare extra step kwargs. inherit TODO: Logic should ideally just be moved out of the pipeline
+        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+        # 7. Prepare added time ids & embeddings
+        add_text_embeddings = pooled_text_embeddings
+        add_time_ids = self._get_add_time_ids(
+            original_size, crops_coords_top_left, target_size, dtype=text_embeddings.dtype
+        )
+
+        if do_classifier_free_guidance:
+            text_embeddings = torch.cat([negative_text_embeddings, text_embeddings], dim=0)
+            add_text_embeddings = torch.cat([negative_pooled_text_embeddings, add_text_embeddings], dim=0)
+            add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
+
+        text_embeddings = text_embeddings.to(device)
+        add_text_embeddings = add_text_embeddings.to(device)
+        add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+        # 8. Denoising loop
+
+        # 8.1 Apply denoising_end
+        if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
+            discrete_timestep_cutoff = int(
+                round(
+                    self.scheduler.config.num_train_timesteps
+                    - (denoising_end * self.scheduler.config.num_train_timesteps)
+                )
+            )
+            num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+            timesteps = timesteps[:num_inference_steps]
+
+        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+        with self.progress_bar(total=num_inference_steps) as progress_bar:
+            for i, t in enumerate(timesteps):
+                # expand the latents if we are doing classifier free guidance
+                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+                # predict the noise residual
+                added_cond_kwargs = {"text_embeds": add_text_embeddings, "time_ids": add_time_ids}
+                noise_pred = self.unet(
+                    latent_model_input,
+                    t,
+                    encoder_hidden_states=text_embeddings,
+                    added_cond_kwargs=added_cond_kwargs
+                ).sample
+
+                # perform guidance
+                if do_classifier_free_guidance:
+                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+                # compute the previous noisy sample x_t -> x_t-1
+                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+                # step callback
+                latents = controller.step_callback(latents)
+
+                # call the callback, if provided
+                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+                    progress_bar.update()
+                    if callback is not None and i % callback_steps == 0:
+                        callback(i, t, latents)
+
+        # 9. Post-processing
+
+        # The decode_latents method is deprecated and has been removed in sdxl
+        # image = self.decode_latents(latents)
+
+        # make sure the VAE is in float32 mode, as it overflows in float16
+        if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
+            self.upcast_vae()
+            latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+        if not output_type == "latent":
+            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+        else:
+            image = latents
+            return StableDiffusionXLPipelineOutput(images=image)
+
+        # apply watermark if available
+        if self.watermark is not None:
+            image = self.watermark.apply_watermark(image)
+
+        image = self.image_processor.postprocess(image, output_type=output_type)
+
+        if not return_dict:
+            return (image,)
+
+        return StableDiffusionXLPipelineOutput(images=image)
+
+    def encode2latents(self,
+                       image,
+                       batch_size,
+                       num_images_per_prompt,
+                       dtype,
+                       device,
+                       generator=None):
+        if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+            raise ValueError(
+                f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+            )
+
+        # Offload text encoder if `enable_model_cpu_offload` was enabled
+        if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+            self.text_encoder_2.to("cpu")
+            torch.cuda.empty_cache()
+
+        image = image.to(device=device, dtype=dtype)
+
+        batch_size = batch_size * num_images_per_prompt
+
+        if image.shape[1] == 4:
+            init_latents = image
+        else:
+            # make sure the VAE is in float32 mode, as it overflows in float16
+            if self.vae.config.force_upcast:
+                image = image.float()
+                self.vae.to(dtype=torch.float32)
+
+            if isinstance(generator, list) and len(generator) != batch_size:
+                raise ValueError(
+                    f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+                    f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+                )
+
+            elif isinstance(generator, list):
+                init_latents = [
+                    self.vae.encode(image[i: i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
+                ]
+                init_latents = torch.cat(init_latents, dim=0)
+            else:
+                init_latents = self.vae.encode(image).latent_dist.sample(generator)
+
+            if self.vae.config.force_upcast:
+                self.vae.to(dtype)
+
+            init_latents = init_latents.to(dtype)
+            init_latents = self.vae.config.scaling_factor * init_latents
+
+        if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+            # expand init_latents for batch_size
+            additional_image_per_prompt = batch_size // init_latents.shape[0]
+            init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
+        elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+            raise ValueError(
+                f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+            )
+        else:
+            init_latents = torch.cat([init_latents], dim=0)
+
+        latents = init_latents
+
+        return latents
+
+    @staticmethod
+    def S_aug(sketch: torch.Tensor,
+              im_res: int = 1024,
+              augments: str = "affine_contrast"):
+        # init augmentations
+        augment_list = []
+        if "affine" in augments:
+            augment_list.append(
+                transforms.RandomPerspective(fill=0, p=1.0, distortion_scale=0.5)
+            )
+            augment_list.append(
+                transforms.RandomResizedCrop(im_res, scale=(0.8, 0.8), ratio=(1.0, 1.0))
+            )
+        if "contrast" in augments:
+            # 2: increases the sharpness by a factor of 2.
+            augment_list.append(
+                transforms.RandomAdjustSharpness(sharpness_factor=2)
+            )
+        augment_compose = transforms.Compose(augment_list)
+
+        return augment_compose(sketch)
+
+    def score_distillation_sampling(self,
+                                    pred_rgb: torch.Tensor,
+                                    crop_size: int,
+                                    augments: str,
+                                    prompt: Union[List, str],
+                                    prompt_2: Optional[Union[List, str]] = None,
+                                    height: Optional[int] = None,
+                                    width: Optional[int] = None,
+                                    negative_prompt: Union[List, str] = None,
+                                    negative_prompt_2: Optional[Union[List, str]] = None,
+                                    guidance_scale: float = 100,
+                                    as_latent: bool = False,
+                                    grad_scale: float = 1,
+                                    t_range: Union[List[float], Tuple[float]] = (0.05, 0.95),
+                                    original_size: Optional[Tuple[int, int]] = None,
+                                    crops_coords_top_left: Tuple[int, int] = (0, 0),
+                                    target_size: Optional[Tuple[int, int]] = None):
+
+        height = height or self.unet.config.sample_size * self.vae_scale_factor
+        width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+        original_size = original_size or (height, width)
+        target_size = target_size or (height, width)
+
+        batch_size = 1 if isinstance(prompt, str) else len(prompt)
+
+        num_train_timesteps = self.scheduler.config.num_train_timesteps
+        min_step = int(num_train_timesteps * t_range[0])
+        max_step = int(num_train_timesteps * t_range[1])
+        alphas = self.scheduler.alphas_cumprod.to(self.device)  # for convenience
+
+        num_images_per_prompt = 1  # the number of images to generate per prompt
+
+        #  Encode input prompt
+        do_classifier_free_guidance = guidance_scale > 1.0
+        (
+            text_embeddings,
+            negative_text_embeddings,
+            pooled_text_embeddings,
+            negative_pooled_text_embeddings,
+        ) = self.encode_prompt(
+            prompt=prompt,
+            prompt_2=prompt_2,
+            device=self.device,
+            num_images_per_prompt=num_images_per_prompt,
+            do_classifier_free_guidance=do_classifier_free_guidance,
+            negative_prompt=negative_prompt,
+            negative_prompt_2=negative_prompt_2,
+        )
+
+        # sketch augmentation
+        pred_rgb_a = self.S_aug(pred_rgb, crop_size, augments)
+
+        # interp to 512x512 to be fed into vae.
+        if as_latent:
+            latents = F.interpolate(pred_rgb_a, (128, 128), mode='bilinear', align_corners=False) * 2 - 1
+        else:
+            # encode image into latents via vae, requires grad!
+            latents = self.encode2latents(
+                pred_rgb_a,
+                batch_size,
+                num_images_per_prompt,
+                text_embeddings.dtype,
+                self.device
+            )
+
+        # timestep ~ U(0.05, 0.95) to avoid very high/low noise level
+        t = torch.randint(min_step, max_step + 1, [1], dtype=torch.long, device=self.device)
+
+        # 7. Prepare added time ids & embeddings
+        add_text_embeddings = pooled_text_embeddings
+        add_time_ids = self._get_add_time_ids(
+            original_size, crops_coords_top_left, target_size, dtype=text_embeddings.dtype
+        )
+
+        if do_classifier_free_guidance:
+            text_embeddings = torch.cat([negative_text_embeddings, text_embeddings], dim=0)
+            add_text_embeddings = torch.cat([negative_pooled_text_embeddings, add_text_embeddings], dim=0)
+            add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
+
+        text_embeddings = text_embeddings.to(self.device)
+        add_text_embeddings = add_text_embeddings.to(self.device)
+        add_time_ids = add_time_ids.to(self.device).repeat(batch_size * num_images_per_prompt, 1)
+
+        # predict the noise residual with unet, stop gradient
+        with torch.no_grad():
+            # add noise
+            noise = torch.randn_like(latents)
+            latents_noisy = self.scheduler.add_noise(latents, noise, t)
+            # pred noise
+            latent_model_input = torch.cat([latents_noisy] * 2) if do_classifier_free_guidance else latents_noisy
+            # predict the noise residual
+            added_cond_kwargs = {"text_embeds": add_text_embeddings, "time_ids": add_time_ids}
+            noise_pred = self.unet(
+                latent_model_input,
+                t,
+                encoder_hidden_states=text_embeddings,
+                added_cond_kwargs=added_cond_kwargs
+            ).sample
+
+        # perform guidance (high scale from paper!)
+        if do_classifier_free_guidance:
+            noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2)
+            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_pos - noise_pred_uncond)
+
+        # w(t), sigma_t^2
+        w = (1 - alphas[t])
+        grad = grad_scale * w * (noise_pred - noise)
+        grad = torch.nan_to_num(grad)
+
+        # since we omitted an item in grad, we need to use the custom function to specify the gradient
+        loss = SpecifyGradient.apply(latents, grad)
+
+        return loss, grad.mean()
+
+    def register_attention_control(self, controller):
+        attn_procs = {}
+        cross_att_count = 0
+        for name in self.unet.attn_processors.keys():
+            cross_attention_dim = None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim
+            if name.startswith("mid_block"):
+                hidden_size = self.unet.config.block_out_channels[-1]
+                place_in_unet = "mid"
+            elif name.startswith("up_blocks"):
+                block_id = int(name[len("up_blocks.")])
+                hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id]
+                place_in_unet = "up"
+            elif name.startswith("down_blocks"):
+                block_id = int(name[len("down_blocks.")])
+                hidden_size = self.unet.config.block_out_channels[block_id]
+                place_in_unet = "down"
+            else:
+                continue
+            cross_att_count += 1
+            attn_procs[name] = P2PCrossAttnProcessor(
+                controller=controller, place_in_unet=place_in_unet
+            )
+
+        self.unet.set_attn_processor(attn_procs)
+        controller.num_att_layers = cross_att_count
+
+    @staticmethod
+    def aggregate_attention(prompts,
+                            attention_store: AttentionStore,
+                            res: int,
+                            from_where: List[str],
+                            is_cross: bool,
+                            select: int):
+        if isinstance(prompts, str):
+            prompts = [prompts]
+        assert isinstance(prompts, list)
+
+        out = []
+        attention_maps = attention_store.get_average_attention()
+        num_pixels = res ** 2
+        for location in from_where:
+            for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
+                if item.shape[1] == num_pixels:
+                    cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
+                    out.append(cross_maps)
+        out = torch.cat(out, dim=0)
+        out = out.sum(0) / out.shape[0]
+        return out.cpu()
+
+    def get_cross_attention(self,
+                            prompts,
+                            attention_store: AttentionStore,
+                            res: int,
+                            from_where: List[str],
+                            select: int = 0,
+                            save_path=None):
+        tokens = self.tokenizer.encode(prompts[select])
+        decoder = self.tokenizer.decode
+        # shape: [res ** 2, res ** 2, seq_len]
+        attention_maps = self.aggregate_attention(prompts, attention_store, res, from_where, True, select)
+
+        images = []
+        for i in range(len(tokens)):
+            image = attention_maps[:, :, i]
+            image = 255 * image / image.max()
+            image = image.unsqueeze(-1).expand(*image.shape, 3)
+            image = image.numpy().astype(np.uint8)
+            image = np.array(Image.fromarray(image).resize((256, 256)))
+            image = text_under_image(image, decoder(int(tokens[i])))
+            images.append(image)
+        image_array = np.stack(images, axis=0)
+        view_images(image_array, save_image=True, fp=save_path)
+
+        return attention_maps, tokens
+
+    def get_self_attention_comp(self,
+                                prompts,
+                                attention_store: AttentionStore,
+                                res: int,
+                                from_where: List[str],
+                                img_size: int = 224,
+                                max_com=10,
+                                select: int = 0,
+                                save_path: AnyStr = None):
+        attention_maps = self.aggregate_attention(prompts, attention_store, res, from_where, False, select)
+        attention_maps = attention_maps.numpy().reshape((res ** 2, res ** 2))
+        # shape: [res ** 2, res ** 2]
+        u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True))
+        print(f"self-attention_maps: {attention_maps.shape}, "
+              f"u: {u.shape}, "
+              f"s: {s.shape}, "
+              f"vh: {vh.shape}")
+
+        images = []
+        vh_returns = []
+        for i in range(max_com):
+            image = vh[i].reshape(res, res)
+            image = (image - image.min()) / (image.max() - image.min())
+            image = 255 * image
+
+            ret_ = Image.fromarray(image).resize((img_size, img_size), resample=PIL.Image.Resampling.BILINEAR)
+            vh_returns.append(np.array(ret_))
+
+            image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8)
+            image = Image.fromarray(image).resize((256, 256))
+            image = np.array(image)
+            images.append(image)
+        image_array = np.stack(images, axis=0)
+        view_images(image_array, num_rows=max_com // 10, offset_ratio=0,
+                    save_image=True, fp=save_path / "self-attn-vh.png")
+
+        return attention_maps, (u, s, vh), np.stack(vh_returns, axis=0)
+
+
+class P2PCrossAttnProcessor:
+
+    def __init__(self, controller, place_in_unet):
+        super().__init__()
+        self.controller = controller
+        self.place_in_unet = place_in_unet
+
+    def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None):
+        batch_size, sequence_length, _ = hidden_states.shape
+        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size=batch_size)
+
+        query = attn.to_q(hidden_states)
+
+        is_cross = encoder_hidden_states is not None
+        encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+        key = attn.to_k(encoder_hidden_states)
+        value = attn.to_v(encoder_hidden_states)
+
+        query = attn.head_to_batch_dim(query)
+        key = attn.head_to_batch_dim(key)
+        value = attn.head_to_batch_dim(value)
+
+        attention_probs = attn.get_attention_scores(query, key, attention_mask)
+
+        # one line change
+        self.controller(attention_probs, is_cross, self.place_in_unet)
+
+        hidden_states = torch.bmm(attention_probs, value)
+        hidden_states = attn.batch_to_head_dim(hidden_states)
+
+        # linear proj
+        hidden_states = attn.to_out[0](hidden_states)
+        # dropout
+        hidden_states = attn.to_out[1](hidden_states)
+
+        return hidden_states
+
+
+class SpecifyGradient(torch.autograd.Function):
+
+    @staticmethod
+    @custom_fwd
+    def forward(ctx, input_tensor, gt_grad):
+        ctx.save_for_backward(gt_grad)
+        # we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward.
+        return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype)
+
+    @staticmethod
+    @custom_bwd
+    def backward(ctx, grad_scale):
+        gt_grad, = ctx.saved_tensors
+        gt_grad = gt_grad * grad_scale
+        return gt_grad, None
diff --git a/methods/painter/diffsketcher/ASDS_pipeline.py b/methods/painter/diffsketcher/ASDS_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..9fa0be2af9ab0f6e5de5a6ce41bffa261bb4cd0f
--- /dev/null
+++ b/methods/painter/diffsketcher/ASDS_pipeline.py
@@ -0,0 +1,481 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) XiMing Xing. All rights reserved.
+# Author: XiMing Xing
+# Description:
+
+import PIL
+from PIL import Image
+from typing import Callable, List, Optional, Union, Tuple, AnyStr
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch.cuda.amp import custom_bwd, custom_fwd
+from torchvision import transforms
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
+
+from methods.token2attn.attn_control import AttentionStore
+from methods.token2attn.ptp_utils import text_under_image, view_images
+
+
+class Token2AttnMixinASDSPipeline(StableDiffusionPipeline):
+    r"""
+    Pipeline for text-to-image generation using Stable Diffusion.
+    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+    Args:
+        vae ([`AutoencoderKL`]):
+            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+        text_encoder ([`CLIPTextModel`]):
+            Frozen text-encoder. Stable Diffusion uses the text portion of
+            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+        tokenizer (`CLIPTokenizer`):
+            Tokenizer of class
+            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+        scheduler ([`SchedulerMixin`]):
+            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+        safety_checker ([`StableDiffusionSafetyChecker`]):
+            Classification module that estimates whether generated images could be considered offensive or harmful.
+            Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+        feature_extractor ([`CLIPFeatureExtractor`]):
+            Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+    """
+    _optional_components = ["safety_checker", "feature_extractor"]
+
+    @torch.no_grad()
+    def __call__(
+            self,
+            prompt: Union[str, List[str]],
+            height: Optional[int] = None,
+            width: Optional[int] = None,
+            controller: AttentionStore = None,  # feed attention_store as control of ptp
+            num_inference_steps: int = 50,
+            guidance_scale: float = 7.5,
+            negative_prompt: Optional[Union[str, List[str]]] = None,
+            num_images_per_prompt: Optional[int] = 1,
+            eta: float = 0.0,
+            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+            latents: Optional[torch.FloatTensor] = None,
+            output_type: Optional[str] = "pil",
+            return_dict: bool = True,
+            callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+            callback_steps: Optional[int] = 1,
+    ):
+        r"""
+        Function invoked when calling the pipeline for generation.
+
+        Args:
+            prompt (`str` or `List[str]`):
+                The prompt or prompts to guide the image generation.
+            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+                The height in pixels of the generated image.
+            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+                The width in pixels of the generated image.
+            num_inference_steps (`int`, *optional*, defaults to 50):
+                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+                expense of slower inference.
+            guidance_scale (`float`, *optional*, defaults to 7.5):
+                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+                `guidance_scale` is defined as `w` of equation 2. of [Imagen
+                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+                usually at the expense of lower image quality.
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+                if `guidance_scale` is less than `1`).
+            num_images_per_prompt (`int`, *optional*, defaults to 1):
+                The number of images to generate per prompt.
+            eta (`float`, *optional*, defaults to 0.0):
+                Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+                [`schedulers.DDIMScheduler`], will be ignored for others.
+            generator (`torch.Generator`, *optional*):
+                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+                to make generation deterministic.
+            latents (`torch.FloatTensor`, *optional*):
+                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+                tensor will ge generated by sampling using the supplied random `generator`.
+            output_type (`str`, *optional*, defaults to `"pil"`):
+                The output format of the generate image. Choose between
+                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+                plain tuple.
+            callback (`Callable`, *optional*):
+                A function that will be called every `callback_steps` steps during inference. The function will be
+                called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+            callback_steps (`int`, *optional*, defaults to 1):
+                The frequency at which the `callback` function will be called. If not specified, the callback will be
+                called at every step.
+
+        Returns:
+            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+            When returning a tuple, the first element is a list with the generated images, and the second element is a
+            list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+            (nsfw) content, according to the `safety_checker`.
+        """
+
+        self.register_attention_control(controller)  # add attention controller
+
+        # 0. Default height and width to unet
+        height = height or self.unet.config.sample_size * self.vae_scale_factor
+        width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+        # 1. Check inputs. Raise error if not correct
+        self.check_inputs(prompt, height, width, callback_steps)
+
+        # 2. Define call parameters
+        batch_size = 1 if isinstance(prompt, str) else len(prompt)
+        device = self._execution_device
+        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+        # corresponds to doing no classifier free guidance.
+        do_classifier_free_guidance = guidance_scale > 1.0
+
+        # 3. Encode input prompt
+        text_embeddings = self._encode_prompt(
+            prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
+        )
+
+        # 4. Prepare timesteps
+        self.scheduler.set_timesteps(num_inference_steps, device=device)
+        timesteps = self.scheduler.timesteps
+
+        # 5. Prepare latent variables
+        try:
+            num_channels_latents = self.unet.config.in_channels
+        except Exception or Warning:
+            num_channels_latents = self.unet.in_channels
+
+        latents = self.prepare_latents(
+            batch_size * num_images_per_prompt,
+            num_channels_latents,
+            height,
+            width,
+            text_embeddings.dtype,
+            device,
+            generator,
+            latents,
+        )
+
+        # 6. Prepare extra step kwargs. inherit TODO: Logic should ideally just be moved out of the pipeline
+        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+        # 7. Denoising loop
+        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+        with self.progress_bar(total=num_inference_steps) as progress_bar:
+            for i, t in enumerate(timesteps):
+                # expand the latents if we are doing classifier free guidance
+                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+                # predict the noise residual
+                noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+                # perform guidance
+                if do_classifier_free_guidance:
+                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+                # compute the previous noisy sample x_t -> x_t-1
+                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+                # step callback
+                latents = controller.step_callback(latents)
+
+                # call the callback, if provided
+                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+                    progress_bar.update()
+                    if callback is not None and i % callback_steps == 0:
+                        callback(i, t, latents)
+
+        # image = self.decode_latents(latents)
+
+        # 8. Post-processing
+        # 9. Run safety checker
+        if not output_type == "latent":
+            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+            # image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
+        else:
+            image = latents
+        has_nsfw_concept = None
+
+        if has_nsfw_concept is None:
+            do_denormalize = [True] * image.shape[0]
+        # else:
+        #     do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+        # 10. Convert to output_type
+        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+        if not return_dict:
+            return (image, has_nsfw_concept)
+
+        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
+    def encode_(self, images):
+        images = (2 * images - 1).clamp(-1.0, 1.0)  # images: [B, 3, H, W]
+
+        # encode images
+        latents = self.vae.encode(images).latent_dist.sample()
+        latents = self.vae.config.scaling_factor * latents
+
+        # scale the initial noise by the standard deviation required by the scheduler
+        latents = latents * self.scheduler.init_noise_sigma
+
+        return latents
+
+    @staticmethod
+    def S_aug(sketch: torch.Tensor,
+              crop_size: int = 512,
+              augments: str = "affine_contrast"):
+        # init augmentations
+        augment_list = []
+        if "affine" in augments:
+            augment_list.append(
+                transforms.RandomPerspective(fill=0, p=1.0, distortion_scale=0.5)
+            )
+            augment_list.append(
+                transforms.RandomResizedCrop(crop_size, scale=(0.8, 0.8), ratio=(1.0, 1.0))
+            )
+        if "contrast" in augments:
+            # 2: increases the sharpness by a factor of 2.
+            augment_list.append(
+                transforms.RandomAdjustSharpness(sharpness_factor=2)
+            )
+        augment_compose = transforms.Compose(augment_list)
+
+        return augment_compose(sketch)
+
+    def score_distillation_sampling(self,
+                                    pred_rgb: torch.Tensor,
+                                    crop_size: int,
+                                    augments: str,
+                                    prompt: Union[List, str],
+                                    negative_prompt: Union[List, str] = None,
+                                    guidance_scale: float = 100,
+                                    as_latent: bool = False,
+                                    grad_scale: float = 1,
+                                    t_range: Union[List[float], Tuple[float]] = (0.02, 0.98)):
+        num_train_timesteps = self.scheduler.config.num_train_timesteps
+        min_step = int(num_train_timesteps * t_range[0])
+        max_step = int(num_train_timesteps * t_range[1])
+        alphas = self.scheduler.alphas_cumprod.to(self.device)  # for convenience
+
+        # sketch augmentation
+        pred_rgb_a = self.S_aug(pred_rgb, crop_size, augments)
+
+        # interp to crop_size x crop_size to be fed into vae.
+        if as_latent:
+            latents = F.interpolate(pred_rgb_a, (64, 64), mode='bilinear', align_corners=False) * 2 - 1
+        else:
+            # encode image into latents with vae, requires grad!
+            latents = self.encode_(pred_rgb_a)
+
+        #  Encode input prompt
+        num_images_per_prompt = 1  # the number of images to generate per prompt
+        do_classifier_free_guidance = guidance_scale > 1.0
+        text_embeddings = self._encode_prompt(
+            prompt, self.device, num_images_per_prompt, do_classifier_free_guidance,
+            negative_prompt=negative_prompt,
+        )
+
+        # timestep ~ U(0.02, 0.98) to avoid very high/low noise level
+        t = torch.randint(min_step, max_step + 1, [1], dtype=torch.long, device=self.device)
+
+        # predict the noise residual with unet, stop gradient
+        with torch.no_grad():
+            # add noise
+            noise = torch.randn_like(latents)
+            latents_noisy = self.scheduler.add_noise(latents, noise, t)
+            # pred noise
+            latent_model_input = torch.cat([latents_noisy] * 2) if do_classifier_free_guidance else latents_noisy
+            noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+        # perform guidance (high scale from paper!)
+        if do_classifier_free_guidance:
+            noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2)
+            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_pos - noise_pred_uncond)
+
+        # w(t), sigma_t^2
+        w = (1 - alphas[t])
+        grad = grad_scale * w * (noise_pred - noise)
+        grad = torch.nan_to_num(grad)
+
+        # since we omitted an item in grad, we need to use the custom function to specify the gradient
+        loss = SpecifyGradient.apply(latents, grad)
+
+        return loss, grad.mean()
+
+    def register_attention_control(self, controller):
+        attn_procs = {}
+        cross_att_count = 0
+        for name in self.unet.attn_processors.keys():
+            cross_attention_dim = None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim
+            if name.startswith("mid_block"):
+                hidden_size = self.unet.config.block_out_channels[-1]
+                place_in_unet = "mid"
+            elif name.startswith("up_blocks"):
+                block_id = int(name[len("up_blocks.")])
+                hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id]
+                place_in_unet = "up"
+            elif name.startswith("down_blocks"):
+                block_id = int(name[len("down_blocks.")])
+                hidden_size = self.unet.config.block_out_channels[block_id]
+                place_in_unet = "down"
+            else:
+                continue
+            cross_att_count += 1
+            attn_procs[name] = P2PCrossAttnProcessor(
+                controller=controller, place_in_unet=place_in_unet
+            )
+
+        self.unet.set_attn_processor(attn_procs)
+        controller.num_att_layers = cross_att_count
+
+    @staticmethod
+    def aggregate_attention(prompts,
+                            attention_store: AttentionStore,
+                            res: int,
+                            from_where: List[str],
+                            is_cross: bool,
+                            select: int):
+        if isinstance(prompts, str):
+            prompts = [prompts]
+        assert isinstance(prompts, list)
+
+        out = []
+        attention_maps = attention_store.get_average_attention()
+        num_pixels = res ** 2
+        for location in from_where:
+            for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
+                if item.shape[1] == num_pixels:
+                    cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
+                    out.append(cross_maps)
+        out = torch.cat(out, dim=0)
+        out = out.sum(0) / out.shape[0]
+        return out.cpu()
+
+    def get_cross_attention(self,
+                            prompts,
+                            attention_store: AttentionStore,
+                            res: int,
+                            from_where: List[str],
+                            select: int = 0,
+                            save_path=None):
+        tokens = self.tokenizer.encode(prompts[select])
+        decoder = self.tokenizer.decode
+        # shape: [res ** 2, res ** 2, seq_len]
+        attention_maps = self.aggregate_attention(prompts, attention_store, res, from_where, True, select)
+
+        images = []
+        for i in range(len(tokens)):
+            image = attention_maps[:, :, i]
+            image = 255 * image / image.max()
+            image = image.unsqueeze(-1).expand(*image.shape, 3)
+            image = image.numpy().astype(np.uint8)
+            image = np.array(Image.fromarray(image).resize((256, 256)))
+            image = text_under_image(image, decoder(int(tokens[i])))
+            images.append(image)
+        image_array = np.stack(images, axis=0)
+        view_images(image_array, save_image=True, fp=save_path)
+
+        return attention_maps, tokens
+
+    def get_self_attention_comp(self,
+                                prompts,
+                                attention_store: AttentionStore,
+                                res: int,
+                                from_where: List[str],
+                                img_size: int = 224,
+                                max_com=10,
+                                select: int = 0,
+                                save_path: AnyStr = None):
+        attention_maps = self.aggregate_attention(prompts, attention_store, res, from_where, False, select)
+        attention_maps = attention_maps.numpy().reshape((res ** 2, res ** 2))
+        # shape: [res ** 2, res ** 2]
+        u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True))
+        print(f"self-attention_maps: {attention_maps.shape}, "
+              f"u: {u.shape}, "
+              f"s: {s.shape}, "
+              f"vh: {vh.shape}")
+
+        images = []
+        vh_returns = []
+        for i in range(max_com):
+            image = vh[i].reshape(res, res)
+            image = (image - image.min()) / (image.max() - image.min())
+            image = 255 * image
+
+            ret_ = Image.fromarray(image).resize((img_size, img_size), resample=PIL.Image.Resampling.BILINEAR)
+            vh_returns.append(np.array(ret_))
+
+            image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8)
+            image = Image.fromarray(image).resize((256, 256))
+            image = np.array(image)
+            images.append(image)
+        image_array = np.stack(images, axis=0)
+        view_images(image_array, num_rows=max_com // 10, offset_ratio=0,
+                    save_image=True, fp=save_path / "self-attn-vh.png")
+
+        return attention_maps, (u, s, vh), np.stack(vh_returns, axis=0)
+
+
+class P2PCrossAttnProcessor:
+
+    def __init__(self, controller, place_in_unet):
+        super().__init__()
+        self.controller = controller
+        self.place_in_unet = place_in_unet
+
+    def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None):
+        batch_size, sequence_length, _ = hidden_states.shape
+        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size=batch_size)
+
+        query = attn.to_q(hidden_states)
+
+        is_cross = encoder_hidden_states is not None
+        encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+        key = attn.to_k(encoder_hidden_states)
+        value = attn.to_v(encoder_hidden_states)
+
+        query = attn.head_to_batch_dim(query)
+        key = attn.head_to_batch_dim(key)
+        value = attn.head_to_batch_dim(value)
+
+        attention_probs = attn.get_attention_scores(query, key, attention_mask)
+
+        # one line change
+        self.controller(attention_probs, is_cross, self.place_in_unet)
+
+        hidden_states = torch.bmm(attention_probs, value)
+        hidden_states = attn.batch_to_head_dim(hidden_states)
+
+        # linear proj
+        hidden_states = attn.to_out[0](hidden_states)
+        # dropout
+        hidden_states = attn.to_out[1](hidden_states)
+
+        return hidden_states
+
+
+class SpecifyGradient(torch.autograd.Function):
+
+    @staticmethod
+    @custom_fwd
+    def forward(ctx, input_tensor, gt_grad):
+        ctx.save_for_backward(gt_grad)
+        # we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward.
+        return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype)
+
+    @staticmethod
+    @custom_bwd
+    def backward(ctx, grad_scale):
+        gt_grad, = ctx.saved_tensors
+        gt_grad = gt_grad * grad_scale
+        return gt_grad, None
diff --git a/methods/painter/diffsketcher/__init__.py b/methods/painter/diffsketcher/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..301150cd7deeb2b097b4959b99f240a39a3632d9
--- /dev/null
+++ b/methods/painter/diffsketcher/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) XiMing Xing. All rights reserved.
+# Author: XiMing Xing
+# Description:
+
+from .painter_params import Painter, SketchPainterOptimizer
+from .ASDS_pipeline import Token2AttnMixinASDSPipeline
+from .ASDS_SDXL_pipeline import Token2AttnMixinASDSSDXLPipeline
+
+__all__ = [
+    'Painter', 'SketchPainterOptimizer',
+    'Token2AttnMixinASDSPipeline',
+    'Token2AttnMixinASDSSDXLPipeline'
+]
diff --git a/methods/painter/diffsketcher/mask_utils.py b/methods/painter/diffsketcher/mask_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..427af90bd079df83fadc67f6d7c4d73f66f8f175
--- /dev/null
+++ b/methods/painter/diffsketcher/mask_utils.py
@@ -0,0 +1,59 @@
+from PIL import Image
+
+import numpy as np
+import torch
+from torchvision import transforms
+from skimage.transform import resize
+
+from .u2net import U2NET
+
+
+def get_mask_u2net(pil_im, output_dir, u2net_path, device="cpu"):
+    # input preprocess
+    w, h = pil_im.size[0], pil_im.size[1]
+    im_size = min(w, h)
+    data_transforms = transforms.Compose([
+        transforms.Resize(min(320, im_size), interpolation=transforms.InterpolationMode.BICUBIC),
+        transforms.ToTensor(),
+        transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
+                             std=(0.26862954, 0.26130258, 0.27577711)),
+    ])
+    input_im_trans = data_transforms(pil_im).unsqueeze(0).to(device)
+
+    # load U^2 Net model
+    net = U2NET(in_ch=3, out_ch=1)
+    net.load_state_dict(torch.load(u2net_path))
+    net.to(device)
+    net.eval()
+
+    # get mask
+    with torch.no_grad():
+        d1, d2, d3, d4, d5, d6, d7 = net(input_im_trans.detach())
+    pred = d1[:, 0, :, :]
+    pred = (pred - pred.min()) / (pred.max() - pred.min())
+    predict = pred
+    predict[predict < 0.5] = 0
+    predict[predict >= 0.5] = 1
+    mask = torch.cat([predict, predict, predict], dim=0).permute(1, 2, 0)
+    mask = mask.cpu().numpy()
+    mask = resize(mask, (h, w), anti_aliasing=False)
+    mask[mask < 0.5] = 0
+    mask[mask >= 0.5] = 1
+
+    # predict_np = predict.clone().cpu().data.numpy()
+    im = Image.fromarray((mask[:, :, 0] * 255).astype(np.uint8)).convert('RGB')
+    save_path_ = output_dir / "mask.png"
+    im.save(save_path_)
+
+    im_np = np.array(pil_im)
+    im_np = im_np / im_np.max()
+    im_np = mask * im_np
+    im_np[mask == 0] = 1
+    im_final = (im_np / im_np.max() * 255).astype(np.uint8)
+    im_final = Image.fromarray(im_final)
+
+    # free u2net
+    del net
+    torch.cuda.empty_cache()
+
+    return im_final, predict
diff --git a/methods/painter/diffsketcher/painter_params.py b/methods/painter/diffsketcher/painter_params.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a7708669e690906cdf0716aedd95d2dd3e60089
--- /dev/null
+++ b/methods/painter/diffsketcher/painter_params.py
@@ -0,0 +1,329 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) XiMing Xing. All rights reserved.
+# Author: XiMing Xing
+# Description:
+import random
+import pathlib
+
+import numpy as np
+import pydiffvg
+import torch
+import torch.nn as nn
+
+from libs.modules.edge_map.DoG import XDoG
+
+
+class Painter(nn.Module):
+
+    def __init__(
+            self,
+            args,
+            num_strokes=4,
+            num_segments=4,
+            imsize=224,
+            device=None,
+            target_im=None,
+            attention_map=None,
+            mask=None,
+    ):
+        super(Painter, self).__init__()
+
+        self.args = args
+        self.device = device
+
+        self.num_paths = num_strokes
+        self.num_segments = num_segments
+        self.width = args.width
+        self.max_width = args.max_width
+        self.optim_width = args.optim_width
+        self.control_points_per_seg = args.control_points_per_seg
+        self.optim_rgba = args.optim_rgba
+        self.optim_alpha = args.optim_opacity
+        self.num_stages = args.num_stages
+        self.softmax_temp = args.softmax_temp
+
+        self.shapes = []
+        self.shape_groups = []
+        self.num_control_points = 0
+        self.canvas_width, self.canvas_height = imsize, imsize
+        self.points_vars = []
+        self.stroke_width_vars = []
+        self.color_vars = []
+        self.color_vars_threshold = args.color_vars_threshold
+
+        self.path_svg = args.path_svg
+        self.strokes_per_stage = self.num_paths
+        self.optimize_flag = []
+
+        # attention related for strokes initialisation
+        self.attention_init = args.attention_init
+        self.xdog_intersec = args.xdog_intersec
+
+        self.image2clip_input = target_im
+        self.mask = mask
+        self.attention_map = attention_map if self.attention_init else None
+
+        self.thresh = self.set_attention_threshold_map() if self.attention_init else None
+        self.strokes_counter = 0  # counts the number of calls to "get_path"
+
+    def init_image(self, stage=0):
+        if stage > 0:
+            # Noting: if multi stages training than add new strokes on existing ones
+            # don't optimize on previous strokes
+            self.optimize_flag = [False for i in range(len(self.shapes))]
+            for i in range(self.strokes_per_stage):
+                stroke_color = torch.FloatTensor(np.random.uniform(size=[4])) \
+                    if self.args.optim_rgba else torch.tensor([0.0, 0.0, 0.0, 1.0])
+                path = self.get_path()
+                self.shapes.append(path)
+                path_group = pydiffvg.ShapeGroup(shape_ids=torch.tensor([len(self.shapes) - 1]),
+                                                 fill_color=None,
+                                                 stroke_color=stroke_color)
+                self.shape_groups.append(path_group)
+                self.optimize_flag.append(True)
+        else:
+            num_paths_exists = 0
+            if self.path_svg is not None and pathlib.Path(self.path_svg).exists():
+                print(f"-> init svg from `{self.path_svg}` ...")
+
+                self.canvas_width, self.canvas_height, self.shapes, self.shape_groups = self.load_svg(self.path_svg)
+                # if you want to add more strokes to existing ones and optimize on all of them
+                num_paths_exists = len(self.shapes)
+
+            for i in range(num_paths_exists, self.num_paths):
+                stroke_color = torch.FloatTensor(np.random.uniform(size=[4])) \
+                    if self.args.optim_rgba else torch.tensor([0.0, 0.0, 0.0, 1.0])
+                path = self.get_path()
+                self.shapes.append(path)
+                path_group = pydiffvg.ShapeGroup(shape_ids=torch.tensor([len(self.shapes) - 1]),
+                                                 fill_color=None,
+                                                 stroke_color=stroke_color)
+                self.shape_groups.append(path_group)
+            self.optimize_flag = [True for i in range(len(self.shapes))]
+
+        img = self.render_warp()
+        img = img[:, :, 3:4] * img[:, :, :3] + \
+              torch.ones(img.shape[0], img.shape[1], 3, device=self.device) * (1 - img[:, :, 3:4])
+        img = img[:, :, :3]
+        img = img.unsqueeze(0)  # convert img from HWC to NCHW
+        img = img.permute(0, 3, 1, 2).to(self.device)  # NHWC -> NCHW
+
+        return img
+
+    def get_image(self):
+        img = self.render_warp()
+
+        opacity = img[:, :, 3:4]
+        img = opacity * img[:, :, :3] + torch.ones(img.shape[0], img.shape[1], 3, device=self.device) * (1 - opacity)
+        img = img[:, :, :3]
+        img = img.unsqueeze(0)  # convert img from HWC to NCHW
+        img = img.permute(0, 3, 1, 2).to(self.device)  # NHWC -> NCHW
+        return img
+
+    def get_path(self):
+        self.num_control_points = torch.zeros(self.num_segments, dtype=torch.int32) + (self.control_points_per_seg - 2)
+        points = []
+        p0 = self.inds_normalised[self.strokes_counter] if self.attention_init else (random.random(), random.random())
+        points.append(p0)
+
+        for j in range(self.num_segments):
+            radius = 0.05
+            for k in range(self.control_points_per_seg - 1):
+                p1 = (p0[0] + radius * (random.random() - 0.5), p0[1] + radius * (random.random() - 0.5))
+                points.append(p1)
+                p0 = p1
+        points = torch.tensor(points).to(self.device)
+        points[:, 0] *= self.canvas_width
+        points[:, 1] *= self.canvas_height
+
+        path = pydiffvg.Path(num_control_points=self.num_control_points,
+                             points=points,
+                             stroke_width=torch.tensor(self.width),
+                             is_closed=False)
+        self.strokes_counter += 1
+        return path
+
+    def clip_curve_shape(self):
+        if self.optim_width:
+            for path in self.shapes:
+                path.stroke_width.data.clamp_(1.0, self.max_width)
+        if self.optim_rgba:
+            for group in self.shape_groups:
+                group.stroke_color.data.clamp_(0.0, 1.0)
+        else:
+            if self.optim_alpha:
+                for group in self.shape_groups:
+                    # group.stroke_color.data: RGBA
+                    group.stroke_color.data[:3].clamp_(0., 0.)  # to force black stroke
+                    group.stroke_color.data[-1].clamp_(0., 1.)  # opacity
+
+    def path_pruning(self):
+        # stroke pruning
+        for group in self.shape_groups:
+            group.stroke_color.data[-1] = (group.stroke_color.data[-1] >= self.color_vars_threshold).float()
+
+    def render_warp(self):
+        self.clip_curve_shape()
+
+        scene_args = pydiffvg.RenderFunction.serialize_scene(
+            self.canvas_width, self.canvas_height, self.shapes, self.shape_groups
+        )
+        _render = pydiffvg.RenderFunction.apply
+        img = _render(self.canvas_width,  # width
+                      self.canvas_height,  # height
+                      2,  # num_samples_x
+                      2,  # num_samples_y
+                      0,  # seed
+                      None,
+                      *scene_args)
+        return img
+
+    def set_points_parameters(self):
+        # stoke`s location optimization
+        self.points_vars = []
+        for i, path in enumerate(self.shapes):
+            if self.optimize_flag[i]:
+                path.points.requires_grad = True
+                self.points_vars.append(path.points)
+
+    def get_points_params(self):
+        return self.points_vars
+
+    def set_width_parameters(self):
+        # stroke`s  width optimization
+        self.stroke_width_vars = []
+        for i, path in enumerate(self.shapes):
+            if self.optimize_flag[i]:
+                path.stroke_width.requires_grad = True
+                self.stroke_width_vars.append(path.stroke_width)
+
+    def get_width_parameters(self):
+        return self.stroke_width_vars
+
+    def set_color_parameters(self):
+        # for storkes' color optimization (opacity)
+        self.color_vars = []
+        for i, group in enumerate(self.shape_groups):
+            if self.optimize_flag[i]:
+                group.stroke_color.requires_grad = True
+                self.color_vars.append(group.stroke_color)
+
+    def get_color_parameters(self):
+        return self.color_vars
+
+    def save_svg(self, output_dir, fname):
+        pydiffvg.save_svg(f'{output_dir}/{fname}.svg',
+                          self.canvas_width,
+                          self.canvas_height,
+                          self.shapes,
+                          self.shape_groups)
+
+    def load_svg(self, path_svg):
+        canvas_width, canvas_height, shapes, shape_groups = pydiffvg.svg_to_scene(path_svg)
+        return canvas_width, canvas_height, shapes, shape_groups
+
+    @staticmethod
+    def softmax(x, tau=0.2):
+        e_x = np.exp(x / tau)
+        return e_x / e_x.sum()
+
+    def set_inds_ldm(self):
+        attn_map = (self.attention_map - self.attention_map.min()) / \
+                   (self.attention_map.max() - self.attention_map.min())
+
+        if self.xdog_intersec:
+            xdog = XDoG(k=10)
+            im_xdog = xdog(self.image2clip_input[0].permute(1, 2, 0).cpu().numpy())
+            print(f"use XDoG, shape: {im_xdog.shape}")
+            intersec_map = (1 - im_xdog) * attn_map
+            attn_map = intersec_map
+
+        attn_map_soft = np.copy(attn_map)
+        attn_map_soft[attn_map > 0] = self.softmax(attn_map[attn_map > 0], tau=self.softmax_temp)
+
+        # select points
+        k = self.num_stages * self.num_paths
+        self.inds = np.random.choice(range(attn_map.flatten().shape[0]),
+                                     size=k,
+                                     replace=False,
+                                     p=attn_map_soft.flatten())
+        self.inds = np.array(np.unravel_index(self.inds, attn_map.shape)).T
+
+        self.inds_normalised = np.zeros(self.inds.shape)
+        self.inds_normalised[:, 0] = self.inds[:, 1] / self.canvas_width
+        self.inds_normalised[:, 1] = self.inds[:, 0] / self.canvas_height
+        self.inds_normalised = self.inds_normalised.tolist()
+        return attn_map_soft
+
+    def set_attention_threshold_map(self):
+        return self.set_inds_ldm()
+
+    def get_attn(self):
+        return self.attention_map
+
+    def get_thresh(self):
+        return self.thresh
+
+    def get_inds(self):
+        return self.inds
+
+    def get_mask(self):
+        return self.mask
+
+
+class SketchPainterOptimizer:
+
+    def __init__(
+            self,
+            renderer: nn.Module,
+            points_lr: float,
+            optim_alpha: bool,
+            optim_rgba: bool,
+            color_lr: float,
+            optim_width: bool,
+            width_lr: float
+    ):
+        self.renderer = renderer
+
+        self.points_lr = points_lr
+        self.optim_color = optim_alpha or optim_rgba
+        self.color_lr = color_lr
+        self.optim_width = optim_width
+        self.width_lr = width_lr
+
+        self.points_optimizer, self.width_optimizer, self.color_optimizer = None, None, None
+
+    def init_optimizers(self):
+        self.renderer.set_points_parameters()
+        self.points_optimizer = torch.optim.Adam(self.renderer.get_points_params(), lr=self.points_lr)
+        if self.optim_color:
+            self.renderer.set_color_parameters()
+            self.color_optimizer = torch.optim.Adam(self.renderer.get_color_parameters(), lr=self.color_lr)
+        if self.optim_width:
+            self.renderer.set_width_parameters()
+            self.width_optimizer = torch.optim.Adam(self.renderer.get_width_parameters(), lr=self.width_lr)
+
+    def update_lr(self, step, decay_steps=(500, 750)):
+        if step % decay_steps[0] == 0 and step > 0:
+            for param_group in self.points_optimizer.param_groups:
+                param_group['lr'] = 0.4
+        if step % decay_steps[1] == 0 and step > 0:
+            for param_group in self.points_optimizer.param_groups:
+                param_group['lr'] = 0.1
+
+    def zero_grad_(self):
+        self.points_optimizer.zero_grad()
+        if self.optim_color:
+            self.color_optimizer.zero_grad()
+        if self.optim_width:
+            self.width_optimizer.zero_grad()
+
+    def step_(self):
+        self.points_optimizer.step()
+        if self.optim_color:
+            self.color_optimizer.step()
+        if self.optim_width:
+            self.width_optimizer.step()
+
+    def get_lr(self):
+        return self.points_optimizer.param_groups[0]['lr']
diff --git a/methods/painter/diffsketcher/process_svg.py b/methods/painter/diffsketcher/process_svg.py
new file mode 100644
index 0000000000000000000000000000000000000000..deb5668603237ca68f2becbc0662460e835ede70
--- /dev/null
+++ b/methods/painter/diffsketcher/process_svg.py
@@ -0,0 +1,71 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) XiMing Xing. All rights reserved.
+# Author: XiMing Xing
+# Description:
+
+import xml.etree.ElementTree as ET
+import statistics
+
+import argparse
+
+
+def remove_low_opacity_paths(svg_file_path, output_file_path, opacity_delta=0.2):
+    try:
+        # Parse the SVG file
+        tree = ET.parse(svg_file_path)
+        namespace = "http://www.w3.org/2000/svg"
+        ET.register_namespace("", namespace)
+
+        root = tree.getroot()
+        root.set('version', '1.1')
+
+        paths = root.findall('.//{http://www.w3.org/2000/svg}path')
+        # Collect stroke-opacity attribute values
+        opacity_values = []
+        for path in paths:
+            opacity = path.get("stroke-opacity")
+            if opacity is not None:
+                opacity_values.append(float(opacity))
+
+        # Calculate median opacity
+        median_opacity = statistics.median(opacity_values) + opacity_delta
+
+        # Create a temporary list to store paths to be removed
+        paths_to_remove = []
+        for path in paths:
+            opacity = path.get('stroke-opacity')
+            if opacity is not None and float(opacity) < median_opacity:
+                paths_to_remove.append(path)
+
+        # Remove paths from the root element
+        for path in paths_to_remove:
+            path.set('stroke-opacity', '0')
+
+        print(f"n_path: {len(paths)}, "
+              f"opacity_thresh: {median_opacity}, "
+              f"n_path_to_remove: {len(set(paths_to_remove))}.")
+
+        # Save the modified SVG to the specified path
+        tree.write(output_file_path, encoding='utf-8', xml_declaration=True, default_namespace="")
+        # print("SVG file saved successfully.")
+        # print(f"file has been saved in: {output_file_path}")
+    except Exception as e:
+        print(f"An error occurred: {str(e)}")
+
+
+if __name__ == '__main__':
+    """
+    python process_svg.py -save ./workdir/xx.svg -tar ./workdir/xx.svg
+    """
+    parser = argparse.ArgumentParser(description="vary style painterly rendering")
+    parser.add_argument("-tar", "--target_file",
+                        default="", type=str,
+                        help="the path of SVG file place.")
+    parser.add_argument("-save", "--save_path",
+                        default="", type=str,
+                        help="the path of processed SVG file place.")
+    parser.add_argument("-od", "--opacity_delta",
+                        default=0.1, type=float)
+    args = parser.parse_args()
+
+    remove_low_opacity_paths(args.target_file, args.save_path, float(args.opacity_delta))
diff --git a/methods/painter/diffsketcher/sketch_utils.py b/methods/painter/diffsketcher/sketch_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0335ffdd4d7d77bd45b6012422da460524366f27
--- /dev/null
+++ b/methods/painter/diffsketcher/sketch_utils.py
@@ -0,0 +1,172 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) XiMing Xing. All rights reserved.
+# Author: XiMing Xing
+# Description:
+import matplotlib.pyplot as plt
+import numpy as np
+from PIL import Image
+
+import torch
+from torchvision.utils import make_grid
+
+
+def plt_batch(
+        photos: torch.Tensor,
+        sketch: torch.Tensor,
+        step: int,
+        prompt: str,
+        save_path: str,
+        name: str,
+        dpi: int = 300
+):
+    if photos.shape != sketch.shape:
+        raise ValueError("photos and sketch must have the same dimensions")
+
+    plt.figure()
+    plt.subplot(1, 2, 1)  # nrows=1, ncols=2, index=1
+    grid = make_grid(photos, normalize=True, pad_value=2)
+    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
+    plt.imshow(ndarr)
+    plt.axis("off")
+    plt.title("Generated sample")
+
+    plt.subplot(1, 2, 2)  # nrows=1, ncols=2, index=2
+    grid = make_grid(sketch, normalize=False, pad_value=2)
+    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
+    plt.imshow(ndarr)
+    plt.axis("off")
+    plt.title(f"Rendering result - {step} steps")
+
+    plt.suptitle(insert_newline(prompt), fontsize=10)
+
+    plt.tight_layout()
+    plt.savefig(f"{save_path}/{name}.png", dpi=dpi)
+    plt.close()
+
+
+def plt_triplet(
+        photos: torch.Tensor,
+        sketch: torch.Tensor,
+        style: torch.Tensor,
+        step: int,
+        prompt: str,
+        save_path: str,
+        name: str,
+        dpi: int = 300
+):
+    if photos.shape != sketch.shape:
+        raise ValueError("photos and sketch must have the same dimensions")
+
+    plt.figure()
+    plt.subplot(1, 3, 1)  # nrows=1, ncols=3, index=1
+    grid = make_grid(photos, normalize=True, pad_value=2)
+    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
+    plt.imshow(ndarr)
+    plt.axis("off")
+    plt.title("Generated sample")
+
+    plt.subplot(1, 3, 2)  # nrows=1, ncols=3, index=2
+    # style = (style + 1) / 2
+    grid = make_grid(style, normalize=False, pad_value=2)
+    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
+    plt.imshow(ndarr)
+    plt.axis("off")
+    plt.title(f"Style")
+
+    plt.subplot(1, 3, 3)  # nrows=1, ncols=3, index=2
+    # sketch = (sketch + 1) / 2
+    grid = make_grid(sketch, normalize=False, pad_value=2)
+    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
+    plt.imshow(ndarr)
+    plt.axis("off")
+    plt.title(f"Rendering result - {step} steps")
+
+    plt.suptitle(insert_newline(prompt), fontsize=10)
+
+    plt.tight_layout()
+    plt.savefig(f"{save_path}/{name}.png", dpi=dpi)
+    plt.close()
+
+
+def insert_newline(string, point=9):
+    # split by blank
+    words = string.split()
+    if len(words) <= point:
+        return string
+
+    word_chunks = [words[i:i + point] for i in range(0, len(words), point)]
+    new_string = "\n".join(" ".join(chunk) for chunk in word_chunks)
+    return new_string
+
+
+def log_tensor_img(inputs, output_dir, output_prefix="input", norm=False, dpi=300):
+    grid = make_grid(inputs, normalize=norm, pad_value=2)
+    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
+    plt.imshow(ndarr)
+    plt.axis("off")
+    plt.tight_layout()
+    plt.savefig(f"{output_dir}/{output_prefix}.png", dpi=dpi, bbox_inches='tight')
+    plt.close()
+
+
+def plt_tensor_img(tensor, title, save_path, name, dpi=500):
+    grid = make_grid(tensor, normalize=True, pad_value=2)
+    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
+    plt.imshow(ndarr)
+    plt.axis("off")
+    plt.title(f"{title}")
+    plt.savefig(f"{save_path}/{name}.png", dpi=dpi)
+    plt.close()
+
+
+def save_tensor_img(tensor, save_path, name, dpi=500):
+    grid = make_grid(tensor, normalize=True, pad_value=2)
+    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
+    plt.imshow(ndarr)
+    plt.axis("off")
+    plt.tight_layout()
+    plt.savefig(f"{save_path}/{name}.png", dpi=dpi)
+    plt.close()
+
+
+def plt_attn(attn, threshold_map, inputs, inds, output_path):
+    # currently supports one image (and not a batch)
+    plt.figure(figsize=(10, 5))
+
+    plt.subplot(1, 3, 1)
+    main_im = make_grid(inputs, normalize=True, pad_value=2)
+    main_im = np.transpose(main_im.cpu().numpy(), (1, 2, 0))
+    plt.imshow(main_im, interpolation='nearest')
+    plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o')
+    plt.title("input img")
+    plt.axis("off")
+
+    plt.subplot(1, 3, 2)
+    plt.imshow(attn, interpolation='nearest', vmin=0, vmax=1)
+    plt.title("attn map")
+    plt.axis("off")
+
+    plt.subplot(1, 3, 3)
+    threshold_map_ = (threshold_map - threshold_map.min()) / \
+                     (threshold_map.max() - threshold_map.min())
+    plt.imshow(np.nan_to_num(threshold_map_), interpolation='nearest', vmin=0, vmax=1)
+    plt.title("prob softmax")
+    plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o')
+    plt.axis("off")
+
+    plt.tight_layout()
+    plt.savefig(output_path)
+    plt.close()
+
+
+def fix_image_scale(im):
+    im_np = np.array(im) / 255
+    height, width = im_np.shape[0], im_np.shape[1]
+    max_len = max(height, width) + 20
+    new_background = np.ones((max_len, max_len, 3))
+    y, x = max_len // 2 - height // 2, max_len // 2 - width // 2
+    new_background[y: y + height, x: x + width] = im_np
+    new_background = (new_background / new_background.max()
+                      * 255).astype(np.uint8)
+    new_im = Image.fromarray(new_background)
+    return new_im
diff --git a/methods/painter/diffsketcher/strotss.py b/methods/painter/diffsketcher/strotss.py
new file mode 100644
index 0000000000000000000000000000000000000000..c56348dc5460e3894d18a16bfef364ea3e65c011
--- /dev/null
+++ b/methods/painter/diffsketcher/strotss.py
@@ -0,0 +1,253 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) XiMing Xing. All rights reserved.
+# Author: XiMing Xing
+# Description:
+
+import math
+
+import torch
+import torch.nn as nn
+import torchvision
+import numpy as np
+
+
+class VGG16Extractor(nn.Module):
+    def __init__(self, space):
+        super().__init__()
+        # load pretrained model
+        self.vgg_layers = torchvision.models.vgg16(
+            weights=torchvision.models.VGG16_Weights.DEFAULT
+        ).features
+
+        for param in self.parameters():
+            param.requires_grad = False
+        self.capture_layers = [1, 3, 6, 8, 11, 13, 15, 22, 29]
+        self.space = space
+
+    def forward_base(self, x):
+        feat = [x]
+        for i in range(len(self.vgg_layers)):
+            x = self.vgg_layers[i](x)
+            if i in self.capture_layers:
+                feat.append(x)
+        return feat
+
+    def forward(self, x):
+        if self.space != 'vgg':
+            x = (x + 1.) / 2.
+            x = x - (torch.Tensor([0.485, 0.456, 0.406]).to(x.device).view(1, -1, 1, 1))
+            x = x / (torch.Tensor([0.229, 0.224, 0.225]).to(x.device).view(1, -1, 1, 1))
+        feat = self.forward_base(x)
+        return feat
+
+    def forward_samples_hypercolumn(self, X, samps=100):
+        feat = self.forward(X)
+
+        xx, xy = np.meshgrid(np.arange(X.shape[2]), np.arange(X.shape[3]))
+        xx = np.expand_dims(xx.flatten(), 1)
+        xy = np.expand_dims(xy.flatten(), 1)
+        xc = np.concatenate([xx, xy], 1)
+
+        samples = min(samps, xc.shape[0])
+
+        np.random.shuffle(xc)
+        xx = xc[:samples, 0]
+        yy = xc[:samples, 1]
+
+        feat_samples = []
+        for i in range(len(feat)):
+
+            layer_feat = feat[i]
+
+            # hack to detect lower resolution
+            if i > 0 and feat[i].size(2) < feat[i - 1].size(2):
+                xx = xx / 2.0
+                yy = yy / 2.0
+
+            xx = np.clip(xx, 0, layer_feat.shape[2] - 1).astype(np.int32)
+            yy = np.clip(yy, 0, layer_feat.shape[3] - 1).astype(np.int32)
+
+            features = layer_feat[:, :, xx[range(samples)], yy[range(samples)]]
+            feat_samples.append(features.clone().detach())
+
+        feat = torch.cat(feat_samples, 1)
+        return feat
+
+
+class StyleLoss:
+
+    def spatial_feature_extract(self, feat_result, feat_content, xx, xy):
+        l2, l3 = [], []
+        device = feat_result[0].device
+
+        # for each extracted layer
+        for i in range(len(feat_result)):
+            fr = feat_result[i]
+            fc = feat_content[i]
+
+            # hack to detect reduced scale
+            if i > 0 and feat_result[i - 1].size(2) > feat_result[i].size(2):
+                xx = xx / 2.0
+                xy = xy / 2.0
+
+            # go back to ints and get residual
+            xxm = np.floor(xx).astype(np.float32)
+            xxr = xx - xxm
+
+            xym = np.floor(xy).astype(np.float32)
+            xyr = xy - xym
+
+            # do bilinear resample
+            w00 = torch.from_numpy((1. - xxr) * (1. - xyr)).float().view(1, 1, -1, 1).to(device)
+            w01 = torch.from_numpy((1. - xxr) * xyr).float().view(1, 1, -1, 1).to(device)
+            w10 = torch.from_numpy(xxr * (1. - xyr)).float().view(1, 1, -1, 1).to(device)
+            w11 = torch.from_numpy(xxr * xyr).float().view(1, 1, -1, 1).to(device)
+
+            xxm = np.clip(xxm.astype(np.int32), 0, fr.size(2) - 1)
+            xym = np.clip(xym.astype(np.int32), 0, fr.size(3) - 1)
+
+            s00 = xxm * fr.size(3) + xym
+            s01 = xxm * fr.size(3) + np.clip(xym + 1, 0, fr.size(3) - 1)
+            s10 = np.clip(xxm + 1, 0, fr.size(2) - 1) * fr.size(3) + (xym)
+            s11 = np.clip(xxm + 1, 0, fr.size(2) - 1) * fr.size(3) + np.clip(xym + 1, 0, fr.size(3) - 1)
+
+            fr = fr.view(1, fr.size(1), fr.size(2) * fr.size(3), 1)
+            fr = fr[:, :, s00, :].mul_(w00).add_(fr[:, :, s01, :].mul_(w01)).add_(fr[:, :, s10, :].mul_(w10)).add_(
+                fr[:, :, s11, :].mul_(w11))
+
+            fc = fc.view(1, fc.size(1), fc.size(2) * fc.size(3), 1)
+            fc = fc[:, :, s00, :].mul_(w00).add_(fc[:, :, s01, :].mul_(w01)).add_(fc[:, :, s10, :].mul_(w10)).add_(
+                fc[:, :, s11, :].mul_(w11))
+
+            l2.append(fr)
+            l3.append(fc)
+
+        x_st = torch.cat([li.contiguous() for li in l2], 1)
+        c_st = torch.cat([li.contiguous() for li in l3], 1)
+
+        xx = torch.from_numpy(xx).view(1, 1, x_st.size(2), 1).float().to(device)
+        yy = torch.from_numpy(xy).view(1, 1, x_st.size(2), 1).float().to(device)
+
+        x_st = torch.cat([x_st, xx, yy], 1)
+        c_st = torch.cat([c_st, xx, yy], 1)
+        return x_st, c_st
+
+    def rgb_to_yuv(self, rgb):
+        C = torch.Tensor(
+            [[0.577350, 0.577350, 0.577350], [-0.577350, 0.788675, -0.211325], [-0.577350, -0.211325, 0.788675]]
+        ).to(rgb.device)
+        yuv = torch.mm(C, rgb)
+        return yuv
+
+    def pairwise_distances_cos(self, x, y):
+        x_norm = torch.sqrt((x ** 2).sum(1).view(-1, 1))
+        y_t = torch.transpose(y, 0, 1)
+        y_norm = torch.sqrt((y ** 2).sum(1).view(1, -1))
+        dist = 1. - torch.mm(x, y_t) / x_norm / y_norm
+        return dist
+
+    def pairwise_distances_sq_l2(self, x, y):
+        x_norm = (x ** 2).sum(1).view(-1, 1)
+        y_t = torch.transpose(y, 0, 1)
+        y_norm = (y ** 2).sum(1).view(1, -1)
+        dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
+        return torch.clamp(dist, 1e-5, 1e5) / x.size(1)
+
+    def distmat(self, x, y, cos_d=True):
+        if cos_d:
+            M = self.pairwise_distances_cos(x, y)
+        else:
+            M = torch.sqrt(self.pairwise_distances_sq_l2(x, y))
+        return M
+
+    def style_loss(self, X, Y):
+        d = X.shape[1]
+
+        if d == 3:
+            X = self.rgb_to_yuv(X.transpose(0, 1).contiguous().view(d, -1)).transpose(0, 1)
+            Y = self.rgb_to_yuv(Y.transpose(0, 1).contiguous().view(d, -1)).transpose(0, 1)
+        else:
+            X = X.transpose(0, 1).contiguous().view(d, -1).transpose(0, 1)
+            Y = Y.transpose(0, 1).contiguous().view(d, -1).transpose(0, 1)
+
+        # Relaxed EMD
+        CX_M = self.distmat(X, Y, cos_d=True)
+
+        if d == 3:
+            CX_M = CX_M + self.distmat(X, Y, cos_d=False)
+
+        m1, m1_inds = CX_M.min(1)
+        m2, m2_inds = CX_M.min(0)
+
+        remd = torch.max(m1.mean(), m2.mean())
+
+        return remd
+
+    def moment_loss(self, X, Y, moments=[1, 2]):
+        loss = 0.
+        X = X.squeeze().t()
+        Y = Y.squeeze().t()
+
+        mu_x = torch.mean(X, 0, keepdim=True)
+        mu_y = torch.mean(Y, 0, keepdim=True)
+        mu_d = torch.abs(mu_x - mu_y).mean()
+
+        if 1 in moments:
+            loss = loss + mu_d
+
+        if 2 in moments:
+            X_c = X - mu_x
+            Y_c = Y - mu_y
+            X_cov = torch.mm(X_c.t(), X_c) / (X.shape[0] - 1)
+            Y_cov = torch.mm(Y_c.t(), Y_c) / (Y.shape[0] - 1)
+
+            D_cov = torch.abs(X_cov - Y_cov).mean()
+            loss = loss + D_cov
+
+        return loss
+
+    def forward(self, feat_result, feat_content, feat_style, indices, content_weight, moment_weight=1.0):
+        # spatial feature extract
+        num_locations = 1024
+        spatial_result, spatial_content = self.spatial_feature_extract(
+            feat_result, feat_content, indices[0][:num_locations], indices[1][:num_locations]
+        )
+
+        # loss_content = content_loss(spatial_result, spatial_content)
+
+        d = feat_style.shape[1]
+        spatial_style = feat_style.view(1, d, -1, 1)
+        feat_max = 3 + 2 * 64 + 128 * 2 + 256 * 3 + 512 * 2  # (sum of all extracted channels)
+
+        loss_remd = self.style_loss(spatial_result[:, :feat_max, :, :], spatial_style[:, :feat_max, :, :])
+
+        loss_moment = self.moment_loss(spatial_result[:, :-2, :, :],
+                                       spatial_style,
+                                       moments=[1, 2])  # -2 is so that it can fit?
+        # palette matching
+        content_weight_frac = 1. / max(content_weight, 1.)
+        loss_moment += content_weight_frac * self.style_loss(spatial_result[:, :3, :, :], spatial_style[:, :3, :, :])
+
+        loss_style = loss_remd + moment_weight * loss_moment
+        # print(f'Style: {loss_style.item():.3f}, Content: {loss_content.item():.3f}')
+
+        style_weight = 1.0 + moment_weight
+        loss_total = (loss_style) / (content_weight + style_weight)
+        return loss_total
+
+
+def sample_indices(feat_content, feat_style):
+    const = 128 ** 2  # 32k or so
+    big_size = feat_content.shape[2] * feat_content.shape[3]  # num feaxels
+
+    stride_x = int(max(math.floor(math.sqrt(big_size // const)), 1))
+    offset_x = np.random.randint(stride_x)
+    stride_y = int(max(math.ceil(math.sqrt(big_size // const)), 1))
+    offset_y = np.random.randint(stride_y)
+    xx, xy = np.meshgrid(
+        np.arange(feat_content.shape[2])[offset_x::stride_x],
+        np.arange(feat_content.shape[3])[offset_y::stride_y]
+    )
+    xx = xx.flatten()
+    xy = xy.flatten()
+    return xx, xy
diff --git a/methods/painter/diffsketcher/u2net.py b/methods/painter/diffsketcher/u2net.py
new file mode 100644
index 0000000000000000000000000000000000000000..bcedd43ece5537921eb68a4715a076f7d4d0f7cd
--- /dev/null
+++ b/methods/painter/diffsketcher/u2net.py
@@ -0,0 +1,524 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class REBNCONV(nn.Module):
+    def __init__(self, in_ch=3, out_ch=3, dirate=1):
+        super(REBNCONV, self).__init__()
+
+        self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate)
+        self.bn_s1 = nn.BatchNorm2d(out_ch)
+        self.relu_s1 = nn.ReLU(inplace=True)
+
+    def forward(self, x):
+        hx = x
+        xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
+
+        return xout
+
+
+## upsample tensor 'src' to have the same spatial size with tensor 'tar'
+def _upsample_like(src, tar):
+    src = F.interpolate(src, size=tar.shape[2:], mode='bilinear')
+
+    return src
+
+
+### RSU-7 ###
+class RSU7(nn.Module):  # UNet07DRES(nn.Module):
+
+    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
+        super(RSU7, self).__init__()
+
+        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
+
+        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
+        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
+        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
+        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
+        self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
+        self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
+
+        self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
+
+        self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
+
+    def forward(self, x):
+        hx = x
+        hxin = self.rebnconvin(hx)
+
+        hx1 = self.rebnconv1(hxin)
+        hx = self.pool1(hx1)
+
+        hx2 = self.rebnconv2(hx)
+        hx = self.pool2(hx2)
+
+        hx3 = self.rebnconv3(hx)
+        hx = self.pool3(hx3)
+
+        hx4 = self.rebnconv4(hx)
+        hx = self.pool4(hx4)
+
+        hx5 = self.rebnconv5(hx)
+        hx = self.pool5(hx5)
+
+        hx6 = self.rebnconv6(hx)
+
+        hx7 = self.rebnconv7(hx6)
+
+        hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
+        hx6dup = _upsample_like(hx6d, hx5)
+
+        hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
+        hx5dup = _upsample_like(hx5d, hx4)
+
+        hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
+        hx4dup = _upsample_like(hx4d, hx3)
+
+        hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
+        hx3dup = _upsample_like(hx3d, hx2)
+
+        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
+        hx2dup = _upsample_like(hx2d, hx1)
+
+        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
+
+        return hx1d + hxin
+
+
+### RSU-6 ###
+class RSU6(nn.Module):  # UNet06DRES(nn.Module):
+
+    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
+        super(RSU6, self).__init__()
+
+        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
+
+        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
+        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
+        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
+        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
+        self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
+
+        self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
+
+        self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
+
+    def forward(self, x):
+        hx = x
+
+        hxin = self.rebnconvin(hx)
+
+        hx1 = self.rebnconv1(hxin)
+        hx = self.pool1(hx1)
+
+        hx2 = self.rebnconv2(hx)
+        hx = self.pool2(hx2)
+
+        hx3 = self.rebnconv3(hx)
+        hx = self.pool3(hx3)
+
+        hx4 = self.rebnconv4(hx)
+        hx = self.pool4(hx4)
+
+        hx5 = self.rebnconv5(hx)
+
+        hx6 = self.rebnconv6(hx5)
+
+        hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
+        hx5dup = _upsample_like(hx5d, hx4)
+
+        hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
+        hx4dup = _upsample_like(hx4d, hx3)
+
+        hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
+        hx3dup = _upsample_like(hx3d, hx2)
+
+        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
+        hx2dup = _upsample_like(hx2d, hx1)
+
+        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
+
+        return hx1d + hxin
+
+
+### RSU-5 ###
+class RSU5(nn.Module):  # UNet05DRES(nn.Module):
+
+    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
+        super(RSU5, self).__init__()
+
+        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
+
+        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
+        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
+        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
+        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
+
+        self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
+
+        self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
+
+    def forward(self, x):
+        hx = x
+
+        hxin = self.rebnconvin(hx)
+
+        hx1 = self.rebnconv1(hxin)
+        hx = self.pool1(hx1)
+
+        hx2 = self.rebnconv2(hx)
+        hx = self.pool2(hx2)
+
+        hx3 = self.rebnconv3(hx)
+        hx = self.pool3(hx3)
+
+        hx4 = self.rebnconv4(hx)
+
+        hx5 = self.rebnconv5(hx4)
+
+        hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
+        hx4dup = _upsample_like(hx4d, hx3)
+
+        hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
+        hx3dup = _upsample_like(hx3d, hx2)
+
+        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
+        hx2dup = _upsample_like(hx2d, hx1)
+
+        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
+
+        return hx1d + hxin
+
+
+### RSU-4 ###
+class RSU4(nn.Module):  # UNet04DRES(nn.Module):
+
+    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
+        super(RSU4, self).__init__()
+
+        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
+
+        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
+        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
+        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
+
+        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
+
+        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
+
+    def forward(self, x):
+        hx = x
+
+        hxin = self.rebnconvin(hx)
+
+        hx1 = self.rebnconv1(hxin)
+        hx = self.pool1(hx1)
+
+        hx2 = self.rebnconv2(hx)
+        hx = self.pool2(hx2)
+
+        hx3 = self.rebnconv3(hx)
+
+        hx4 = self.rebnconv4(hx3)
+
+        hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
+        hx3dup = _upsample_like(hx3d, hx2)
+
+        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
+        hx2dup = _upsample_like(hx2d, hx1)
+
+        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
+
+        return hx1d + hxin
+
+
+### RSU-4F ###
+class RSU4F(nn.Module):  # UNet04FRES(nn.Module):
+
+    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
+        super(RSU4F, self).__init__()
+
+        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
+
+        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
+        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
+        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
+
+        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
+
+        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
+        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
+        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
+
+    def forward(self, x):
+        hx = x
+
+        hxin = self.rebnconvin(hx)
+
+        hx1 = self.rebnconv1(hxin)
+        hx2 = self.rebnconv2(hx1)
+        hx3 = self.rebnconv3(hx2)
+
+        hx4 = self.rebnconv4(hx3)
+
+        hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
+        hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
+        hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
+
+        return hx1d + hxin
+
+
+##### U^2-Net ####
+class U2NET(nn.Module):
+
+    def __init__(self, in_ch=3, out_ch=1):
+        super(U2NET, self).__init__()
+
+        self.stage1 = RSU7(in_ch, 32, 64)
+        self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.stage2 = RSU6(64, 32, 128)
+        self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.stage3 = RSU5(128, 64, 256)
+        self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.stage4 = RSU4(256, 128, 512)
+        self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.stage5 = RSU4F(512, 256, 512)
+        self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.stage6 = RSU4F(512, 256, 512)
+
+        # decoder
+        self.stage5d = RSU4F(1024, 256, 512)
+        self.stage4d = RSU4(1024, 128, 256)
+        self.stage3d = RSU5(512, 64, 128)
+        self.stage2d = RSU6(256, 32, 64)
+        self.stage1d = RSU7(128, 16, 64)
+
+        self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
+        self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
+        self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
+        self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
+        self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
+        self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
+
+        self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)
+
+    def forward(self, x):
+        hx = x
+
+        # stage 1
+        hx1 = self.stage1(hx)
+        hx = self.pool12(hx1)
+
+        # stage 2
+        hx2 = self.stage2(hx)
+        hx = self.pool23(hx2)
+
+        # stage 3
+        hx3 = self.stage3(hx)
+        hx = self.pool34(hx3)
+
+        # stage 4
+        hx4 = self.stage4(hx)
+        hx = self.pool45(hx4)
+
+        # stage 5
+        hx5 = self.stage5(hx)
+        hx = self.pool56(hx5)
+
+        # stage 6
+        hx6 = self.stage6(hx)
+        hx6up = _upsample_like(hx6, hx5)
+
+        # -------------------- decoder --------------------
+        hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
+        hx5dup = _upsample_like(hx5d, hx4)
+
+        hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
+        hx4dup = _upsample_like(hx4d, hx3)
+
+        hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
+        hx3dup = _upsample_like(hx3d, hx2)
+
+        hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
+        hx2dup = _upsample_like(hx2d, hx1)
+
+        hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
+
+        # side output
+        d1 = self.side1(hx1d)
+
+        d2 = self.side2(hx2d)
+        d2 = _upsample_like(d2, d1)
+
+        d3 = self.side3(hx3d)
+        d3 = _upsample_like(d3, d1)
+
+        d4 = self.side4(hx4d)
+        d4 = _upsample_like(d4, d1)
+
+        d5 = self.side5(hx5d)
+        d5 = _upsample_like(d5, d1)
+
+        d6 = self.side6(hx6)
+        d6 = _upsample_like(d6, d1)
+
+        d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
+
+        return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), \
+               torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), \
+               torch.sigmoid(d6)
+
+
+### U^2-Net small ###
+class U2NETP(nn.Module):
+
+    def __init__(self, in_ch=3, out_ch=1):
+        super(U2NETP, self).__init__()
+
+        self.stage1 = RSU7(in_ch, 16, 64)
+        self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.stage2 = RSU6(64, 16, 64)
+        self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.stage3 = RSU5(64, 16, 64)
+        self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.stage4 = RSU4(64, 16, 64)
+        self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.stage5 = RSU4F(64, 16, 64)
+        self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.stage6 = RSU4F(64, 16, 64)
+
+        # decoder
+        self.stage5d = RSU4F(128, 16, 64)
+        self.stage4d = RSU4(128, 16, 64)
+        self.stage3d = RSU5(128, 16, 64)
+        self.stage2d = RSU6(128, 16, 64)
+        self.stage1d = RSU7(128, 16, 64)
+
+        self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
+        self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
+        self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
+        self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
+        self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
+        self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)
+
+        self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)
+
+    def forward(self, x):
+        hx = x
+
+        # stage 1
+        hx1 = self.stage1(hx)
+        hx = self.pool12(hx1)
+
+        # stage 2
+        hx2 = self.stage2(hx)
+        hx = self.pool23(hx2)
+
+        # stage 3
+        hx3 = self.stage3(hx)
+        hx = self.pool34(hx3)
+
+        # stage 4
+        hx4 = self.stage4(hx)
+        hx = self.pool45(hx4)
+
+        # stage 5
+        hx5 = self.stage5(hx)
+        hx = self.pool56(hx5)
+
+        # stage 6
+        hx6 = self.stage6(hx)
+        hx6up = _upsample_like(hx6, hx5)
+
+        # decoder
+        hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
+        hx5dup = _upsample_like(hx5d, hx4)
+
+        hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
+        hx4dup = _upsample_like(hx4d, hx3)
+
+        hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
+        hx3dup = _upsample_like(hx3d, hx2)
+
+        hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
+        hx2dup = _upsample_like(hx2d, hx1)
+
+        hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
+
+        # side output
+        d1 = self.side1(hx1d)
+
+        d2 = self.side2(hx2d)
+        d2 = _upsample_like(d2, d1)
+
+        d3 = self.side3(hx3d)
+        d3 = _upsample_like(d3, d1)
+
+        d4 = self.side4(hx4d)
+        d4 = _upsample_like(d4, d1)
+
+        d5 = self.side5(hx5d)
+        d5 = _upsample_like(d5, d1)
+
+        d6 = self.side6(hx6)
+        d6 = _upsample_like(d6, d1)
+
+        d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
+
+        return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), \
+               torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), \
+               torch.sigmoid(d6)
diff --git a/methods/token2attn/__init__.py b/methods/token2attn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ad761f2f5443eb41b15afc4116a66ecdfa9d918
--- /dev/null
+++ b/methods/token2attn/__init__.py
@@ -0,0 +1,4 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) XiMing Xing. All rights reserved.
+# Author: XiMing Xing
+# Description:
diff --git a/methods/token2attn/attn_control.py b/methods/token2attn/attn_control.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ec474ccab6ebea796bc7a9d2e91c52ad8d9ed57
--- /dev/null
+++ b/methods/token2attn/attn_control.py
@@ -0,0 +1,264 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) XiMing Xing. All rights reserved.
+# Author: XiMing Xing
+# Description:
+
+from abc import ABC, abstractmethod
+from typing import Optional, Union, Tuple, List, Dict
+
+import torch
+import torch.nn.functional as F
+
+from .ptp_utils import (get_word_inds, get_time_words_attention_alpha)
+from .seq_aligner import (get_replacement_mapper, get_refinement_mapper)
+
+
+class AttentionControl(ABC):
+
+    def __init__(self):
+        self.cur_step = 0
+        self.num_att_layers = -1
+        self.cur_att_layer = 0
+
+    def step_callback(self, x_t):
+        return x_t
+
+    def between_steps(self):
+        return
+
+    @property
+    def num_uncond_att_layers(self):
+        return 0
+
+    @abstractmethod
+    def forward(self, attn, is_cross: bool, place_in_unet: str):
+        raise NotImplementedError
+
+    def __call__(self, attn, is_cross: bool, place_in_unet: str):
+        if self.cur_att_layer >= self.num_uncond_att_layers:
+            h = attn.shape[0]
+            attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
+        self.cur_att_layer += 1
+        if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
+            self.cur_att_layer = 0
+            self.cur_step += 1
+            self.between_steps()
+        return attn
+
+    def reset(self):
+        self.cur_step = 0
+        self.cur_att_layer = 0
+
+
+class EmptyControl(AttentionControl):
+
+    def forward(self, attn, is_cross: bool, place_in_unet: str):
+        return attn
+
+
+class AttentionStore(AttentionControl):
+
+    def __init__(self):
+        super(AttentionStore, self).__init__()
+        self.step_store = self.get_empty_store()
+        self.attention_store = {}
+
+    @staticmethod
+    def get_empty_store():
+        return {"down_cross": [], "mid_cross": [], "up_cross": [],
+                "down_self": [], "mid_self": [], "up_self": []}
+
+    def forward(self, attn, is_cross: bool, place_in_unet: str):
+        key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
+        if attn.shape[1] <= 32 ** 2:  # avoid memory overhead
+            self.step_store[key].append(attn)
+        return attn
+
+    def between_steps(self):
+        if len(self.attention_store) == 0:
+            self.attention_store = self.step_store
+        else:
+            for key in self.attention_store:
+                for i in range(len(self.attention_store[key])):
+                    self.attention_store[key][i] += self.step_store[key][i]
+        self.step_store = self.get_empty_store()
+
+    def get_average_attention(self):
+        average_attention = {
+            key: [item / self.cur_step for item in self.attention_store[key]]
+            for key in self.attention_store
+        }
+        return average_attention
+
+    def reset(self):
+        super(AttentionStore, self).reset()
+        self.step_store = self.get_empty_store()
+        self.attention_store = {}
+
+
+class LocalBlend:
+
+    def __init__(self,
+                 prompts: List[str],
+                 words: [List[List[str]]],
+                 tokenizer,
+                 device,
+                 threshold=.3,
+                 max_num_words=77):
+        self.max_num_words = max_num_words
+
+        alpha_layers = torch.zeros(len(prompts), 1, 1, 1, 1, self.max_num_words)
+        for i, (prompt, words_) in enumerate(zip(prompts, words)):
+            if type(words_) is str:
+                words_ = [words_]
+            for word in words_:
+                ind = get_word_inds(prompt, word, tokenizer)
+                alpha_layers[i, :, :, :, :, ind] = 1
+        self.alpha_layers = alpha_layers.to(device)
+        self.threshold = threshold
+
+    def __call__(self, x_t, attention_store):
+        k = 1
+        maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3]
+        maps = [item.reshape(self.alpha_layers.shape[0], -1, 1, 16, 16, self.max_num_words) for item in maps]
+        maps = torch.cat(maps, dim=1)
+        maps = (maps * self.alpha_layers).sum(-1).mean(1)
+        mask = F.max_pool2d(maps, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k))
+        mask = F.interpolate(mask, size=(x_t.shape[2:]))
+        mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0]
+        mask = mask.gt(self.threshold)
+        mask = (mask[:1] + mask[1:]).float()
+        x_t = x_t[:1] + mask * (x_t - x_t[:1])
+        return x_t
+
+
+class AttentionControlEdit(AttentionStore, ABC):
+
+    def __init__(self,
+                 prompts,
+                 num_steps: int,
+                 cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
+                 self_replace_steps: Union[float, Tuple[float, float]],
+                 local_blend: Optional[LocalBlend],
+                 tokenizer,
+                 device):
+        super(AttentionControlEdit, self).__init__()
+        self.tokenizer = tokenizer
+        self.device = device
+
+        self.batch_size = len(prompts)
+        self.cross_replace_alpha = get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps,
+                                                                  self.tokenizer).to(self.device)
+        if type(self_replace_steps) is float:
+            self_replace_steps = 0, self_replace_steps
+        self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
+        self.local_blend = local_blend  # define outside
+
+    def step_callback(self, x_t):
+        if self.local_blend is not None:
+            x_t = self.local_blend(x_t, self.attention_store)
+        return x_t
+
+    def replace_self_attention(self, attn_base, att_replace):
+        if att_replace.shape[2] <= 16 ** 2:
+            return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)
+        else:
+            return att_replace
+
+    @abstractmethod
+    def replace_cross_attention(self, attn_base, att_replace):
+        raise NotImplementedError
+
+    def forward(self, attn, is_cross: bool, place_in_unet: str):
+        super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet)
+        # FIXME not replace correctly
+        if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]):
+            h = attn.shape[0] // (self.batch_size)
+            attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
+            attn_base, attn_repalce = attn[0], attn[1:]
+            if is_cross:
+                alpha_words = self.cross_replace_alpha[self.cur_step]
+                attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + (
+                        1 - alpha_words) * attn_repalce
+                attn[1:] = attn_repalce_new
+            else:
+                attn[1:] = self.replace_self_attention(attn_base, attn_repalce)
+            attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
+        return attn
+
+
+class AttentionReplace(AttentionControlEdit):
+
+    def __init__(self,
+                 prompts,
+                 num_steps: int,
+                 cross_replace_steps: float,
+                 self_replace_steps: float,
+                 local_blend: Optional[LocalBlend] = None,
+                 tokenizer=None,
+                 device=None):
+        super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps,
+                                               local_blend, tokenizer, device)
+        self.mapper = get_replacement_mapper(prompts, self.tokenizer).to(self.device)
+
+    def replace_cross_attention(self, attn_base, att_replace):
+        return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper)
+
+
+class AttentionRefine(AttentionControlEdit):
+
+    def __init__(self,
+                 prompts,
+                 num_steps: int,
+                 cross_replace_steps: float,
+                 self_replace_steps: float,
+                 local_blend: Optional[LocalBlend] = None,
+                 tokenizer=None,
+                 device=None):
+        super(AttentionRefine, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps,
+                                              local_blend, tokenizer, device)
+        self.mapper, alphas = get_refinement_mapper(prompts, self.tokenizer)
+        self.mapper, alphas = self.mapper.to(self.device), alphas.to(self.device)
+        self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1])
+
+    def replace_cross_attention(self, attn_base, att_replace):
+        attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3)
+        attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas)
+        return attn_replace
+
+
+class AttentionReweight(AttentionControlEdit):
+
+    def __init__(self,
+                 prompts,
+                 num_steps: int,
+                 cross_replace_steps: float,
+                 self_replace_steps: float,
+                 equalizer,
+                 local_blend: Optional[LocalBlend] = None,
+                 controller: Optional[AttentionControlEdit] = None,
+                 tokenizer=None,
+                 device=None):
+        super(AttentionReweight, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps,
+                                                local_blend, tokenizer, device)
+        self.equalizer = equalizer.to(self.device)
+        self.prev_controller = controller
+
+    def replace_cross_attention(self, attn_base, att_replace):
+        if self.prev_controller is not None:
+            attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace)
+        attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :]
+        return attn_replace
+
+
+def get_equalizer(tokenizer, text: str,
+                  word_select: Union[int, Tuple[int, ...]],
+                  values: Union[List[float], Tuple[float, ...]]):
+    if type(word_select) is int or type(word_select) is str:
+        word_select = (word_select,)
+    equalizer = torch.ones(len(values), 77)
+    values = torch.tensor(values, dtype=torch.float32)
+    for word in word_select:
+        inds = get_word_inds(text, word, tokenizer)
+        equalizer[:, inds] = values
+    return equalizer
diff --git a/methods/token2attn/ptp_utils.py b/methods/token2attn/ptp_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4557c4cd929222f7580e444f392c0260b0fbe1dd
--- /dev/null
+++ b/methods/token2attn/ptp_utils.py
@@ -0,0 +1,110 @@
+# -*- coding: utf-8 -*-
+import pathlib
+from typing import Union, Optional, List, Tuple, Dict, Text, BinaryIO
+from PIL import Image
+
+import torch
+import cv2
+import numpy as np
+
+from .seq_aligner import get_word_inds
+
+
+def text_under_image(image: np.ndarray,
+                     text: str,
+                     text_color: Tuple[int, int, int] = (0, 0, 0)) -> np.ndarray:
+    h, w, c = image.shape
+    offset = int(h * .2)
+    img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
+    font = cv2.FONT_HERSHEY_SIMPLEX
+    img[:h] = image
+    textsize = cv2.getTextSize(text, font, 1, 2)[0]
+    text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
+    cv2.putText(img, text, (text_x, text_y), font, 1, text_color, 2)
+    return img
+
+
+def view_images(images: Union[np.ndarray, List],
+                num_rows: int = 1,
+                offset_ratio: float = 0.02,
+                save_image: bool = False,
+                fp: Union[Text, pathlib.Path, BinaryIO] = None) -> np.ndarray:
+    if save_image:
+        assert fp is not None
+
+    if isinstance(images, np.ndarray) and images.ndim == 4:
+        num_empty = images.shape[0] % num_rows
+    else:
+        images = [images] if not isinstance(images, list) else images
+        num_empty = len(images) % num_rows
+
+    empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
+    images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
+    num_items = len(images)
+
+    # Calculate the composite image
+    h, w, c = images[0].shape
+    offset = int(h * offset_ratio)
+    num_cols = int(np.ceil(num_items / num_rows))  # count the number of columns
+    image_h = h * num_rows + offset * (num_rows - 1)
+    image_w = w * num_cols + offset * (num_cols - 1)
+    assert image_h > 0, "Invalid image height: {} (num_rows={}, offset_ratio={}, num_items={})".format(
+        image_h, num_rows, offset_ratio, num_items)
+    assert image_w > 0, "Invalid image width: {} (num_cols={}, offset_ratio={}, num_items={})".format(
+        image_w, num_cols, offset_ratio, num_items)
+    image_ = np.ones((image_h, image_w, 3), dtype=np.uint8) * 255
+
+    # Ensure that the last row is filled with empty images if necessary
+    if len(images) % num_cols > 0:
+        empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
+        num_empty = num_cols - len(images) % num_cols
+        images += [empty_images] * num_empty
+
+    for i in range(num_rows):
+        for j in range(num_cols):
+            k = i * num_cols + j
+            if k >= num_items:
+                break
+            image_[i * (h + offset): i * (h + offset) + h, j * (w + offset): j * (w + offset) + w] = images[k]
+
+    pil_img = Image.fromarray(image_)
+    if save_image:
+        pil_img.save(fp)
+    return pil_img
+
+
+def update_alpha_time_word(alpha,
+                           bounds: Union[float, Tuple[float, float]],
+                           prompt_ind: int,
+                           word_inds: Optional[torch.Tensor] = None):
+    if isinstance(bounds, float):
+        bounds = 0, bounds
+    start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
+    if word_inds is None:
+        word_inds = torch.arange(alpha.shape[2])
+    alpha[: start, prompt_ind, word_inds] = 0
+    alpha[start: end, prompt_ind, word_inds] = 1
+    alpha[end:, prompt_ind, word_inds] = 0
+    return alpha
+
+
+def get_time_words_attention_alpha(prompts, num_steps,
+                                   cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]],
+                                   tokenizer,
+                                   max_num_words=77):
+    if type(cross_replace_steps) is not dict:
+        cross_replace_steps = {"default_": cross_replace_steps}
+    if "default_" not in cross_replace_steps:
+        cross_replace_steps["default_"] = (0., 1.)
+    alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
+    for i in range(len(prompts) - 1):
+        alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"],
+                                                  i)
+    for key, item in cross_replace_steps.items():
+        if key != "default_":
+            inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]
+            for i, ind in enumerate(inds):
+                if len(ind) > 0:
+                    alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)
+    alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words)
+    return alpha_time_words
diff --git a/methods/token2attn/seq_aligner.py b/methods/token2attn/seq_aligner.py
new file mode 100644
index 0000000000000000000000000000000000000000..d534d8ae1b6618604c619d56250293f66c0430f5
--- /dev/null
+++ b/methods/token2attn/seq_aligner.py
@@ -0,0 +1,182 @@
+# -*- coding: utf-8 -*-
+import torch
+import numpy as np
+
+
+class ScoreParams:
+
+    def __init__(self, gap, match, mismatch):
+        self.gap = gap
+        self.match = match
+        self.mismatch = mismatch
+
+    def mis_match_char(self, x, y):
+        if x != y:
+            return self.mismatch
+        else:
+            return self.match
+
+
+def get_matrix(size_x, size_y, gap):
+    matrix = []
+    for i in range(len(size_x) + 1):
+        sub_matrix = []
+        for j in range(len(size_y) + 1):
+            sub_matrix.append(0)
+        matrix.append(sub_matrix)
+    for j in range(1, len(size_y) + 1):
+        matrix[0][j] = j * gap
+    for i in range(1, len(size_x) + 1):
+        matrix[i][0] = i * gap
+    return matrix
+
+
+def get_matrix(size_x, size_y, gap):
+    matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32)
+    matrix[0, 1:] = (np.arange(size_y) + 1) * gap
+    matrix[1:, 0] = (np.arange(size_x) + 1) * gap
+    return matrix
+
+
+def get_traceback_matrix(size_x, size_y):
+    matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32)
+    matrix[0, 1:] = 1
+    matrix[1:, 0] = 2
+    matrix[0, 0] = 4
+    return matrix
+
+
+def global_align(x, y, score):
+    matrix = get_matrix(len(x), len(y), score.gap)
+    trace_back = get_traceback_matrix(len(x), len(y))
+    for i in range(1, len(x) + 1):
+        for j in range(1, len(y) + 1):
+            left = matrix[i, j - 1] + score.gap
+            up = matrix[i - 1, j] + score.gap
+            diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1])
+            matrix[i, j] = max(left, up, diag)
+            if matrix[i, j] == left:
+                trace_back[i, j] = 1
+            elif matrix[i, j] == up:
+                trace_back[i, j] = 2
+            else:
+                trace_back[i, j] = 3
+    return matrix, trace_back
+
+
+def get_aligned_sequences(x, y, trace_back):
+    x_seq = []
+    y_seq = []
+    i = len(x)
+    j = len(y)
+    mapper_y_to_x = []
+    while i > 0 or j > 0:
+        if trace_back[i, j] == 3:
+            x_seq.append(x[i - 1])
+            y_seq.append(y[j - 1])
+            i = i - 1
+            j = j - 1
+            mapper_y_to_x.append((j, i))
+        elif trace_back[i][j] == 1:
+            x_seq.append('-')
+            y_seq.append(y[j - 1])
+            j = j - 1
+            mapper_y_to_x.append((j, -1))
+        elif trace_back[i][j] == 2:
+            x_seq.append(x[i - 1])
+            y_seq.append('-')
+            i = i - 1
+        elif trace_back[i][j] == 4:
+            break
+    mapper_y_to_x.reverse()
+    return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64)
+
+
+def get_mapper(x: str, y: str, tokenizer, max_len=77):
+    x_seq = tokenizer.encode(x)
+    y_seq = tokenizer.encode(y)
+    score = ScoreParams(0, 1, -1)
+    matrix, trace_back = global_align(x_seq, y_seq, score)
+    mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1]
+    alphas = torch.ones(max_len)
+    alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float()
+    mapper = torch.zeros(max_len, dtype=torch.int64)
+    mapper[:mapper_base.shape[0]] = mapper_base[:, 1]
+    mapper[mapper_base.shape[0]:] = len(y_seq) + torch.arange(max_len - len(y_seq))
+    return mapper, alphas
+
+
+def get_refinement_mapper(prompts, tokenizer, max_len=77):
+    x_seq = prompts[0]
+    mappers, alphas = [], []
+    for i in range(1, len(prompts)):
+        mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len)
+        mappers.append(mapper)
+        alphas.append(alpha)
+    return torch.stack(mappers), torch.stack(alphas)
+
+
+def get_word_inds(text: str, word_place: int, tokenizer):
+    split_text = text.split(" ")
+    if type(word_place) is str:
+        word_place = [i for i, word in enumerate(split_text) if word_place == word]
+    elif type(word_place) is int:
+        word_place = [word_place]
+    out = []
+    if len(word_place) > 0:
+        words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
+        cur_len, ptr = 0, 0
+
+        for i in range(len(words_encode)):
+            cur_len += len(words_encode[i])
+            if ptr in word_place:
+                out.append(i + 1)
+            if cur_len >= len(split_text[ptr]):
+                ptr += 1
+                cur_len = 0
+    return np.array(out)
+
+
+def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77):
+    words_x = x.split(' ')
+    words_y = y.split(' ')
+    if len(words_x) != len(words_y):
+        raise ValueError(f"attention replacement edit can only be applied on prompts with the same length"
+                         f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words.")
+    inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]]
+    inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace]
+    inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace]
+    mapper = np.zeros((max_len, max_len))
+    i = j = 0
+    cur_inds = 0
+    while i < max_len and j < max_len:
+        if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i:
+            inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds]
+            if len(inds_source_) == len(inds_target_):
+                mapper[inds_source_, inds_target_] = 1
+            else:
+                ratio = 1 / len(inds_target_)
+                for i_t in inds_target_:
+                    mapper[inds_source_, i_t] = ratio
+            cur_inds += 1
+            i += len(inds_source_)
+            j += len(inds_target_)
+        elif cur_inds < len(inds_source):
+            mapper[i, j] = 1
+            i += 1
+            j += 1
+        else:
+            mapper[j, j] = 1
+            i += 1
+            j += 1
+
+    return torch.from_numpy(mapper).float()
+
+
+def get_replacement_mapper(prompts, tokenizer, max_len=77):
+    x_seq = prompts[0]
+    mappers = []
+    for i in range(1, len(prompts)):
+        mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len)
+        mappers.append(mapper)
+    return torch.stack(mappers)
diff --git a/requirements.txt b/requirements.txt
index 02880b2c318d90cc041f172ff63e9c17fe98b3e6..12d758226caebb8e16cdb05bfb5b0dbdd1e56a28 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,24 @@
-torch>=1.7.0
-torchvision>=0.8.0
-transformers>=4.0.0
-diffusers>=0.10.0
-cairosvg>=2.5.0
-Pillow>=9.0.0
\ No newline at end of file
+torch>=1.12.0
+torchvision>=0.13.0
+diffusers>=0.20.0
+transformers>=4.21.0
+accelerate>=0.12.0
+safetensors>=0.3.0
+hydra-core>=1.3.0
+omegaconf>=2.3.0
+opencv-python>=4.6.0
+scikit-image>=0.19.0
+matplotlib>=3.5.0
+numpy>=1.21.0
+scipy>=1.9.0
+einops>=0.6.0
+timm>=0.6.0
+ftfy>=6.1.0
+regex>=2022.7.0
+tqdm>=4.64.0
+svgwrite>=1.4.0
+svgpathtools>=1.4.0
+freetype-py>=2.3.0
+shapely>=1.8.0
+svgutils>=0.3.0
+clip-by-openai>=1.0
\ No newline at end of file