DaniKaEp commited on
Commit
991bb01
·
verified ·
1 Parent(s): 6b9a350

Adding app, data and model files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ df_vae_encoding_April16_all.csv filter=lfs diff=lfs merge=lfs -text
VAE_model_tablets_class.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ import pytorch_lightning as pl
5
+
6
+ class Flatten(nn.Module):
7
+ def forward(self, x):
8
+ return x.view(x.size(0), -1)
9
+
10
+ class UnFlatten(nn.Module):
11
+ def forward(self, x):
12
+ # Adjusted to match the output of the encoder
13
+ return x.view(x.size(0), 256, 16, 16) # Adjusted dimensions
14
+
15
+ class VAE(pl.LightningModule):
16
+ def __init__(self, image_channels=1, h_dim=16*16*256, z_dim=12, lr=1e-3, beta=1, use_classification_loss=True,
17
+ num_classes=None, loss_type="standard", class_weights=None, device=None):
18
+ super(VAE, self).__init__()
19
+ self.lr = lr
20
+ self.beta = beta
21
+ self.use_classification_loss = use_classification_loss
22
+
23
+ # Adjusted encoder for 512x512 input
24
+ self.encoder = nn.Sequential(
25
+ nn.Conv2d(image_channels, 32, kernel_size=5, stride=2, padding=2), # 256x256
26
+ nn.BatchNorm2d(32),
27
+ nn.LeakyReLU(),
28
+ nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=2), # 128x128
29
+ nn.BatchNorm2d(64),
30
+ nn.LeakyReLU(),
31
+ nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2), # 64x64
32
+ nn.BatchNorm2d(128),
33
+ nn.LeakyReLU(),
34
+ nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2), # 32x32
35
+ nn.BatchNorm2d(256),
36
+ nn.LeakyReLU(),
37
+ nn.Conv2d(256, 256, kernel_size=5, stride=2, padding=2), # 16x16
38
+ nn.BatchNorm2d(256),
39
+ nn.LeakyReLU(),
40
+ Flatten()
41
+ )
42
+
43
+ self.fc1 = nn.Linear(h_dim, z_dim) # For mu
44
+ self.fc2 = nn.Linear(h_dim, z_dim) # For logvar
45
+ self.fc3 = nn.Linear(z_dim, h_dim) # For reconstruction
46
+
47
+ # Adjusted decoder for reconstructing 512x512 output
48
+ self.decoder = nn.Sequential(
49
+ UnFlatten(),
50
+ nn.ConvTranspose2d(256, 256, kernel_size=5, stride=2, padding=2, output_padding=1), # 32x32
51
+ nn.BatchNorm2d(256),
52
+ nn.LeakyReLU(),
53
+ nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2, padding=2, output_padding=1), # 64x64
54
+ nn.BatchNorm2d(128),
55
+ nn.LeakyReLU(),
56
+ nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1), # 128x128
57
+ nn.BatchNorm2d(64),
58
+ nn.LeakyReLU(),
59
+ nn.ConvTranspose2d(64, 32, kernel_size=5, stride=2, padding=2, output_padding=1), # 256x256
60
+ nn.BatchNorm2d(32),
61
+ nn.LeakyReLU(),
62
+ nn.ConvTranspose2d(32, image_channels, kernel_size=5, stride=2, padding=2, output_padding=1), # 512x512
63
+ nn.BatchNorm2d(image_channels),
64
+ nn.Sigmoid(),
65
+ )
66
+
67
+ self.loss_type = loss_type
68
+ if use_classification_loss:
69
+ if loss_type == "standard":
70
+ self.criterion = nn.CrossEntropyLoss()
71
+ elif loss_type == "weighted":
72
+ # Check if class weights are provided
73
+ if class_weights is None:
74
+ raise ValueError("For weighted loss, class_weights must be provided.")
75
+ self.class_weights = torch.tensor(class_weights).to(device)
76
+ self.criterion = nn.CrossEntropyLoss(weight=self.class_weights)
77
+ elif loss_type == "focal":
78
+ self.criterion = FocalLoss()
79
+ else:
80
+ raise ValueError(f"Unknown loss_type: {loss_type}")
81
+
82
+
83
+ if self.use_classification_loss:
84
+ assert num_classes is not None, "num_classes must be provided if use_classification_loss is True."
85
+ self.fc_classify = nn.Sequential(
86
+ nn.Linear(z_dim, num_classes),
87
+ nn.Softmax(dim=1)
88
+ )
89
+
90
+ def reparameterize(self, mu, logvar):
91
+ std = logvar.mul(0.5).exp_()
92
+ eps = torch.randn_like(std).to(std.device)
93
+ z = mu + std * eps
94
+ return z
95
+
96
+ def bottleneck(self, h):
97
+ mu, logvar = self.fc1(h), self.fc2(h)
98
+ z = self.reparameterize(mu, logvar)
99
+ if self.use_classification_loss:
100
+ class_logits = self.fc_classify(z)
101
+ return z, mu, logvar, class_logits
102
+ return z, mu, logvar
103
+
104
+ def forward(self, x):
105
+ if self.use_classification_loss:
106
+ z, mu, logvar, class_logits = self.bottleneck(self.encoder(x))
107
+ z = self.fc3(z)
108
+ return [self.decoder(z), mu, logvar, class_logits]
109
+ else:
110
+ z, mu, logvar = self.bottleneck(self.encoder(x))
111
+ z = self.fc3(z)
112
+ return [self.decoder(z), mu, logvar]
113
+
114
+ def loss_function(self,recons,x,mu,logvar):
115
+ # Account for the minibatch samples from the dataset; M_N = self.params['batch_size']/ self.num_train_imgs
116
+ recons_loss =F.mse_loss(recons, x,reduction="sum")
117
+ kld_loss = torch.sum(-0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp(), dim = 1), dim = 0)
118
+ loss = recons_loss + self.beta * kld_loss
119
+ return loss
120
+
121
+ def classification_loss(self, logits, labels):
122
+ if self.loss_type == "standard":
123
+ return F.cross_entropy(logits, labels)
124
+ else: # For both "weighted" and "focal"
125
+ return self.criterion(logits, labels)
126
+
127
+ def configure_optimizers(self):
128
+ return torch.optim.Adam(self.parameters(), lr=self.lr)
129
+
130
+ def training_step(self, train_batch, batch_idx):
131
+ x, y = train_batch
132
+ outputs = self(x)
133
+
134
+ recon, mu, logvar = outputs[:3]
135
+ recon_loss = self.loss_function(recon, x, mu, logvar)
136
+
137
+ if self.use_classification_loss:
138
+ class_logits = outputs[3]
139
+ class_loss = self.classification_loss(class_logits, y)
140
+ self.log('train_class_loss', class_loss)
141
+
142
+ total_loss = 0.5 * recon_loss + 0.5 * class_loss
143
+ self.log('train_recon_loss', recon_loss)
144
+ self.log('train_total_loss', total_loss)
145
+ return total_loss
146
+
147
+ def representation(self, x):
148
+ return self.bottleneck(self.encoder(x))[0]
149
+
150
+ def validation_step(self, val_batch, batch_idx):
151
+ x, y = val_batch
152
+ outputs = self(x)
153
+
154
+ recon, mu, logvar = outputs[:3]
155
+ recon_loss = self.loss_function(recon, x, mu, logvar)
156
+
157
+ if self.use_classification_loss:
158
+ class_logits = outputs[3]
159
+ class_loss = self.classification_loss(class_logits, y)
160
+ self.log('val_class_loss', class_loss)
161
+
162
+ total_loss = 0.5 * recon_loss + 0.5 * class_loss
163
+ self.log('val_recon_loss', recon_loss)
164
+ self.log('val_total_loss', total_loss)
165
+ return total_loss
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from PIL import Image as PILImage
6
+ import torch
7
+
8
+ from era_data import TabletPeriodDataset
9
+ from VAE_model_tablets_class import VAE
10
+
11
+ import gradio as gr
12
+
13
+
14
+ model = VAE()
15
+ model.load_state_dict(torch.load('epoch=22-step=213621.ckpt'))
16
+ model.eval()
17
+
18
+ # Load your dataframe encoding
19
+ df_encodings = pd.read_csv('df_vae_encoding_April16_all.csv')
20
+ df_means = df_encodings.drop(["Period", "Genre", "Genre_Name", "CDLI_id"], axis = 1).groupby("Period_Name").mean().reset_index()
21
+ period_names = df_means['Period_Name'].unique()
22
+
23
+ def get_image_from_period(period_name):
24
+ period_data = torch.from_numpy(df_means[df_means["Period_Name"] == period_name].drop(["Period_Name"], axis=1).values[0].astype('float32'))
25
+ return period_data
26
+
27
+ def generate_image(period1, period2, interpolation_value):
28
+ image1 = get_image_from_period(period1)
29
+ image2 = get_image_from_period(period2)
30
+
31
+ i = interpolation_value
32
+ new_tablet = (1-i) * image1 + i * image2
33
+ new_tab_long = model.fc3(new_tablet).unsqueeze(0)
34
+
35
+ with torch.no_grad():
36
+ generated_image = model.decoder(new_tab_long)
37
+ generated_image = generated_image[0][0].detach().cpu().numpy()
38
+ generated_image = (generated_image * 255).astype(np.uint8)
39
+ pil_img = PILImage.fromarray(generated_image)
40
+ img_byte_arr = io.BytesIO()
41
+ pil_img.save(img_byte_arr, format='PNG')
42
+ return img_byte_arr.getvalue()
43
+
44
+ # Define Gradio interface
45
+ iface = gr.Interface(
46
+ fn=generate_image,
47
+ inputs=[
48
+ gr.Dropdown(choices=period_names.tolist(), label="Period 1"),
49
+ gr.Dropdown(choices=period_names.tolist(), label="Period 2"),
50
+ gr.Slider(0, 1, step=0.1, label="Interpolation")
51
+ ],
52
+ outputs=gr.Image(label="Generated Image")
53
+ )
54
+
55
+ if __name__ == "__main__":
56
+ iface.launch()
df_vae_encoding_April16_all.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:edf0599719afb46d959a04ed8aafe7015a5321f56adc03a374166947c9b09e32
3
+ size 13648966
epoch=22-step=213621.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f55d0199f326978670388de691fa334bc4df34ac4da13f0e9dc1734e5f1dea1e
3
+ size 94369023
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ pytorch_lightning
4
+ ipywidgets
5
+ numpy
6
+ pandas
7
+ PIL