import torch from torch import nn import torch.nn.functional as F class VAE(nn.Module): def __init__(self, in_channels=3, latent_dim=200, hidden_dims=None): super(VAE, self).__init__() self.latent_dim = latent_dim modules = [] if hidden_dims is None: # ابعاد استاندارد برای یک VAE روی تصاویر 64x64 hidden_dims = [32, 64, 128, 256, 512] # ------------------ # ساخت Encoder # ------------------ for h_dim in hidden_dims: modules.append( nn.Sequential( nn.Conv2d(in_channels, out_channels=h_dim, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(h_dim), nn.LeakyReLU()) ) in_channels = h_dim self.encoder = nn.Sequential(*modules) self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) # ------------------ # ساخت Decoder # ------------------ modules = [] self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1]*4) hidden_dims.reverse() # معکوس کردن لیست برای دیکودر for i in range(len(hidden_dims) - 1): modules.append( nn.Sequential( nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i+1], kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[i+1]), nn.LeakyReLU()) ) self.decoder = nn.Sequential(*modules) self.final_layer = nn.Sequential( nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[-1]), nn.LeakyReLU(), nn.Conv2d(hidden_dims[-1], out_channels=3, kernel_size=3, padding=1), nn.Sigmoid() ) def encode(self, x): """ورودی را انکود کرده و پارامترهای فضای نهان را برمی‌گرداند""" result = self.encoder(x) result = torch.flatten(result, start_dim=1) # تقسیم نتیجه به دو بخش میانگین و لگاریتم واریانس mu = self.fc_mu(result) log_var = self.fc_var(result) return [mu, log_var] def decode(self, z): """بردار نهان z را به یک تصویر دیکود می‌کند""" result = self.decoder_input(z) result = result.view(-1, 512, 2, 2) result = self.decoder(result) result = self.final_layer(result) return result def reparameterize(self, mu, logvar): """ترفند Reparameterization برای نمونه‌برداری""" std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps * std + mu def forward(self, input, **kwargs): """یک پاس کامل در شبکه""" mu, log_var = self.encode(input) z = self.reparameterize(mu, log_var) return [self.decode(z), input, mu, log_var] def loss_f(self, *args, **kwargs): """محاسبه خطای VAE""" recons = args[0] input = args[1] mu = args[2] log_var = args[3] kld_weight = kwargs['KW'] # وزن خطای KL recons_loss = F.mse_loss(recons, input) kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0) loss = recons_loss + kld_weight * kld_loss return loss, recons_loss, kld_loss def sample(self, num_samples, current_device, **kwargs): """تولید نمونه‌های جدید از فضای نهان""" z = torch.randn(num_samples, self.latent_dim) z = z.to(current_device) samples = self.decode(z) return samples def generate(self, x, **kwargs): """تولید یک تصویر از روی یک تصویر ورودی""" return self.forward(x, **kwargs)[0]