File size: 6,798 Bytes
2b534de | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | """
wild mixture of
https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
https://github.com/CompVis/taming-transformers
-- merci
"""
import global_
import os
import sys
from pathlib import Path
import torch
import torch.nn as nn
import numpy as np
import pytorch_lightning as pl
from torch.optim.lr_scheduler import LambdaLR
from einops import rearrange, repeat
from contextlib import contextmanager
from functools import partial
from tqdm import tqdm
from torchvision.utils import make_grid
from pytorch_lightning.utilities.distributed import rank_zero_only
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
from ldm.modules.ema import LitEma
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
from ldm.models.diffusion.ddim import DDIMSampler
from torchvision.transforms import Resize
import torchvision.transforms.functional as TF
import torch.nn.functional as F
import math
import time
import random
import copy
from torch.autograd import Variable
import torch.distributed as dist
from torch.distributed.optim import ZeroRedundancyOptimizer
from src.Face_models.encoders.model_irse import Backbone
import dlib
from eval_tool.lpips.lpips import LPIPS
from PIL import Image
import argparse
from contextlib import nullcontext
from util_face import *
from util_vis import vis_tensors_A
from my_py_lib.image_util import save_any_A,imgs_2_grid_A
from my_py_lib.torch_util import recursive_to
from my_py_lib.torch_util import custom_repr_v3
from confs import *
from lmk_util.lmk_extractor import LandmarkExtractor,lmkAll_2_lmkMain
from ldm.modules.encoders.modules import FrozenCLIPEmbedder
from ldm.modules.diffusionmodules.openaimodel import UNetModel
from MoE import *
__conditioning_keys__ = {'concat': 'c_concat',
'crossattn': 'c_crossattn',
'adm': 'y'}
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
def un_norm_clip(x1):
x = x1*1.0 # to avoid changing the original tensor or clone() can be used
reduce=False
if len(x.shape)==3:
x = x.unsqueeze(0)
reduce=True
x[:,0,:,:] = x[:,0,:,:] * 0.26862954 + 0.48145466
x[:,1,:,:] = x[:,1,:,:] * 0.26130258 + 0.4578275
x[:,2,:,:] = x[:,2,:,:] * 0.27577711 + 0.40821073
if reduce:
x = x.squeeze(0)
return x
def un_norm(x):
return (x+1.0)/2.0
def save_clip_img(img, path,clip=True):
if clip:
img=un_norm_clip(img)
else:
img=torch.clamp(un_norm(img), min=0.0, max=1.0)
img = img.cpu().numpy().transpose((1, 2, 0))
img = (img * 255).astype(np.uint8)
img = Image.fromarray(img)
img.save(path)
# if clip:
# img=TF.normalize(img, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
# else:
# img=TF.normalize(img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
class IDLoss(nn.Module):
def __init__(self,opts,multiscale=False):
super(IDLoss, self).__init__()
print('Loading ResNet ArcFace')
self.opts = opts
self.multiscale = multiscale
self.face_pool_1 = torch.nn.AdaptiveAvgPool2d((256, 256))
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
# self.facenet=iresnet100(pretrained=False, fp16=False) # changed by sanoojan
self.facenet.load_state_dict(torch.load(opts.other_params.arcface_path))
self.face_pool_2 = torch.nn.AdaptiveAvgPool2d((112, 112))
self.facenet.eval()
self.set_requires_grad(False)
def set_requires_grad(self, flag=True):
for p in self.parameters():
p.requires_grad = flag
def extract_feats(self, x,clip_img=True):
# breakpoint()
if clip_img:
x = un_norm_clip(x)
x = TF.normalize(x, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
x = self.face_pool_1(x) if x.shape[2]!=256 else x # (1) resize to 256 if needed
x = x[:, :, 35:223, 32:220] # (2) Crop interesting region
x = self.face_pool_2(x) # (3) resize to 112 to fit pre-trained model
# breakpoint()
x_feats = self.facenet(x, multi_scale=self.multiscale )
# x_feats = self.facenet(x) # changed by sanoojan
return x_feats
def forward(self, y_hat, y,clip_img=True,return_seperate=False):
n_samples = y.shape[0]
y_feats_ms = self.extract_feats(y,clip_img=clip_img) # Otherwise use the feature from there
y_hat_feats_ms = self.extract_feats(y_hat,clip_img=clip_img)
y_feats_ms = [y_f.detach() for y_f in y_feats_ms]
loss_all = 0
sim_improvement_all = 0
seperate_sim=[]
for y_hat_feats, y_feats in zip(y_hat_feats_ms, y_feats_ms):
loss = 0
sim_improvement = 0
count = 0
# lossess = []
for i in range(n_samples):
sim_target = y_hat_feats[i].dot(y_feats[i])
sim_views = y_feats[i].dot(y_feats[i])
seperate_sim.append(sim_target)
loss += 1 - sim_target # id loss
sim_improvement += float(sim_target) - float(sim_views)
count += 1
loss_all += loss / count
sim_improvement_all += sim_improvement / count
if return_seperate:
return loss_all, sim_improvement_all, seperate_sim
return loss_all, sim_improvement_all, None
def uniform_on_device(r1, r2, shape, device):
return (r1 - r2) * torch.rand(*shape, device=device) + r2
class LandmarkDetectionModel(nn.Module):
def __init__(self):
super(LandmarkDetectionModel, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(640, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.landmark_predictor = nn.Linear(128 * 32 * 32, 68 * 2) # Adjust output size as needed
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
landmarks = self.landmark_predictor(x)
return landmarks
|