Spaces:
Runtime error
Runtime error
jiawei-ren commited on
Commit ·
9c31709
1
Parent(s): ebf41b0
init
Browse files
app.py
CHANGED
|
@@ -92,7 +92,7 @@ def train(train_loader, training_df, training_bundle, num_epochs):
|
|
| 92 |
scheduler.step()
|
| 93 |
if (epoch + 1) % PRINT_FREQ == 0:
|
| 94 |
visualize_training_process(training_df, training_bundle, epoch)
|
| 95 |
-
visualize_training_process(training_df, training_bundle, num_epochs, final=True)
|
| 96 |
|
| 97 |
|
| 98 |
def visualize_training_process(training_df, training_bundle, epoch, final=False):
|
|
@@ -101,7 +101,7 @@ def visualize_training_process(training_df, training_bundle, epoch, final=False)
|
|
| 101 |
model.eval()
|
| 102 |
y = model(X_demo)
|
| 103 |
df = df.append(make_dataframe(X_demo, y, criterion_name), ignore_index=True)
|
| 104 |
-
visualize(training_df, df, 'train_log/{:05d}.png'.format(epoch + 1), fast=True)
|
| 105 |
if final:
|
| 106 |
visualize(training_df, df, 'regression_result.png', fast=False)
|
| 107 |
|
|
@@ -200,7 +200,7 @@ class DummyDataset(Dataset):
|
|
| 200 |
return len(self.inputs)
|
| 201 |
|
| 202 |
|
| 203 |
-
def visualize(training_df, df, save_path, fast=False):
|
| 204 |
if fast:
|
| 205 |
f = plt.figure(figsize=(3, 3))
|
| 206 |
g = f.add_subplot(111)
|
|
@@ -208,24 +208,26 @@ def visualize(training_df, df, save_path, fast=False):
|
|
| 208 |
plt.xlim((Y_LB - B) / K, (Y_UB - B) / K)
|
| 209 |
plt.ylim(Y_LB, Y_UB)
|
| 210 |
else:
|
| 211 |
-
g = sns.jointplot(data=training_df, x='x', y='y', color='#003ea1', alpha=0.1, linewidths=0, s=
|
| 212 |
marginal_kws=dict(bins=torch.linspace(Y_LB, Y_UB, steps=NUM_SEG + 1)),
|
| 213 |
xlim=((Y_LB - B) / K, (Y_UB - B) / K),
|
| 214 |
ylim=(Y_LB, Y_UB),
|
| 215 |
space=0.1,
|
| 216 |
-
height=
|
| 217 |
ratio=2,
|
| 218 |
estimator=None, ci=None,
|
| 219 |
legend=False,
|
| 220 |
)
|
| 221 |
g.ax_marg_x.remove()
|
| 222 |
g_line = sns.lineplot(data=df, x='x', y='y', hue='Method', ax=g.ax_joint, estimator=None, ci=None)
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
|
|
|
|
|
|
| 227 |
|
| 228 |
-
plt.savefig(save_path, bbox_inches='tight')
|
| 229 |
plt.close()
|
| 230 |
|
| 231 |
|
|
@@ -267,11 +269,11 @@ if __name__ == '__main__':
|
|
| 267 |
iface = gr.Interface(
|
| 268 |
fn=run,
|
| 269 |
inputs=[
|
| 270 |
-
gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [0, 2)'),
|
| 271 |
-
gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [2, 4)'),
|
| 272 |
-
gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [4, 6)'),
|
| 273 |
-
gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [6, 8)'),
|
| 274 |
gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [8, 10)'),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
gr.inputs.Number(default=0, label='Random Seed', optional=False),
|
| 276 |
gr.inputs.Radio(['Prepare Training Data', 'Start Regressing!'],
|
| 277 |
type="index", default=None, label='Mode', optional=False),
|
|
|
|
| 92 |
scheduler.step()
|
| 93 |
if (epoch + 1) % PRINT_FREQ == 0:
|
| 94 |
visualize_training_process(training_df, training_bundle, epoch)
|
| 95 |
+
visualize_training_process(training_df, training_bundle, num_epochs-1, final=True)
|
| 96 |
|
| 97 |
|
| 98 |
def visualize_training_process(training_df, training_bundle, epoch, final=False):
|
|
|
|
| 101 |
model.eval()
|
| 102 |
y = model(X_demo)
|
| 103 |
df = df.append(make_dataframe(X_demo, y, criterion_name), ignore_index=True)
|
| 104 |
+
visualize(training_df, df, 'train_log/{:05d}.png'.format(epoch + 1), fast=True, epoch=epoch)
|
| 105 |
if final:
|
| 106 |
visualize(training_df, df, 'regression_result.png', fast=False)
|
| 107 |
|
|
|
|
| 200 |
return len(self.inputs)
|
| 201 |
|
| 202 |
|
| 203 |
+
def visualize(training_df, df, save_path, fast=False, epoch=None):
|
| 204 |
if fast:
|
| 205 |
f = plt.figure(figsize=(3, 3))
|
| 206 |
g = f.add_subplot(111)
|
|
|
|
| 208 |
plt.xlim((Y_LB - B) / K, (Y_UB - B) / K)
|
| 209 |
plt.ylim(Y_LB, Y_UB)
|
| 210 |
else:
|
| 211 |
+
g = sns.jointplot(data=training_df, x='x', y='y', color='#003ea1', alpha=0.1, linewidths=0, s=50,
|
| 212 |
marginal_kws=dict(bins=torch.linspace(Y_LB, Y_UB, steps=NUM_SEG + 1)),
|
| 213 |
xlim=((Y_LB - B) / K, (Y_UB - B) / K),
|
| 214 |
ylim=(Y_LB, Y_UB),
|
| 215 |
space=0.1,
|
| 216 |
+
height=5,
|
| 217 |
ratio=2,
|
| 218 |
estimator=None, ci=None,
|
| 219 |
legend=False,
|
| 220 |
)
|
| 221 |
g.ax_marg_x.remove()
|
| 222 |
g_line = sns.lineplot(data=df, x='x', y='y', hue='Method', ax=g.ax_joint, estimator=None, ci=None)
|
| 223 |
+
if epoch is not None:
|
| 224 |
+
g_line.legend(loc='upper left', title="Epoch {:03d}".format(epoch+1))
|
| 225 |
+
else:
|
| 226 |
+
g_line.legend(loc='upper left')
|
| 227 |
+
plt.gca().axes.set_xlabel(r'$x$')
|
| 228 |
+
plt.gca().axes.set_ylabel(r'$y$')
|
| 229 |
|
| 230 |
+
plt.savefig(save_path, bbox_inches='tight', dpi=200)
|
| 231 |
plt.close()
|
| 232 |
|
| 233 |
|
|
|
|
| 269 |
iface = gr.Interface(
|
| 270 |
fn=run,
|
| 271 |
inputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [8, 10)'),
|
| 273 |
+
gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [6, 8)'),
|
| 274 |
+
gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [4, 6)'),
|
| 275 |
+
gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [2, 4)'),
|
| 276 |
+
gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [0, 2)'),
|
| 277 |
gr.inputs.Number(default=0, label='Random Seed', optional=False),
|
| 278 |
gr.inputs.Radio(['Prepare Training Data', 'Start Regressing!'],
|
| 279 |
type="index", default=None, label='Mode', optional=False),
|