File size: 1,524 Bytes
592789a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51

import torch
import torch.nn as nn

# Constantes utilisées lors de l'entraînement
LATENT_DIM = 32
COND_DIM = 4

class SimpleVAE(nn.Module):
    # Classe SimpleVAE complète (Définition de la structure du Modèle d'IA)
    def __init__(self):
        super(SimpleVAE, self).__init__()
        
        # ENCODEUR
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(16, 32, 4, 2, 1), nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * 16 * 16, 128), nn.ReLU(),
            nn.Linear(128, LATENT_DIM * 2)
        )
        
        # DÉCODEUR
        self.decoder_input = nn.Linear(LATENT_DIM + COND_DIM, 32 * 16 * 16)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(16, 3, 4, 2, 1), nn.Sigmoid()
        )

    def encode(self, x):
        result = self.encoder(x)
        mu, logvar = result.chunk(2, dim=1)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def decode(self, z, conditions):
        z_c = torch.cat([z, conditions], dim=1)
        h = self.decoder_input(z_c)
        h = h.view(-1, 32, 16, 16)
        return self.decoder(h)

    def forward(self, x, conditions):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z, conditions)
        return recon_x, mu, logvar