mmdbes's picture
Upload 2 files
056c4b5 verified
import torch
from torch import nn
import torch.nn.functional as F
class VAE(nn.Module):
def __init__(self, in_channels=3, latent_dim=200, hidden_dims=None):
super(VAE, self).__init__()
self.latent_dim = latent_dim
modules = []
if hidden_dims is None:
# ابعاد استاندارد برای یک VAE روی تصاویر 64x64
hidden_dims = [32, 64, 128, 256, 512]
# ------------------
# ساخت Encoder
# ------------------
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU())
)
in_channels = h_dim
self.encoder = nn.Sequential(*modules)
self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)
# ------------------
# ساخت Decoder
# ------------------
modules = []
self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1]*4)
hidden_dims.reverse() # معکوس کردن لیست برای دیکودر
for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i+1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[i+1]),
nn.LeakyReLU())
)
self.decoder = nn.Sequential(*modules)
self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], out_channels=3, kernel_size=3, padding=1),
nn.Sigmoid()
)
def encode(self, x):
"""ورودی را انکود کرده و پارامترهای فضای نهان را برمی‌گرداند"""
result = self.encoder(x)
result = torch.flatten(result, start_dim=1)
# تقسیم نتیجه به دو بخش میانگین و لگاریتم واریانس
mu = self.fc_mu(result)
log_var = self.fc_var(result)
return [mu, log_var]
def decode(self, z):
"""بردار نهان z را به یک تصویر دیکود می‌کند"""
result = self.decoder_input(z)
result = result.view(-1, 512, 2, 2)
result = self.decoder(result)
result = self.final_layer(result)
return result
def reparameterize(self, mu, logvar):
"""ترفند Reparameterization برای نمونه‌برداری"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
def forward(self, input, **kwargs):
"""یک پاس کامل در شبکه"""
mu, log_var = self.encode(input)
z = self.reparameterize(mu, log_var)
return [self.decode(z), input, mu, log_var]
def loss_f(self, *args, **kwargs):
"""محاسبه خطای VAE"""
recons = args[0]
input = args[1]
mu = args[2]
log_var = args[3]
kld_weight = kwargs['KW'] # وزن خطای KL
recons_loss = F.mse_loss(recons, input)
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0)
loss = recons_loss + kld_weight * kld_loss
return loss, recons_loss, kld_loss
def sample(self, num_samples, current_device, **kwargs):
"""تولید نمونه‌های جدید از فضای نهان"""
z = torch.randn(num_samples, self.latent_dim)
z = z.to(current_device)
samples = self.decode(z)
return samples
def generate(self, x, **kwargs):
"""تولید یک تصویر از روی یک تصویر ورودی"""
return self.forward(x, **kwargs)[0]