Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import torch | |
| import seaborn as sns | |
| import pandas as pd | |
| import os | |
| import os.path as osp | |
| import ffmpeg | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.nn.modules.loss import _Loss | |
| from torch.utils.data import Dataset, DataLoader | |
| NUM_PER_BUCKET = 1000 | |
| NOISE_SIGMA = 1 | |
| Y_UB = 10 | |
| Y_LB = 0 | |
| K = 1 | |
| B = 0 | |
| NUM_SEG = 5 | |
| NUM_EPOCHS = 100 | |
| PRINT_FREQ = NUM_EPOCHS // 20 | |
| NUM_TRAIN_SAMPLES = NUM_PER_BUCKET * NUM_SEG | |
| BATCH_SIZE = 256 | |
| def make_dataframe(x, y, method=None): | |
| x = list(x[:, 0].detach().numpy()) | |
| y = list(y[:, 0].detach().numpy()) | |
| if method is not None: | |
| method = [method for _ in range(len(x))] | |
| df = pd.DataFrame({'x': x, 'y': y, 'Method': method}) | |
| else: | |
| df = pd.DataFrame({'x': x, 'y': y}) | |
| return df | |
| Y_demo = torch.linspace(Y_LB, Y_UB, 2).unsqueeze(-1) | |
| X_demo = (Y_demo - B) / K | |
| df_oracle = make_dataframe(X_demo, Y_demo, 'Oracle') | |
| def prepare_data(sel_num): | |
| interval = (Y_UB - Y_LB) / NUM_SEG | |
| all_x, all_y = [], [] | |
| prob = [] | |
| for i in range(NUM_SEG): | |
| uniform_y_distribution = torch.distributions.Uniform(Y_UB - (i + 1) * interval, Y_UB - i * interval) | |
| y_uniform = uniform_y_distribution.sample((NUM_TRAIN_SAMPLES, 1))[:sel_num[i]] | |
| noise_distribution = torch.distributions.Normal(loc=0, scale=NOISE_SIGMA) | |
| noise = noise_distribution.sample((NUM_TRAIN_SAMPLES, 1))[:sel_num[i]] | |
| y_uniform_oracle = y_uniform - noise | |
| x_uniform = (y_uniform_oracle - B) / K | |
| all_x += x_uniform | |
| all_y += y_uniform | |
| prob += [torch.tensor(sel_num[i]).float() for _ in range(sel_num[i])] | |
| all_x = torch.stack(all_x) | |
| all_y = torch.stack(all_y) | |
| prob = torch.stack(prob) | |
| return all_x, all_y, prob | |
| def unzip_dataloader(training_loader): | |
| all_x = [] | |
| all_y = [] | |
| for data, label, _ in training_loader: | |
| all_x.append(data) | |
| all_y.append(label) | |
| all_x = torch.cat(all_x) | |
| all_y = torch.cat(all_y) | |
| return all_x, all_y | |
| def train(train_loader, training_df, training_bundle, num_epochs): | |
| visualize_training_process(training_df, training_bundle, -1) | |
| for epoch in range(num_epochs): | |
| for model, optimizer, scheduler, criterion, criterion_name in training_bundle: | |
| model.train() | |
| for data, target, prob in train_loader: | |
| optimizer.zero_grad() | |
| pred = model(data) | |
| if criterion_name == 'Reweight': | |
| loss = criterion(pred, target, prob) | |
| else: | |
| loss = criterion(pred, target) | |
| loss.backward() | |
| optimizer.step() | |
| scheduler.step() | |
| if (epoch + 1) % PRINT_FREQ == 0: | |
| visualize_training_process(training_df, training_bundle, epoch) | |
| visualize_training_process(training_df, training_bundle, num_epochs-1, final=True) | |
| def visualize_training_process(training_df, training_bundle, epoch, final=False): | |
| df = df_oracle | |
| for model, optimizer, scheduler, criterion, criterion_name in training_bundle: | |
| model.eval() | |
| y = model(X_demo) | |
| df = df.append(make_dataframe(X_demo, y, criterion_name), ignore_index=True) | |
| visualize(training_df, df, 'train_log/{:05d}.png'.format(epoch + 1), fast=True, epoch=epoch) | |
| if final: | |
| visualize(training_df, df, 'regression_result.png', fast=False) | |
| def make_video(): | |
| ( | |
| ffmpeg | |
| .input('train_log/*.png', pattern_type='glob', framerate=3) | |
| .output('movie.mp4') | |
| .run() | |
| ) | |
| class ReweightL2(_Loss): | |
| def __init__(self, reweight='inverse'): | |
| super(ReweightL2, self).__init__() | |
| self.reweight = reweight | |
| def forward(self, pred, target, prob): | |
| reweight = self.reweight | |
| if reweight == 'inverse': | |
| inv_prob = prob.pow(-1) | |
| elif reweight == 'sqrt_inv': | |
| inv_prob = prob.pow(-0.5) | |
| else: | |
| raise NotImplementedError | |
| inv_prob = inv_prob / inv_prob.sum() | |
| loss = F.mse_loss(pred, target, reduction='none').sum(-1) * inv_prob | |
| loss = loss.sum() | |
| return loss | |
| class LinearModel(nn.Module): | |
| def __init__(self, input_dim, output_dim): | |
| super(LinearModel, self).__init__() | |
| self.mlp = nn.Sequential( | |
| nn.Linear(input_dim, output_dim), | |
| ) | |
| def forward(self, x): | |
| x = self.mlp(x) | |
| return x | |
| def prepare_model(): | |
| model = LinearModel(input_dim=1, output_dim=1) | |
| optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9) | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS) | |
| return model, optimizer, scheduler | |
| class BMCLoss(_Loss): | |
| def __init__(self): | |
| super(BMCLoss, self).__init__() | |
| self.noise_sigma = NOISE_SIGMA | |
| def forward(self, pred, target): | |
| pred = pred.reshape(-1, 1) | |
| target = target.reshape(-1, 1) | |
| noise_var = self.noise_sigma ** 2 | |
| loss = bmc_loss(pred, target, noise_var) | |
| return loss | |
| def bmc_loss(pred, target, noise_var): | |
| logits = - 0.5 * (pred - target.T).pow(2) / noise_var | |
| loss = F.cross_entropy(logits, torch.arange(pred.shape[0])) | |
| return loss * (2 * noise_var) | |
| def regress(train_loader, training_df): | |
| training_bundle = [] | |
| criterions = { | |
| 'MSE': torch.nn.MSELoss(), | |
| 'Reweight': ReweightL2(), | |
| 'Balanced MSE': BMCLoss(), | |
| } | |
| for criterion_name in criterions: | |
| criterion = criterions[criterion_name] | |
| model, optimizer, scheduler = prepare_model() | |
| training_bundle.append((model, optimizer, scheduler, criterion, criterion_name)) | |
| train(train_loader, training_df, training_bundle, NUM_EPOCHS) | |
| class DummyDataset(Dataset): | |
| def __init__(self, inputs, targets, prob): | |
| self.inputs = inputs | |
| self.targets = targets | |
| self.prob = prob | |
| def __getitem__(self, index): | |
| return self.inputs[index], self.targets[index], self.prob[index] | |
| def __len__(self): | |
| return len(self.inputs) | |
| def visualize(training_df, df, save_path, fast=False, epoch=None): | |
| if fast: | |
| f = plt.figure(figsize=(3, 3)) | |
| g = f.add_subplot(111) | |
| g_line = sns.lineplot(data=df, x='x', y='y', hue='Method', ax=g, estimator=None, ci=None) | |
| plt.xlim((Y_LB - B) / K, (Y_UB - B) / K) | |
| plt.ylim(Y_LB, Y_UB) | |
| else: | |
| g = sns.jointplot(data=training_df, x='x', y='y', color='#003ea1', alpha=0.1, linewidths=0, s=50, | |
| marginal_kws=dict(bins=torch.linspace(Y_LB, Y_UB, steps=NUM_SEG + 1)), | |
| xlim=((Y_LB - B) / K, (Y_UB - B) / K), | |
| ylim=(Y_LB, Y_UB), | |
| space=0.1, | |
| height=5, | |
| ratio=2, | |
| estimator=None, ci=None, | |
| legend=False, | |
| ) | |
| g.ax_marg_x.remove() | |
| g_line = sns.lineplot(data=df, x='x', y='y', hue='Method', ax=g.ax_joint, estimator=None, ci=None) | |
| if epoch is not None: | |
| g_line.legend(loc='upper left', title="Epoch {:03d}".format(epoch+1)) | |
| else: | |
| g_line.legend(loc='upper left') | |
| plt.gca().axes.set_xlabel(r'$x$') | |
| plt.gca().axes.set_ylabel(r'$y$') | |
| plt.savefig(save_path, bbox_inches='tight', dpi=200) | |
| plt.close() | |
| def clean_up_logs(): | |
| if not osp.exists('train_log'): | |
| os.mkdir('train_log') | |
| for f in os.listdir('train_log'): | |
| os.remove(osp.join('train_log', f)) | |
| for f in ['regression_result.png', 'training_data.png', 'movie.mp4']: | |
| if osp.isfile(f): | |
| os.remove(f) | |
| def run(num1, num2, num3, num4, num5, random_seed, mode): | |
| sel_num = [num1, num2, num3, num4, num5] | |
| sel_num = [int(num / 100 * NUM_PER_BUCKET) for num in sel_num] | |
| torch.manual_seed(int(random_seed)) | |
| all_x, all_y, prob = prepare_data(sel_num) | |
| train_loader = DataLoader(DummyDataset(all_x, all_y, prob), BATCH_SIZE, shuffle=True) | |
| training_df = make_dataframe(all_x, all_y) | |
| clean_up_logs() | |
| if mode == 0: | |
| visualize(training_df, df_oracle, 'training_data.png') | |
| if mode == 1: | |
| regress(train_loader, training_df) | |
| make_video() | |
| if mode == 0: | |
| text = "Press \"Start Regressing\" if your are happy with the training data. Regression takes ~30s." | |
| else: | |
| text = "Press \"Prepare Training Data\" before moving the sliders. You may also change the random seed." | |
| training_data_plot = 'training_data.png' if mode == 0 else None | |
| output = 'regression_result.png'.format(NUM_EPOCHS) if mode == 1 else None | |
| video = "movie.mp4" if mode == 1 else None | |
| return training_data_plot, output, video, text | |
| if __name__ == '__main__': | |
| iface = gr.Interface( | |
| fn=run, | |
| inputs=[ | |
| gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [8, 10)'), | |
| gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [6, 8)'), | |
| gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [4, 6)'), | |
| gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [2, 4)'), | |
| gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [0, 2)'), | |
| gr.inputs.Number(default=0, label='Random Seed', optional=False), | |
| gr.inputs.Radio(['Prepare Training Data', 'Start Regressing!'], | |
| type="index", default=None, label='Mode', optional=False), | |
| ], | |
| outputs=[ | |
| gr.outputs.Image(type="file", label="Training data"), | |
| gr.outputs.Image(type="file", label="Regression result"), | |
| gr.outputs.Video(type='mp4', label='Training process'), | |
| gr.outputs.Textbox(type="auto", label='What\' s next?') | |
| ], | |
| live=True, | |
| allow_flagging='never', | |
| title="Balanced MSE for Imbalanced Visual Regression [CVPR 2022]", | |
| description="Welcome to the demo of Balanced MSE ⚖. In this demo, we will work on a simple task: imbalanced <i>linear</i> regression. <br>" | |
| "To get started, move the sliders 🎚 to create your training data " | |
| "or click the examples 📕 at the bottom of the page 👇👇", | |
| examples=[ | |
| [0.1, 0.8, 6.4, 51.2, 100, 0, 'Prepare Training Data'], | |
| [1, 10, 100, 10, 1, 0, 'Prepare Training Data'], | |
| ], | |
| css=".output-image, .image-preview {height: 500px !important}", | |
| article="<p style='text-align: center'><a href='https://github.com/jiawei-ren/BalancedMSE' target='_blank'>Balanced MSE @ GitHub</a></p> " | |
| ) | |
| iface.launch() | |