| import torch |
| from torch.utils.data import Dataset,DataLoader |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.optim as optim |
| import numpy as np |
| import datetime |
|
|
| class AugmentationParams(): |
| def __init__(self,tshift_max=0, tshift_prob=0, |
| tmask_min=0, tmask_max=0, tmask_prob=0, |
| fshift_max=0, fshift_prob=0, |
| fmask_min=0, fmask_max=0, fmask_prob=0, |
| scale_min=1, scale_max=1, scale_prob=0, |
| fgmix_weight_min=0,fgmix_weight_max=0, fgmix_prob=0, |
| bgmix_weight_min=0,bgmix_weight_max=0, bgmix_prob=0, |
| pixelnoise_rate_min=0, pixelnoise_rate_max=0, pixelnoise_intensity_min=0, pixelnoise_intensity_max=0, pixelnoise_prob=0): |
| self.tshift_max = tshift_max |
| self.tshift_prob = tshift_prob |
| self.tmask_min = tmask_min |
| self.tmask_max = tmask_max |
| self.tmask_prob = tmask_prob |
| self.fshift_max = fshift_max |
| self.fshift_prob = fshift_prob |
| self.fmask_min = fmask_min |
| self.fmask_max = fmask_max |
| self.fmask_prob = fmask_prob |
| self.scale_min = scale_min |
| self.scale_max = scale_max |
| self.scale_prob = scale_prob |
| self.fgmix_weight_max = fgmix_weight_max |
| self.fgmix_weight_min = fgmix_weight_min |
| self.fgmix_prob = fgmix_prob |
| self.bgmix_weight_max = bgmix_weight_max |
| self.bgmix_weight_min = bgmix_weight_min |
| self.bgmix_prob = bgmix_prob |
| self.pixelnoise_rate_min = pixelnoise_rate_min |
| self.pixelnoise_rate_max = pixelnoise_rate_max |
| self.pixelnoise_intensity_min = pixelnoise_intensity_min |
| self.pixelnoise_intensity_max = pixelnoise_intensity_max |
| self.pixelnoise_prob = pixelnoise_prob |
|
|
|
|
| def augment(img1, ap): |
| """data augmentation |
| img1: original image |
| ap: AugmentationParams |
| """ |
|
|
| img = np.copy(img1) |
| ntime, nfreq = img.shape |
| |
| if ap.tshift_prob > np.random.uniform(): |
| tshift = np.random.randint(low=0, high=ap.tshift_max+1) |
| img = np.concatenate((img[tshift:,:], img[:tshift,:]),axis=0) |
| |
| if ap.fshift_prob > np.random.uniform(): |
| fshift = np.random.randint(low=-ap.fshift_max, high=ap.fshift_max+1) |
| if fshift < 0: |
| |
| fshift2 = -fshift |
| img = np.concatenate((img[:,fshift2:], np.repeat(img[:,nfreq-1],fshift2).reshape((ntime,fshift2))),axis=1) |
| elif fshift > 0: |
| |
| img = np.concatenate((np.repeat(img[:,0],fshift).reshape((ntime,fshift)), img[:,:-fshift]),axis=1) |
| |
| if ap.scale_prob > np.random.uniform(): |
| |
| origmax = np.max(img) |
| w = np.random.uniform(low=ap.scale_min, high=ap.scale_max) |
| img *= w/origmax |
| |
| if ap.tmask_prob > np.random.uniform(): |
| tpos = np.random.randint(low=0, high=ntime) |
| twidth = np.random.randint(low=ap.tmask_min, high=ap.tmask_max) |
| tpos2=min(tpos+twidth, ntime) |
| img[tpos:tpos2,:] = 0 |
| |
| if ap.fmask_prob > np.random.uniform(): |
| fpos = np.random.randint(low=0, high=nfreq) |
| fwidth = np.random.randint(low=ap.fmask_min, high=ap.fmask_max) |
| fpos2=min(fpos+fwidth, nfreq) |
| img[:,fpos:fpos2] = 0 |
|
|
| if ap.pixelnoise_prob > np.random.uniform(): |
| r = np.random.uniform(low=ap.pixelnoise_rate_min, high=ap.pixelnoise_rate_max) |
| nn = np.int(ntime * nfreq * r) |
| ii = np.random.randint(0,ntime, size=nn) |
| jj = np.random.randint(0,nfreq, size=nn) |
| img[ii,jj] += np.random.uniform(low=ap.pixelnoise_intensity_min, high=ap.pixelnoise_intensity_max, size=nn) |
|
|
| return img |
|
|
| def normalize(data): |
| return (data-np.mean(data))/np.std(data) |
|
|
| class Dataset(Dataset): |
| """ data and labels in input and numpy arrays, they are converted into tensors |
| n1: number of fg samples, n2: number of bg samples to train clean bg class, n3: number of bg samples only to be mixed with fg |
| """ |
| def __init__(self, data, labels, n1,n2=0,n3=0, ntime=512, ap=None, eps=0.01, nclasses=22): |
| self.data = data |
| |
| self.labels = torch.nn.functional.one_hot(torch.from_numpy(labels),num_classes=nclasses).float() |
| self.labels = torch.clamp(self.labels, min=eps, max=1.0-eps) |
| self.ap = ap |
| self.n1 = n1; |
| self.n2 = n2; |
| self.n3 = n3; |
| self.ntime=ntime |
| |
| def __len__(self): |
| return self.n1+self.n2+self.n3 |
|
|
| def __getitem__(self, idx): |
| img = self.data[idx] |
| lab = self.labels[idx] |
|
|
| if self.ap: |
| img = augment(img, self.ap) |
| |
| if idx < self.n1 and self.ap.fgmix_prob > np.random.uniform(): |
| |
| idx2 = np.random.randint(low=0, high=self.n1) |
| img2 = self.data[idx2] |
| img2 = augment(img2, self.ap) |
| lab2 = self.labels[idx2] |
| |
| |
| |
| img = np.maximum(img, img2) |
| lab = torch.maximum(lab, lab2) |
| |
| if idx < self.n1 and self.ap.bgmix_prob > np.random.uniform(): |
| |
| idx2 = np.random.randint(low=self.n1, high=self.n1+self.n2) |
| img2 = self.data[idx2] |
| img2 = augment(img2, self.ap) |
| lab2 = self.labels[idx2] |
| |
| |
| |
| img = np.maximum(img, img2) |
| lab = torch.maximum(lab, lab2) |
| |
| |
| img = torch.from_numpy(img[:self.ntime]).unsqueeze(0) |
| return img, lab |
|
|
|
|
| class Net(nn.Module): |
|
|
| def __init__(self, ntime=512, nfreq=128, nclasses=22): |
| super(Net, self).__init__() |
| |
| self.conv1 = nn.Conv2d( 1, 32, 3, padding=1) |
| self.conv2 = nn.Conv2d( 32, 64, 3, padding=1) |
| self.conv3 = nn.Conv2d( 64, 128, 3, padding=1) |
| self.conv4 = nn.Conv2d( 128, 256, 3, padding=1) |
| self.conv5 = nn.Conv2d( 256, 512, 3, padding=1) |
| self.conv6 = nn.Conv2d( 512, 512, 3, padding=1) |
| |
| |
| n_maxpool = 6 |
| nt = ntime |
| nr = nfreq |
| for i in range(n_maxpool): |
| nt //= 2 |
| nr //= 2 |
| nr *= 4 |
| |
| self.fc1 = nn.Linear(512 * nt * nr, 512) |
| self.fc2 = nn.Linear(512, 128) |
| self.fc3 = nn.Linear(128, nclasses) |
| |
| def forward(self, x): |
| |
| x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) |
| |
| x = F.max_pool2d(F.relu(self.conv2(x)), 2) |
| x = F.max_pool2d(F.relu(self.conv3(x)), 2) |
| |
| x = F.max_pool2d(F.relu(self.conv4(x)), (2,2)) |
| x = F.max_pool2d(F.relu(self.conv5(x)), (2,1)) |
| x = F.max_pool2d(F.relu(self.conv6(x)), (2,1)) |
| |
| x = torch.flatten(x, 1) |
| x = F.relu(self.fc1(x)) |
| x = F.relu(self.fc2(x)) |
| x = self.fc3(x) |
| return x |
|
|
| def train(net, loss_fn, optimizer, train_dataloader, validation_dataloader, device, nepochs=10, info=0, model_outfile=None): |
| """ returns accuracy for train and validation data from each training epoch |
| if model_outfile defined, save the best model |
| """ |
| acc_train=np.zeros(nepochs) |
| acc_valid=np.zeros(nepochs) |
| bestval=0 |
| for epoch in range(nepochs): |
| net.train() |
| running_loss = 0.0 |
| for i, data in enumerate(train_dataloader, 0): |
| inputs, labels = data |
| inputs = inputs.to(device=device) |
| labels = labels.to(device=device) |
| outputs = net(inputs) |
| loss = loss_fn(outputs, labels) |
| |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
|
|
| if info>1: |
| running_loss += loss.item() |
| if i % 100 == 99: |
| print('[%d, %5d] train_loss: %.3f' % (epoch + 1, i + 1, running_loss / 100)) |
| running_loss = 0.0 |
|
|
| net.eval() |
| train_accuracy = validate(net, train_dataloader, device) |
| validation_accuracy = validate(net, validation_dataloader, device) |
| acc_train[epoch] = train_accuracy |
| acc_valid[epoch] = validation_accuracy |
| if info: |
| print(f'{datetime.datetime.now()} epoch {epoch}, train_accuracy: {train_accuracy:.3f} validation_accuracy: {validation_accuracy:.3f}') |
|
|
| if validation_accuracy > bestval: |
| bestval = validation_accuracy |
| if model_outfile: |
| torch.save(net.state_dict(), model_outfile) |
| return acc_train, acc_valid |
|
|
| |
| def validate1(net, dataloader, device): |
| """ assumes only one given label """ |
| correct=0 |
| total=0 |
| net.eval() |
| with torch.no_grad(): |
| for inputs, labels in dataloader: |
| inputs = inputs.to(device=device) |
| labels = labels.to(device=device) |
| outputs = net(inputs) |
| predicted = torch.argmax(outputs, 1) |
| given_labels = torch.argmax(labels, 1) |
| total += labels.shape[0] |
| correct += int((predicted == given_labels).sum()) |
|
|
| return correct/total |
|
|
|
|
| def validate(net, dataloader, device): |
| """ allows multi-labeling """ |
| correct = 0.0 |
| total = 0 |
| net.eval() |
| with torch.no_grad(): |
| for inputs, labels in dataloader: |
| inputs = inputs.to(device=device) |
| labels = labels.to(device=device) |
| outputs = net(inputs) |
| for i,lab in enumerate(labels): |
| target = torch.where(lab > 0.5, 1, 0) |
| ntarget = target.sum() |
| out, ind = torch.sort(outputs[i], descending=True) |
| correct += target[ind[:ntarget]].sum() / ntarget |
| total += labels.shape[0] |
|
|
| return correct/total |
|
|
|
|
| def classify(net, dataloader, device, nclasses=22): |
| """ compute logits using dataloader """ |
| n=len(dataloader.dataset) |
| out=np.zeros((n,nclasses)) |
| i1=0 |
| net.eval() |
| with torch.no_grad(): |
| for inputs, labels in dataloader: |
| inputs = inputs.to(device=device) |
| outputs = net(inputs) |
| i2=i1+len(outputs) |
| out[i1:i2] = outputs.detach().numpy() |
| i1=i2 |
| return out |
|
|
| def classify1_cpu(dat, net, nclasses=22): |
| """ compute logits for data matrix using for loop (cpu only) """ |
| net.eval() |
| n=len(dat) |
| out=np.zeros((n,nclasses)) |
| for i in range(n): |
| out[i]=net(torch.unsqueeze(torch.unsqueeze(torch.from_numpy(dat[i]),0),0)).detach().numpy() |
| return out |
|
|
| def classify_cpu(dat, net): |
| """ compute logits for entire data matrix (cpu only) """ |
| net.eval() |
| |
| out = net(torch.unsqueeze(torch.from_numpy(dat),1)).detach().numpy() |
| return out |
|
|
|
|