Spaces:
Runtime error
Runtime error
jiawei-ren commited on
Commit ·
f404def
1
Parent(s): 8a01af6
init
Browse files
app.py
CHANGED
|
@@ -18,7 +18,6 @@ Y_LB = 0
|
|
| 18 |
K = 1
|
| 19 |
B = 0
|
| 20 |
NUM_SEG = 5
|
| 21 |
-
sns.set_theme(palette='colorblind')
|
| 22 |
NUM_EPOCHS = 100
|
| 23 |
PRINT_FREQ = NUM_EPOCHS // 20
|
| 24 |
NUM_TRAIN_SAMPLES = NUM_PER_BUCKET * NUM_SEG
|
|
@@ -76,8 +75,8 @@ def unzip_dataloader(training_loader):
|
|
| 76 |
return all_x, all_y
|
| 77 |
|
| 78 |
|
| 79 |
-
def train(train_loader, training_bundle, num_epochs):
|
| 80 |
-
training_df
|
| 81 |
for epoch in range(num_epochs):
|
| 82 |
for model, optimizer, scheduler, criterion, criterion_name in training_bundle:
|
| 83 |
model.train()
|
|
@@ -92,23 +91,19 @@ def train(train_loader, training_bundle, num_epochs):
|
|
| 92 |
optimizer.step()
|
| 93 |
scheduler.step()
|
| 94 |
if (epoch + 1) % PRINT_FREQ == 0:
|
| 95 |
-
|
|
|
|
| 96 |
|
| 97 |
|
| 98 |
-
def
|
| 99 |
df = df_oracle
|
| 100 |
for model, optimizer, scheduler, criterion, criterion_name in training_bundle:
|
| 101 |
model.eval()
|
| 102 |
y = model(X_demo)
|
| 103 |
df = df.append(make_dataframe(X_demo, y, criterion_name), ignore_index=True)
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
plt.ylim(Y_LB, Y_UB)
|
| 108 |
-
plt.gca().axes.set_xlabel(r'$x$', fontsize=10)
|
| 109 |
-
plt.gca().axes.set_ylabel(r'$y$', fontsize=10)
|
| 110 |
-
plt.savefig('train_log/{:05d}.png'.format(epoch + 1), bbox_inches='tight')
|
| 111 |
-
plt.close()
|
| 112 |
|
| 113 |
|
| 114 |
def make_video():
|
|
@@ -178,7 +173,7 @@ def bmc_loss(pred, target, noise_var):
|
|
| 178 |
return loss * (2 * noise_var)
|
| 179 |
|
| 180 |
|
| 181 |
-
def regress(train_loader):
|
| 182 |
training_bundle = []
|
| 183 |
criterions = {
|
| 184 |
'MSE': torch.nn.MSELoss(),
|
|
@@ -189,7 +184,7 @@ def regress(train_loader):
|
|
| 189 |
criterion = criterions[criterion_name]
|
| 190 |
model, optimizer, scheduler = prepare_model()
|
| 191 |
training_bundle.append((model, optimizer, scheduler, criterion, criterion_name))
|
| 192 |
-
train(train_loader, training_bundle, NUM_EPOCHS)
|
| 193 |
|
| 194 |
|
| 195 |
class DummyDataset(Dataset):
|
|
@@ -205,22 +200,31 @@ class DummyDataset(Dataset):
|
|
| 205 |
return len(self.inputs)
|
| 206 |
|
| 207 |
|
| 208 |
-
def
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
plt.gca().axes.set_xlabel(r'$x$', fontsize=10)
|
| 221 |
plt.gca().axes.set_ylabel(r'$y$', fontsize=10)
|
| 222 |
-
|
| 223 |
-
plt.
|
|
|
|
| 224 |
|
| 225 |
|
| 226 |
def clean_up_logs():
|
|
@@ -228,8 +232,9 @@ def clean_up_logs():
|
|
| 228 |
os.mkdir('train_log')
|
| 229 |
for f in os.listdir('train_log'):
|
| 230 |
os.remove(osp.join('train_log', f))
|
| 231 |
-
|
| 232 |
-
|
|
|
|
| 233 |
|
| 234 |
|
| 235 |
def run(num1, num2, num3, num4, num5, random_seed, submit):
|
|
@@ -238,19 +243,22 @@ def run(num1, num2, num3, num4, num5, random_seed, submit):
|
|
| 238 |
torch.manual_seed(int(random_seed))
|
| 239 |
all_x, all_y, prob = prepare_data(sel_num)
|
| 240 |
train_loader = DataLoader(DummyDataset(all_x, all_y, prob), BATCH_SIZE, shuffle=True)
|
| 241 |
-
|
| 242 |
|
|
|
|
| 243 |
if submit == 0:
|
| 244 |
-
|
| 245 |
-
else:
|
| 246 |
-
text = "Press \"Prepare Training Data\" before changing the training data. You may also change the random seed."
|
| 247 |
if submit == 1:
|
| 248 |
-
|
| 249 |
-
regress(train_loader)
|
| 250 |
make_video()
|
| 251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
video = "movie.mp4" if submit == 1 else None
|
| 253 |
-
return
|
| 254 |
|
| 255 |
|
| 256 |
if __name__ == '__main__':
|
|
@@ -268,9 +276,9 @@ if __name__ == '__main__':
|
|
| 268 |
],
|
| 269 |
outputs=[
|
| 270 |
gr.outputs.Image(type="file", label="Training data"),
|
| 271 |
-
gr.outputs.Textbox(type="auto", label='What\' s next?'),
|
| 272 |
gr.outputs.Image(type="file", label="Regression result"),
|
| 273 |
-
gr.outputs.Video(type='mp4', label='Training process')
|
|
|
|
| 274 |
],
|
| 275 |
live=True,
|
| 276 |
allow_flagging='never',
|
|
@@ -282,6 +290,7 @@ if __name__ == '__main__':
|
|
| 282 |
[0.1, 0.8, 6.4, 51.2, 100, 0, 'Prepare Training Data'],
|
| 283 |
[1, 10, 100, 10, 1, 0, 'Prepare Training Data'],
|
| 284 |
],
|
| 285 |
-
|
|
|
|
| 286 |
)
|
| 287 |
iface.launch()
|
|
|
|
| 18 |
K = 1
|
| 19 |
B = 0
|
| 20 |
NUM_SEG = 5
|
|
|
|
| 21 |
NUM_EPOCHS = 100
|
| 22 |
PRINT_FREQ = NUM_EPOCHS // 20
|
| 23 |
NUM_TRAIN_SAMPLES = NUM_PER_BUCKET * NUM_SEG
|
|
|
|
| 75 |
return all_x, all_y
|
| 76 |
|
| 77 |
|
| 78 |
+
def train(train_loader, training_df, training_bundle, num_epochs):
|
| 79 |
+
visualize_training_process(training_df, training_bundle, -1)
|
| 80 |
for epoch in range(num_epochs):
|
| 81 |
for model, optimizer, scheduler, criterion, criterion_name in training_bundle:
|
| 82 |
model.train()
|
|
|
|
| 91 |
optimizer.step()
|
| 92 |
scheduler.step()
|
| 93 |
if (epoch + 1) % PRINT_FREQ == 0:
|
| 94 |
+
visualize_training_process(training_df, training_bundle, epoch)
|
| 95 |
+
visualize_training_process(training_df, training_bundle, num_epochs, final=True)
|
| 96 |
|
| 97 |
|
| 98 |
+
def visualize_training_process(training_df, training_bundle, epoch, final=False):
|
| 99 |
df = df_oracle
|
| 100 |
for model, optimizer, scheduler, criterion, criterion_name in training_bundle:
|
| 101 |
model.eval()
|
| 102 |
y = model(X_demo)
|
| 103 |
df = df.append(make_dataframe(X_demo, y, criterion_name), ignore_index=True)
|
| 104 |
+
visualize(training_df, df, 'train_log/{:05d}.png'.format(epoch + 1), fast=True)
|
| 105 |
+
if final:
|
| 106 |
+
visualize(training_df, df, 'regression_result.png', fast=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
|
| 109 |
def make_video():
|
|
|
|
| 173 |
return loss * (2 * noise_var)
|
| 174 |
|
| 175 |
|
| 176 |
+
def regress(train_loader, training_df):
|
| 177 |
training_bundle = []
|
| 178 |
criterions = {
|
| 179 |
'MSE': torch.nn.MSELoss(),
|
|
|
|
| 184 |
criterion = criterions[criterion_name]
|
| 185 |
model, optimizer, scheduler = prepare_model()
|
| 186 |
training_bundle.append((model, optimizer, scheduler, criterion, criterion_name))
|
| 187 |
+
train(train_loader, training_df, training_bundle, NUM_EPOCHS)
|
| 188 |
|
| 189 |
|
| 190 |
class DummyDataset(Dataset):
|
|
|
|
| 200 |
return len(self.inputs)
|
| 201 |
|
| 202 |
|
| 203 |
+
def visualize(training_df, df, save_path, fast=False):
|
| 204 |
+
if fast:
|
| 205 |
+
g_line = sns.lineplot(data=df, x='x', y='y', hue='Method', estimator=None, ci=None)
|
| 206 |
+
plt.xlim((Y_LB - B) / K, (Y_UB - B) / K)
|
| 207 |
+
plt.ylim(Y_LB, Y_UB)
|
| 208 |
+
else:
|
| 209 |
+
g = sns.jointplot(data=training_df, x='x', y='y', color='#003ea1', alpha=0.1, linewidths=0, s=100,
|
| 210 |
+
marginal_kws=dict(bins=torch.linspace(Y_LB, Y_UB, steps=NUM_SEG + 1)),
|
| 211 |
+
xlim=((Y_LB - B) / K, (Y_UB - B) / K),
|
| 212 |
+
ylim=(Y_LB, Y_UB),
|
| 213 |
+
space=0.1,
|
| 214 |
+
height=8,
|
| 215 |
+
ratio=2,
|
| 216 |
+
estimator=None, ci=None,
|
| 217 |
+
legend=False
|
| 218 |
+
)
|
| 219 |
+
g.ax_marg_x.remove()
|
| 220 |
+
g_line = sns.lineplot(data=df, x='x', y='y', hue='Method', ax=g.ax_joint, estimator=None, ci=None)
|
| 221 |
+
g_line.legend_.set_title(None)
|
| 222 |
+
g_line.legend(loc='upper left')
|
| 223 |
plt.gca().axes.set_xlabel(r'$x$', fontsize=10)
|
| 224 |
plt.gca().axes.set_ylabel(r'$y$', fontsize=10)
|
| 225 |
+
|
| 226 |
+
plt.savefig(save_path, bbox_inches='tight')
|
| 227 |
+
plt.clf()
|
| 228 |
|
| 229 |
|
| 230 |
def clean_up_logs():
|
|
|
|
| 232 |
os.mkdir('train_log')
|
| 233 |
for f in os.listdir('train_log'):
|
| 234 |
os.remove(osp.join('train_log', f))
|
| 235 |
+
for f in ['regression_result.png', 'training_data.png', 'movie.mp4']:
|
| 236 |
+
if osp.isfile(f):
|
| 237 |
+
os.remove(f)
|
| 238 |
|
| 239 |
|
| 240 |
def run(num1, num2, num3, num4, num5, random_seed, submit):
|
|
|
|
| 243 |
torch.manual_seed(int(random_seed))
|
| 244 |
all_x, all_y, prob = prepare_data(sel_num)
|
| 245 |
train_loader = DataLoader(DummyDataset(all_x, all_y, prob), BATCH_SIZE, shuffle=True)
|
| 246 |
+
training_df = make_dataframe(all_x, all_y)
|
| 247 |
|
| 248 |
+
clean_up_logs()
|
| 249 |
if submit == 0:
|
| 250 |
+
visualize(training_df, df_oracle, 'training_data.png')
|
|
|
|
|
|
|
| 251 |
if submit == 1:
|
| 252 |
+
regress(train_loader, training_df)
|
|
|
|
| 253 |
make_video()
|
| 254 |
+
if submit == 0:
|
| 255 |
+
text = "Press \"Start Regressing\" if your are happy with the training data. Regression takes ~10s."
|
| 256 |
+
else:
|
| 257 |
+
text = "Press \"Prepare Training Data\" before moving the sliders. You may also change the random seed."
|
| 258 |
+
training_data_plot = 'training_data.png' if submit == 0 else None
|
| 259 |
+
output = 'regression_result.png'.format(NUM_EPOCHS) if submit == 1 else None
|
| 260 |
video = "movie.mp4" if submit == 1 else None
|
| 261 |
+
return training_data_plot, output, video, text
|
| 262 |
|
| 263 |
|
| 264 |
if __name__ == '__main__':
|
|
|
|
| 276 |
],
|
| 277 |
outputs=[
|
| 278 |
gr.outputs.Image(type="file", label="Training data"),
|
|
|
|
| 279 |
gr.outputs.Image(type="file", label="Regression result"),
|
| 280 |
+
gr.outputs.Video(type='mp4', label='Training process'),
|
| 281 |
+
gr.outputs.Textbox(type="auto", label='What\' s next?')
|
| 282 |
],
|
| 283 |
live=True,
|
| 284 |
allow_flagging='never',
|
|
|
|
| 290 |
[0.1, 0.8, 6.4, 51.2, 100, 0, 'Prepare Training Data'],
|
| 291 |
[1, 10, 100, 10, 1, 0, 'Prepare Training Data'],
|
| 292 |
],
|
| 293 |
+
css = ".output-image, .image-preview {height: 500px !important}",
|
| 294 |
+
article="<p style='text-align: center'><a href='https://github.com/jiawei-ren/BalancedMSE' target='_blank'>Balanced MSE @ GitHub</a></p> "
|
| 295 |
)
|
| 296 |
iface.launch()
|