Upload 3 files
Browse files- app.py +75 -0
- diffusion.py +139 -0
- modules.py +243 -0
app.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import os
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import torch
|
| 6 |
+
import torchvision.transforms as transforms
|
| 7 |
+
import requests
|
| 8 |
+
|
| 9 |
+
# Function to download the model from Google Drive
|
| 10 |
+
def download_file_from_google_drive(id, destination):
|
| 11 |
+
URL = "https://drive.google.com/uc?export=download"
|
| 12 |
+
session = requests.Session()
|
| 13 |
+
response = session.get(URL, params={'id': id}, stream=True)
|
| 14 |
+
token = get_confirm_token(response)
|
| 15 |
+
|
| 16 |
+
if token:
|
| 17 |
+
params = {'id': id, 'confirm': token}
|
| 18 |
+
response = session.get(URL, params=params, stream=True)
|
| 19 |
+
|
| 20 |
+
save_response_content(response, destination)
|
| 21 |
+
|
| 22 |
+
def get_confirm_token(response):
|
| 23 |
+
for key, value in response.cookies.items():
|
| 24 |
+
if key.startswith('download_warning'):
|
| 25 |
+
return value
|
| 26 |
+
return None
|
| 27 |
+
|
| 28 |
+
def save_response_content(response, destination):
|
| 29 |
+
CHUNK_SIZE = 32768
|
| 30 |
+
with open(destination, "wb") as f:
|
| 31 |
+
for chunk in response.iter_content(CHUNK_SIZE):
|
| 32 |
+
if chunk: # filter out keep-alive new chunks
|
| 33 |
+
f.write(chunk)
|
| 34 |
+
|
| 35 |
+
# Replace 'YOUR_FILE_ID' with your actual file ID from Google Drive
|
| 36 |
+
file_id = '1WJ33nys02XpPDsMO5uIZFiLqTuAT_iuV'
|
| 37 |
+
destination = 'ema_ckpt_cond.pt'
|
| 38 |
+
download_file_from_google_drive(file_id, destination)
|
| 39 |
+
|
| 40 |
+
# Preprocessing
|
| 41 |
+
from modules import PaletteModelV2
|
| 42 |
+
from diffusion import Diffusion_cond
|
| 43 |
+
|
| 44 |
+
device = 'cuda'
|
| 45 |
+
|
| 46 |
+
model = PaletteModelV2(c_in=2, c_out=1, num_classes=5, image_size=256, true_img_size=64).to(device)
|
| 47 |
+
ckpt = torch.load(destination, map_location=device)
|
| 48 |
+
model.load_state_dict(ckpt)
|
| 49 |
+
|
| 50 |
+
diffusion = Diffusion_cond(noise_steps=1000, img_size=256, device=device)
|
| 51 |
+
model.eval()
|
| 52 |
+
|
| 53 |
+
transform_hmi = transforms.Compose([
|
| 54 |
+
transforms.ToTensor(),
|
| 55 |
+
transforms.Resize((256, 256)),
|
| 56 |
+
transforms.RandomVerticalFlip(p=1.0),
|
| 57 |
+
transforms.Normalize(mean=(0.5,), std=(0.5,))
|
| 58 |
+
])
|
| 59 |
+
|
| 60 |
+
def generate_image(seed_image):
|
| 61 |
+
seed_image_tensor = transform_hmi(Image.open(seed_image)).reshape(1, 1, 256, 256).to(device)
|
| 62 |
+
generated_image = diffusion.sample(model, y=seed_image_tensor, labels=None, n=1)
|
| 63 |
+
generated_image_pil = transforms.ToPILImage()(generated_image.squeeze().cpu())
|
| 64 |
+
return generated_image_pil
|
| 65 |
+
|
| 66 |
+
# Create Gradio interface
|
| 67 |
+
iface = gr.Interface(
|
| 68 |
+
fn=generate_image,
|
| 69 |
+
inputs="file",
|
| 70 |
+
outputs="image",
|
| 71 |
+
title="Magnetogram-to-Magnetogram: Generative Forecasting of Solar Evolution",
|
| 72 |
+
description="Upload a LoS magnetogram and predict how it is going to be in 24 hours."
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
iface.launch()
|
diffusion.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Created on Tue Apr 25 14:45:59 2023
|
| 4 |
+
|
| 5 |
+
@author: pio-r
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import logging
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S")
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Diffusion_cond:
|
| 18 |
+
def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=256, img_channel=1, device="cuda"):
|
| 19 |
+
self.noise_steps = noise_steps # timestesps
|
| 20 |
+
self.beta_start = beta_start
|
| 21 |
+
self.beta_end = beta_end
|
| 22 |
+
self.img_channel = img_channel
|
| 23 |
+
self.img_size = img_size
|
| 24 |
+
self.device = device
|
| 25 |
+
|
| 26 |
+
self.beta = self.prepare_noise_schedule().to(device)
|
| 27 |
+
self.alpha = 1. - self.beta
|
| 28 |
+
self.alphas_prev = torch.cat([torch.tensor([1.0]).to(device), self.alpha[:-1]], dim=0)
|
| 29 |
+
self.alpha_hat = torch.cumprod(self.alpha, dim=0)
|
| 30 |
+
self.alphas_cumprod_prev = torch.cat([torch.tensor([1.0]).to(device), self.alpha_hat[:-1]], dim=0)
|
| 31 |
+
# self.alphas_cumprod_prev = torch.from_numpy(np.append(1, self.alpha_hat[:-1].cpu().numpy())).to(device)
|
| 32 |
+
def prepare_noise_schedule(self):
|
| 33 |
+
return torch.linspace(self.beta_start, self.beta_end, self.noise_steps) # linear variance schedule as proposed by Ho et al 2020
|
| 34 |
+
|
| 35 |
+
def noise_images(self, x, t):
|
| 36 |
+
sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
|
| 37 |
+
sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
|
| 38 |
+
Ɛ = torch.randn_like(x)
|
| 39 |
+
return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛ # equation in the paper from Ho et al that describes the noise processs
|
| 40 |
+
|
| 41 |
+
def sample_timesteps(self, n):
|
| 42 |
+
return torch.randint(low=1, high=self.noise_steps, size=(n,))
|
| 43 |
+
|
| 44 |
+
def sample(self, model, n, y, labels, cfg_scale=3, eta=1, sampling_mode='ddpm'):
|
| 45 |
+
logging.info(f"Sampling {n} new images....")
|
| 46 |
+
model.eval() # evaluation mode
|
| 47 |
+
with torch.no_grad(): # algorithm 2 from DDPM
|
| 48 |
+
x = torch.randn((n, self.img_channel, self.img_size, self.img_size)).to(self.device)
|
| 49 |
+
for i in tqdm(reversed(range(1, self.noise_steps)), position=0): # reverse loop from T to 1
|
| 50 |
+
t = (torch.ones(n) * i).long().to(self.device) # create timesteps tensor of length n
|
| 51 |
+
predicted_noise = model(x, y, labels, t)
|
| 52 |
+
if cfg_scale > 0:
|
| 53 |
+
uncond_predicted_noise = model(x, y, None, t)
|
| 54 |
+
predicted_noise = torch.lerp(uncond_predicted_noise, predicted_noise, cfg_scale)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
alpha = self.alpha[t][:, None, None, None]
|
| 58 |
+
alpha_hat = self.alpha_hat[t][:, None, None, None] # this is noise, created in one
|
| 59 |
+
alpha_prev = self.alphas_cumprod_prev[t][:, None, None, None]
|
| 60 |
+
beta = self.beta[t][:, None, None, None]
|
| 61 |
+
# SAMPLING adjusted from Stable diffusion
|
| 62 |
+
sigma = (
|
| 63 |
+
eta
|
| 64 |
+
* torch.sqrt((1 - alpha_prev) / (1 - alpha_hat)
|
| 65 |
+
* (1 - alpha_hat / alpha_prev))
|
| 66 |
+
)
|
| 67 |
+
if i > 1:
|
| 68 |
+
noise = torch.randn_like(x)
|
| 69 |
+
else:
|
| 70 |
+
noise = torch.zeros_like(x)
|
| 71 |
+
# pred_x0 = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise)
|
| 72 |
+
pred_x0 = (x - torch.sqrt(1 - alpha_hat) * predicted_noise) / torch.sqrt(alpha_hat)
|
| 73 |
+
if sampling_mode == 'ddpm':
|
| 74 |
+
x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
|
| 75 |
+
elif sampling_mode == 'ddim':
|
| 76 |
+
noise = torch.randn_like(x)
|
| 77 |
+
nonzero_mask = (
|
| 78 |
+
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
| 79 |
+
)
|
| 80 |
+
x = (
|
| 81 |
+
torch.sqrt(alpha_prev) * pred_x0 +
|
| 82 |
+
torch.sqrt(1 - alpha_prev - sigma ** 2) * predicted_noise +
|
| 83 |
+
nonzero_mask * sigma * noise
|
| 84 |
+
)
|
| 85 |
+
else:
|
| 86 |
+
print('The sampler {} is not implemented'.format(sampling_mode))
|
| 87 |
+
break
|
| 88 |
+
model.train() # it goes back to training mode
|
| 89 |
+
# x = (x.clamp(-1, 1) + 1) / 2 # to be in [-1, 1], the plus 1 and the division by 2 is to bring back values to [0, 1]
|
| 90 |
+
# x = (x * 255).type(torch.uint8) # to bring in valid pixel range
|
| 91 |
+
return x
|
| 92 |
+
|
| 93 |
+
mse = nn.MSELoss()
|
| 94 |
+
|
| 95 |
+
def psnr(input: torch.Tensor, target: torch.Tensor, max_val: float) -> torch.Tensor:
|
| 96 |
+
r"""Create a function that calculates the PSNR between 2 images.
|
| 97 |
+
|
| 98 |
+
PSNR is Peek Signal to Noise Ratio, which is similar to mean squared error.
|
| 99 |
+
Given an m x n image, the PSNR is:
|
| 100 |
+
|
| 101 |
+
.. math::
|
| 102 |
+
|
| 103 |
+
\text{PSNR} = 10 \log_{10} \bigg(\frac{\text{MAX}_I^2}{MSE(I,T)}\bigg)
|
| 104 |
+
|
| 105 |
+
where
|
| 106 |
+
|
| 107 |
+
.. math::
|
| 108 |
+
|
| 109 |
+
\text{MSE}(I,T) = \frac{1}{mn}\sum_{i=0}^{m-1}\sum_{j=0}^{n-1} [I(i,j) - T(i,j)]^2
|
| 110 |
+
|
| 111 |
+
and :math:`\text{MAX}_I` is the maximum possible input value
|
| 112 |
+
(e.g for floating point images :math:`\text{MAX}_I=1`).
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
input: the input image with arbitrary shape :math:`(*)`.
|
| 116 |
+
labels: the labels image with arbitrary shape :math:`(*)`.
|
| 117 |
+
max_val: The maximum value in the input tensor.
|
| 118 |
+
|
| 119 |
+
Return:
|
| 120 |
+
the computed loss as a scalar.
|
| 121 |
+
|
| 122 |
+
Examples:
|
| 123 |
+
>>> ones = torch.ones(1)
|
| 124 |
+
>>> psnr(ones, 1.2 * ones, 2.) # 10 * log(4/((1.2-1)**2)) / log(10)
|
| 125 |
+
tensor(20.0000)
|
| 126 |
+
|
| 127 |
+
Reference:
|
| 128 |
+
https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio#Definition
|
| 129 |
+
"""
|
| 130 |
+
if not isinstance(input, torch.Tensor):
|
| 131 |
+
raise TypeError(f"Expected torch.Tensor but got {type(target)}.")
|
| 132 |
+
|
| 133 |
+
if not isinstance(target, torch.Tensor):
|
| 134 |
+
raise TypeError(f"Expected torch.Tensor but got {type(input)}.")
|
| 135 |
+
|
| 136 |
+
if input.shape != target.shape:
|
| 137 |
+
raise TypeError(f"Expected tensors of equal shapes, but got {input.shape} and {target.shape}")
|
| 138 |
+
|
| 139 |
+
return 10.0 * torch.log10(max_val**2 / mse(input, target))
|
modules.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Created on Tue Apr 25 14:28:21 2023
|
| 4 |
+
|
| 5 |
+
@author: pio-r
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from torch.nn import functional as F
|
| 11 |
+
from torch.utils.checkpoint import checkpoint
|
| 12 |
+
|
| 13 |
+
class EMA:
|
| 14 |
+
def __init__(self, beta):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.beta = beta
|
| 17 |
+
self.step = 0
|
| 18 |
+
|
| 19 |
+
def update_model_average(self, ma_model, current_model):
|
| 20 |
+
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
|
| 21 |
+
old_weight, up_weight = ma_params.data, current_params.data
|
| 22 |
+
ma_params.data = self.update_average(old_weight, up_weight)
|
| 23 |
+
|
| 24 |
+
def update_average(self, old, new):
|
| 25 |
+
if old is None:
|
| 26 |
+
return new
|
| 27 |
+
return old * self.beta + (1 - self.beta) * new
|
| 28 |
+
|
| 29 |
+
def step_ema(self, ema_model, model, step_start_ema=2000):
|
| 30 |
+
if self.step < step_start_ema:
|
| 31 |
+
self.reset_parameters(ema_model, model)
|
| 32 |
+
self.step += 1
|
| 33 |
+
return
|
| 34 |
+
self.update_model_average(ema_model, model)
|
| 35 |
+
self.step += 1
|
| 36 |
+
|
| 37 |
+
def reset_parameters(self, ema_model, model):
|
| 38 |
+
ema_model.load_state_dict(model.state_dict())
|
| 39 |
+
|
| 40 |
+
class SelfAttention(nn.Module):
|
| 41 |
+
"""
|
| 42 |
+
Pre Layer norm -> multi-headed tension -> skip connections -> pass it to
|
| 43 |
+
the feed forward layer (layer-norm -> 2 multiheadattention)
|
| 44 |
+
"""
|
| 45 |
+
def __init__(self, channels, size):
|
| 46 |
+
super(SelfAttention, self).__init__()
|
| 47 |
+
self.channels = channels
|
| 48 |
+
self.size = size
|
| 49 |
+
self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
|
| 50 |
+
self.ln = nn.LayerNorm([channels])
|
| 51 |
+
self.ff_self = nn.Sequential(
|
| 52 |
+
nn.LayerNorm([channels]),
|
| 53 |
+
nn.Linear(channels, channels),
|
| 54 |
+
nn.GELU(),
|
| 55 |
+
nn.Linear(channels, channels),
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
def forward(self, x):
|
| 59 |
+
x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
|
| 60 |
+
x_ln = self.ln(x)
|
| 61 |
+
attention_value, _ = self.mha(x_ln, x_ln, x_ln)
|
| 62 |
+
attention_value = attention_value + x
|
| 63 |
+
attention_value = self.ff_self(attention_value) + attention_value
|
| 64 |
+
return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class DoubleConv(nn.Module):
|
| 68 |
+
"""
|
| 69 |
+
Normal convolution block, with 2d convolution -> Group Norm -> GeLU -> convolution -> Group Norm
|
| 70 |
+
Possibility to add residual connection providing residual=True
|
| 71 |
+
"""
|
| 72 |
+
def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.residual = residual
|
| 75 |
+
if not mid_channels:
|
| 76 |
+
mid_channels = out_channels
|
| 77 |
+
self.double_conv = nn.Sequential(
|
| 78 |
+
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
|
| 79 |
+
nn.GroupNorm(1, mid_channels),
|
| 80 |
+
nn.GELU(),
|
| 81 |
+
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
|
| 82 |
+
nn.GroupNorm(1, out_channels),
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
def forward(self, x):
|
| 86 |
+
if self.residual:
|
| 87 |
+
return F.gelu(x + self.double_conv(x))
|
| 88 |
+
else:
|
| 89 |
+
return self.double_conv(x)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class Down(nn.Module):
|
| 93 |
+
"""
|
| 94 |
+
maxpool reduce size by half -> 2*DoubleConv -> Embedding layer
|
| 95 |
+
|
| 96 |
+
"""
|
| 97 |
+
def __init__(self, in_channels, out_channels, emb_dim=256):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.maxpool_conv = nn.Sequential(
|
| 100 |
+
nn.MaxPool2d(2),
|
| 101 |
+
DoubleConv(in_channels, in_channels, residual=True),
|
| 102 |
+
DoubleConv(in_channels, out_channels),
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
self.emb_layer = nn.Sequential(
|
| 106 |
+
nn.SiLU(),
|
| 107 |
+
nn.Linear( # linear projection to bring the time embedding to the proper dimension
|
| 108 |
+
emb_dim,
|
| 109 |
+
out_channels
|
| 110 |
+
),
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
def forward(self, x, t):
|
| 114 |
+
x = self.maxpool_conv(x)
|
| 115 |
+
emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) # projection
|
| 116 |
+
return x + emb
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class Up(nn.Module):
|
| 120 |
+
"""
|
| 121 |
+
We take the skip connection which comes from the encoder
|
| 122 |
+
"""
|
| 123 |
+
def __init__(self, in_channels, out_channels, emb_dim=256):
|
| 124 |
+
super().__init__()
|
| 125 |
+
|
| 126 |
+
self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
|
| 127 |
+
self.conv = nn.Sequential(
|
| 128 |
+
DoubleConv(in_channels, in_channels, residual=True),
|
| 129 |
+
DoubleConv(in_channels, out_channels, in_channels // 2),
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
self.emb_layer = nn.Sequential(
|
| 133 |
+
nn.SiLU(),
|
| 134 |
+
nn.Linear(
|
| 135 |
+
emb_dim,
|
| 136 |
+
out_channels
|
| 137 |
+
),
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
def forward(self, x, skip_x, t):
|
| 141 |
+
x = self.up(x)
|
| 142 |
+
x = torch.cat([skip_x, x], dim=1)
|
| 143 |
+
x = self.conv(x)
|
| 144 |
+
emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
|
| 145 |
+
return x + emb
|
| 146 |
+
|
| 147 |
+
class PaletteModelV2(nn.Module):
|
| 148 |
+
def __init__(self, c_in=1, c_out=1, image_size=64, time_dim=256, device='cuda', latent=False, true_img_size=64, num_classes=None):
|
| 149 |
+
super(PaletteModelV2, self).__init__()
|
| 150 |
+
|
| 151 |
+
# Encoder
|
| 152 |
+
self.true_img_size = true_img_size
|
| 153 |
+
self.image_size = image_size
|
| 154 |
+
self.time_dim = time_dim
|
| 155 |
+
self.device = device
|
| 156 |
+
self.inc = DoubleConv(c_in, self.image_size) # Wrap-up for 2 Conv Layers
|
| 157 |
+
self.down1 = Down(self.image_size, self.image_size*2) # input and output channels
|
| 158 |
+
# self.sa1 = SelfAttention(self.image_size*2,int( self.true_img_size/2)) # 1st is channel dim, 2nd current image resolution
|
| 159 |
+
self.down2 = Down(self.image_size*2, self.image_size*4)
|
| 160 |
+
# self.sa2 = SelfAttention(self.image_size*4, int(self.true_img_size/4))
|
| 161 |
+
self.down3 = Down(self.image_size*4, self.image_size*4)
|
| 162 |
+
# self.sa3 = SelfAttention(self.image_size*4, int(self.true_img_size/8))
|
| 163 |
+
|
| 164 |
+
# Bootleneck
|
| 165 |
+
self.bot1 = DoubleConv(self.image_size*4, self.image_size*8)
|
| 166 |
+
self.bot2 = DoubleConv(self.image_size*8, self.image_size*8)
|
| 167 |
+
self.bot3 = DoubleConv(self.image_size*8, self.image_size*4)
|
| 168 |
+
|
| 169 |
+
# Decoder: reverse of encoder
|
| 170 |
+
self.up1 = Up(self.image_size*8, self.image_size*2)
|
| 171 |
+
# self.sa4 = SelfAttention(self.image_size*2, int(self.true_img_size/4))
|
| 172 |
+
self.up2 = Up(self.image_size*4, self.image_size)
|
| 173 |
+
# self.sa5 = SelfAttention(self.image_size, int(self.true_img_size/2))
|
| 174 |
+
self.up3 = Up(self.image_size*2, self.image_size)
|
| 175 |
+
# self.sa6 = SelfAttention(self.image_size, self.true_img_size)
|
| 176 |
+
self.outc = nn.Conv2d(self.image_size, c_out, kernel_size=1) # projecting back to the output channel dimensions
|
| 177 |
+
|
| 178 |
+
if num_classes is not None:
|
| 179 |
+
self.label_emb = nn.Embedding(num_classes, time_dim)
|
| 180 |
+
|
| 181 |
+
if latent == True:
|
| 182 |
+
self.latent = nn.Sequential(
|
| 183 |
+
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
|
| 184 |
+
nn.LeakyReLU(0.2),
|
| 185 |
+
nn.MaxPool2d(kernel_size=2, stride=2),
|
| 186 |
+
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
|
| 187 |
+
nn.LeakyReLU(0.2),
|
| 188 |
+
nn.MaxPool2d(kernel_size=2, stride=2),
|
| 189 |
+
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
|
| 190 |
+
nn.LeakyReLU(0.2),
|
| 191 |
+
nn.MaxPool2d(kernel_size=2, stride=2),
|
| 192 |
+
nn.Flatten(),
|
| 193 |
+
nn.Linear(64 * 8 * 8, 256)).to(device)
|
| 194 |
+
|
| 195 |
+
def pos_encoding(self, t, channels):
|
| 196 |
+
"""
|
| 197 |
+
Input noised images and the timesteps. The timesteps will only be
|
| 198 |
+
a tensor with the integer timesteps values in it
|
| 199 |
+
"""
|
| 200 |
+
inv_freq = 1.0 / (
|
| 201 |
+
10000
|
| 202 |
+
** (torch.arange(0, channels, 2, device=self.device).float() / channels)
|
| 203 |
+
)
|
| 204 |
+
pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
|
| 205 |
+
pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
|
| 206 |
+
pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
|
| 207 |
+
return pos_enc
|
| 208 |
+
|
| 209 |
+
def forward(self, x, y, lab, t):
|
| 210 |
+
# Pass the source image through the encoder network
|
| 211 |
+
t = t.unsqueeze(-1).type(torch.float)
|
| 212 |
+
t = self.pos_encoding(t, self.time_dim) # Encoding timesteps is HERE, we provide the dimension we want to encode
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
if lab is not None:
|
| 216 |
+
t += self.label_emb(lab)
|
| 217 |
+
|
| 218 |
+
# t += self.latent(y)
|
| 219 |
+
|
| 220 |
+
# Concatenate the source image and reference image
|
| 221 |
+
x = torch.cat([x, y], dim=1)
|
| 222 |
+
|
| 223 |
+
x1 = self.inc(x)
|
| 224 |
+
x2 = self.down1(x1, t)
|
| 225 |
+
# x2 = self.sa1(x2)
|
| 226 |
+
x3 = self.down2(x2, t)
|
| 227 |
+
# x3 = self.sa2(x3)
|
| 228 |
+
x4 = self.down3(x3, t)
|
| 229 |
+
# x4 = self.sa3(x4)
|
| 230 |
+
|
| 231 |
+
x4 = self.bot1(x4)
|
| 232 |
+
x4 = self.bot2(x4)
|
| 233 |
+
x4 = self.bot3(x4)
|
| 234 |
+
|
| 235 |
+
x = self.up1(x4, x3, t) # We note that upsampling box that in the skip connections from encoder
|
| 236 |
+
# x = self.sa4(x)
|
| 237 |
+
x = self.up2(x, x2, t)
|
| 238 |
+
# x = self.sa5(x)
|
| 239 |
+
x = self.up3(x, x1, t)
|
| 240 |
+
# x = self.sa6(x)
|
| 241 |
+
output = self.outc(x)
|
| 242 |
+
|
| 243 |
+
return output
|