import torch from torch.utils.data import Dataset,DataLoader import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import numpy as np import datetime class AugmentationParams(): def __init__(self,tshift_max=0, tshift_prob=0, tmask_min=0, tmask_max=0, tmask_prob=0, fshift_max=0, fshift_prob=0, fmask_min=0, fmask_max=0, fmask_prob=0, scale_min=1, scale_max=1, scale_prob=0, fgmix_weight_min=0,fgmix_weight_max=0, fgmix_prob=0, bgmix_weight_min=0,bgmix_weight_max=0, bgmix_prob=0, pixelnoise_rate_min=0, pixelnoise_rate_max=0, pixelnoise_intensity_min=0, pixelnoise_intensity_max=0, pixelnoise_prob=0): self.tshift_max = tshift_max self.tshift_prob = tshift_prob self.tmask_min = tmask_min self.tmask_max = tmask_max self.tmask_prob = tmask_prob self.fshift_max = fshift_max self.fshift_prob = fshift_prob self.fmask_min = fmask_min self.fmask_max = fmask_max self.fmask_prob = fmask_prob self.scale_min = scale_min self.scale_max = scale_max self.scale_prob = scale_prob self.fgmix_weight_max = fgmix_weight_max self.fgmix_weight_min = fgmix_weight_min self.fgmix_prob = fgmix_prob self.bgmix_weight_max = bgmix_weight_max self.bgmix_weight_min = bgmix_weight_min self.bgmix_prob = bgmix_prob self.pixelnoise_rate_min = pixelnoise_rate_min self.pixelnoise_rate_max = pixelnoise_rate_max self.pixelnoise_intensity_min = pixelnoise_intensity_min self.pixelnoise_intensity_max = pixelnoise_intensity_max self.pixelnoise_prob = pixelnoise_prob def augment(img1, ap): """data augmentation img1: original image ap: AugmentationParams """ img = np.copy(img1) ntime, nfreq = img.shape if ap.tshift_prob > np.random.uniform(): tshift = np.random.randint(low=0, high=ap.tshift_max+1) img = np.concatenate((img[tshift:,:], img[:tshift,:]),axis=0) if ap.fshift_prob > np.random.uniform(): fshift = np.random.randint(low=-ap.fshift_max, high=ap.fshift_max+1) if fshift < 0: # shift down fshift2 = -fshift img = np.concatenate((img[:,fshift2:], np.repeat(img[:,nfreq-1],fshift2).reshape((ntime,fshift2))),axis=1) elif fshift > 0: # shift up img = np.concatenate((np.repeat(img[:,0],fshift).reshape((ntime,fshift)), img[:,:-fshift]),axis=1) if ap.scale_prob > np.random.uniform(): # scale max between scale_min and scale_max origmax = np.max(img) w = np.random.uniform(low=ap.scale_min, high=ap.scale_max) img *= w/origmax if ap.tmask_prob > np.random.uniform(): tpos = np.random.randint(low=0, high=ntime) twidth = np.random.randint(low=ap.tmask_min, high=ap.tmask_max) tpos2=min(tpos+twidth, ntime) img[tpos:tpos2,:] = 0 if ap.fmask_prob > np.random.uniform(): fpos = np.random.randint(low=0, high=nfreq) fwidth = np.random.randint(low=ap.fmask_min, high=ap.fmask_max) fpos2=min(fpos+fwidth, nfreq) img[:,fpos:fpos2] = 0 if ap.pixelnoise_prob > np.random.uniform(): r = np.random.uniform(low=ap.pixelnoise_rate_min, high=ap.pixelnoise_rate_max) nn = np.int(ntime * nfreq * r) ii = np.random.randint(0,ntime, size=nn) jj = np.random.randint(0,nfreq, size=nn) img[ii,jj] += np.random.uniform(low=ap.pixelnoise_intensity_min, high=ap.pixelnoise_intensity_max, size=nn) return img def normalize(data): return (data-np.mean(data))/np.std(data) class Dataset(Dataset): """ data and labels in input and numpy arrays, they are converted into tensors n1: number of fg samples, n2: number of bg samples to train clean bg class, n3: number of bg samples only to be mixed with fg """ def __init__(self, data, labels, n1,n2=0,n3=0, ntime=512, ap=None, eps=0.01, nclasses=22): self.data = data # label is float because BCEWithLogitsLoss supports labels that are probabilities self.labels = torch.nn.functional.one_hot(torch.from_numpy(labels),num_classes=nclasses).float() self.labels = torch.clamp(self.labels, min=eps, max=1.0-eps) self.ap = ap self.n1 = n1; self.n2 = n2; self.n3 = n3; self.ntime=ntime def __len__(self): return self.n1+self.n2+self.n3 def __getitem__(self, idx): img = self.data[idx] lab = self.labels[idx] if self.ap: img = augment(img, self.ap) # sample 2nd image only when it is needed... if idx < self.n1 and self.ap.fgmix_prob > np.random.uniform(): # mix only foreground species idx2 = np.random.randint(low=0, high=self.n1) img2 = self.data[idx2] img2 = augment(img2, self.ap) lab2 = self.labels[idx2] #w = np.random.uniform(low=self.ap.fgmix_weight_min, high=self.ap.fgmix_weight_max) #img = (1-w) * img + w * img2 #img = np.maximum(img, w*img2) img = np.maximum(img, img2) lab = torch.maximum(lab, lab2) if idx < self.n1 and self.ap.bgmix_prob > np.random.uniform(): # mix with background idx2 = np.random.randint(low=self.n1, high=self.n1+self.n2) img2 = self.data[idx2] img2 = augment(img2, self.ap) lab2 = self.labels[idx2] #w = np.random.uniform(low=self.ap.bgmix_weight_min, high=self.ap.bgmix_weight_max) #img = (1-w) * img + w * img2 #img = np.maximum(img, w*img2) img = np.maximum(img, img2) lab = torch.maximum(lab, lab2) #img=normalize(img) img = torch.from_numpy(img[:self.ntime]).unsqueeze(0) return img, lab class Net(nn.Module): def __init__(self, ntime=512, nfreq=128, nclasses=22): super(Net, self).__init__() # 1 input 100x64 image channel, 32 output channels, 3x3 square convolution kernel self.conv1 = nn.Conv2d( 1, 32, 3, padding=1) self.conv2 = nn.Conv2d( 32, 64, 3, padding=1) self.conv3 = nn.Conv2d( 64, 128, 3, padding=1) self.conv4 = nn.Conv2d( 128, 256, 3, padding=1) self.conv5 = nn.Conv2d( 256, 512, 3, padding=1) self.conv6 = nn.Conv2d( 512, 512, 3, padding=1) # image dimension after maxpool layers n_maxpool = 6 nt = ntime nr = nfreq for i in range(n_maxpool): nt //= 2 nr //= 2 nr *= 4 self.fc1 = nn.Linear(512 * nt * nr, 512) self.fc2 = nn.Linear(512, 128) self.fc3 = nn.Linear(128, nclasses) def forward(self, x): # Max pooling over a (2, 2) window x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) # If the size is a square, you can specify with a single number, default stride is kernel size x = F.max_pool2d(F.relu(self.conv2(x)), 2) x = F.max_pool2d(F.relu(self.conv3(x)), 2) # x = F.max_pool2d(F.relu(self.conv4(x)), (8,2)) x = F.max_pool2d(F.relu(self.conv4(x)), (2,2)) x = F.max_pool2d(F.relu(self.conv5(x)), (2,1)) x = F.max_pool2d(F.relu(self.conv6(x)), (2,1)) x = torch.flatten(x, 1) # flatten all dimensions except the batch dimension, input now is 8*8 image (64*512 filter outputs) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x def train(net, loss_fn, optimizer, train_dataloader, validation_dataloader, device, nepochs=10, info=0, model_outfile=None): """ returns accuracy for train and validation data from each training epoch if model_outfile defined, save the best model """ acc_train=np.zeros(nepochs) acc_valid=np.zeros(nepochs) bestval=0 for epoch in range(nepochs): net.train() running_loss = 0.0 for i, data in enumerate(train_dataloader, 0): inputs, labels = data inputs = inputs.to(device=device) labels = labels.to(device=device) outputs = net(inputs) loss = loss_fn(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() if info>1: running_loss += loss.item() if i % 100 == 99: print('[%d, %5d] train_loss: %.3f' % (epoch + 1, i + 1, running_loss / 100)) running_loss = 0.0 net.eval() train_accuracy = validate(net, train_dataloader, device) validation_accuracy = validate(net, validation_dataloader, device) acc_train[epoch] = train_accuracy acc_valid[epoch] = validation_accuracy if info: print(f'{datetime.datetime.now()} epoch {epoch}, train_accuracy: {train_accuracy:.3f} validation_accuracy: {validation_accuracy:.3f}') if validation_accuracy > bestval: bestval = validation_accuracy if model_outfile: torch.save(net.state_dict(), model_outfile) return acc_train, acc_valid def validate1(net, dataloader, device): """ assumes only one given label """ correct=0 total=0 net.eval() with torch.no_grad(): for inputs, labels in dataloader: inputs = inputs.to(device=device) labels = labels.to(device=device) outputs = net(inputs) predicted = torch.argmax(outputs, 1) given_labels = torch.argmax(labels, 1) total += labels.shape[0] correct += int((predicted == given_labels).sum()) return correct/total def validate(net, dataloader, device): """ allows multi-labeling """ correct = 0.0 total = 0 net.eval() with torch.no_grad(): for inputs, labels in dataloader: inputs = inputs.to(device=device) labels = labels.to(device=device) outputs = net(inputs) for i,lab in enumerate(labels): target = torch.where(lab > 0.5, 1, 0) ntarget = target.sum() out, ind = torch.sort(outputs[i], descending=True) correct += target[ind[:ntarget]].sum() / ntarget total += labels.shape[0] return correct/total def classify(net, dataloader, device, nclasses=22): """ compute logits using dataloader """ n=len(dataloader.dataset) out=np.zeros((n,nclasses)) i1=0 net.eval() with torch.no_grad(): for inputs, labels in dataloader: inputs = inputs.to(device=device) outputs = net(inputs) i2=i1+len(outputs) out[i1:i2] = outputs.detach().numpy() i1=i2 return out def classify1_cpu(dat, net, nclasses=22): """ compute logits for data matrix using for loop (cpu only) """ net.eval() n=len(dat) out=np.zeros((n,nclasses)) for i in range(n): out[i]=net(torch.unsqueeze(torch.unsqueeze(torch.from_numpy(dat[i]),0),0)).detach().numpy() return out def classify_cpu(dat, net): """ compute logits for entire data matrix (cpu only) """ net.eval() # add dimension for number of channels (1) so that tensor is [num_segments num_channels ntime nfreq] out = net(torch.unsqueeze(torch.from_numpy(dat),1)).detach().numpy() return out