File size: 11,900 Bytes
e9f9fd3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 |
"Provides advanced training extensions to `fastai.basic_train`. Includes half-precision, learning rate finder, mixup, and one-cycle"
from .torch_core import *
from .callbacks import *
from .basic_data import *
from .basic_train import *
__all__ = ['BnFreeze', 'GradientClipping', 'ShowGraph', 'Interpretation', 'ClassificationInterpretation', 'MultiLabelClassificationInterpretation',
'fit_one_cycle', 'lr_find', 'one_cycle_scheduler', 'to_fp16', 'to_fp32', 'mixup', 'AccumulateScheduler']
def one_cycle_scheduler(lr_max:float, **kwargs:Any)->OneCycleScheduler:
"Instantiate a `OneCycleScheduler` with `lr_max`."
return partial(OneCycleScheduler, lr_max=lr_max, **kwargs)
def fit_one_cycle(learn:Learner, cyc_len:int, max_lr:Union[Floats,slice]=defaults.lr,
moms:Tuple[float,float]=(0.95,0.85), div_factor:float=25., pct_start:float=0.3, final_div:float=None,
wd:float=None, callbacks:Optional[CallbackList]=None, tot_epochs:int=None, start_epoch:int=None,
batch_multiplier:int=1)->None:
"Fit a model following the 1cycle policy."
max_lr = learn.lr_range(max_lr)
callbacks = listify(callbacks)
callbacks.append(OneCycleScheduler(learn, max_lr, moms=moms, div_factor=div_factor, pct_start=pct_start,
final_div=final_div, tot_epochs=tot_epochs, start_epoch=start_epoch))
learn.fit(cyc_len, max_lr, wd=wd, callbacks=callbacks, batch_multiplier=batch_multiplier)
def lr_find(learn:Learner, start_lr:Floats=1e-7, end_lr:Floats=10, num_it:int=100, stop_div:bool=True, wd:float=None,
batch_multiplier:int=1):
"Explore lr from `start_lr` to `end_lr` over `num_it` iterations in `learn`. If `stop_div`, stops when loss diverges."
start_lr = learn.lr_range(start_lr)
start_lr = np.array(start_lr) if is_listy(start_lr) else start_lr
end_lr = learn.lr_range(end_lr)
end_lr = np.array(end_lr) if is_listy(end_lr) else end_lr
cb = LRFinder(learn, start_lr, end_lr, num_it, stop_div)
epochs = int(np.ceil(num_it/len(learn.data.train_dl)))
learn.fit(epochs, start_lr, callbacks=[cb], wd=wd, batch_multiplier=batch_multiplier)
def to_fp16(learn:Learner, loss_scale:float=None, max_noskip:int=1000, dynamic:bool=True, clip:float=None,
flat_master:bool=False, max_scale:float=2**24)->Learner:
"Put `learn` in FP16 precision mode."
learn.to_fp32()
learn.model = model2half(learn.model)
learn.data.add_tfm(batch_to_half)
learn.mp_cb = MixedPrecision(learn, loss_scale=loss_scale, max_noskip=max_noskip, dynamic=dynamic, clip=clip,
flat_master=flat_master, max_scale=max_scale)
learn.callbacks.append(learn.mp_cb)
return learn
def to_fp32(learn:Learner):
"Put `learn` back to FP32 precision mode."
learn.data.remove_tfm(batch_to_half)
for cb in learn.callbacks:
if isinstance(cb, MixedPrecision): learn.callbacks.remove(cb)
learn.model = learn.model.float()
return learn
def mixup(learn:Learner, alpha:float=0.4, stack_x:bool=False, stack_y:bool=True) -> Learner:
"Add mixup https://arxiv.org/abs/1710.09412 to `learn`."
learn.callback_fns.append(partial(MixUpCallback, alpha=alpha, stack_x=stack_x, stack_y=stack_y))
return learn
Learner.fit_one_cycle = fit_one_cycle
Learner.lr_find = lr_find
Learner.to_fp16 = to_fp16
Learner.to_fp32 = to_fp32
Learner.mixup = mixup
class ShowGraph(LearnerCallback):
"Update a graph of learner stats and metrics after each epoch."
def on_epoch_end(self, n_epochs:int, last_metrics:MetricsList, **kwargs)->bool:
"If we have `last_metrics` plot them in our pbar graph"
if last_metrics is not None and last_metrics[0] is not None:
rec = self.learn.recorder
iters = range_of(rec.losses)
val_iter = np.array(rec.nb_batches).cumsum()
x_bounds = (0, (n_epochs - len(rec.nb_batches)) * rec.nb_batches[-1] + len(rec.losses))
y_bounds = (0, max((max(Tensor(rec.losses)), max(Tensor(rec.val_losses)))))
rec.pbar.update_graph([(iters, rec.losses), (val_iter, rec.val_losses)], x_bounds, y_bounds)
return {}
class BnFreeze(LearnerCallback):
"Freeze moving average statistics in all non-trainable batchnorm layers."
def on_epoch_begin(self, **kwargs:Any)->None:
"Put bn layers in eval mode just after `model.train()`."
set_bn_eval(self.learn.model)
class GradientClipping(LearnerCallback):
"Gradient clipping during training."
def __init__(self, learn:Learner, clip:float = 0.):
super().__init__(learn)
self.clip = clip
def on_backward_end(self, **kwargs):
"Clip the gradient before the optimizer step."
if self.clip: nn.utils.clip_grad_norm_(self.learn.model.parameters(), self.clip)
def clip_grad(learn:Learner, clip:float=0.1)->Learner:
"Add gradient clipping of `clip` during training."
learn.callback_fns.append(partial(GradientClipping, clip=clip))
return learn
Learner.clip_grad = clip_grad
class AccumulateScheduler(LearnerCallback):
"Does accumlated step every nth step by accumulating gradients"
def __init__(self, learn:Learner, n_step:int = 1, drop_last:bool = False):
super().__init__(learn)
self.n_step,self.drop_last = n_step,drop_last
def on_train_begin(self, **kwargs):
"check if loss is reduction"
if hasattr(self.loss_func, "reduction") and (self.loss_func.reduction != "sum"):
warn("For better gradients consider 'reduction=sum'")
def on_epoch_begin(self, **kwargs):
"init samples and batches, change optimizer"
self.acc_samples, self.acc_batches = 0., 0.
def on_batch_begin(self, last_input, last_target, **kwargs):
"accumulate samples and batches"
self.acc_samples += last_input.shape[0]
self.acc_batches += 1
def on_backward_end(self, **kwargs):
"accumulated step and reset samples, True will result in no stepping"
if (self.acc_batches % self.n_step) == 0:
for p in (self.learn.model.parameters()):
if p.requires_grad: p.grad.div_(self.acc_samples)
self.acc_samples = 0
else: return {'skip_step':True, 'skip_zero':True}
def on_epoch_end(self, **kwargs):
"step the rest of the accumulated grads if not perfectly divisible"
for p in (self.learn.model.parameters()):
if p.requires_grad: p.grad.div_(self.acc_samples)
if not self.drop_last: self.learn.opt.step()
self.learn.opt.zero_grad()
class Interpretation():
"Interpretation base class, can be inherited for task specific Interpretation classes"
def __init__(self, learn:Learner, preds:Tensor, y_true:Tensor, losses:Tensor, ds_type:DatasetType=DatasetType.Valid):
self.data,self.preds,self.y_true,self.losses,self.ds_type, self.learn = \
learn.data,preds,y_true,losses,ds_type,learn
self.ds = (self.data.train_ds if ds_type == DatasetType.Train else
self.data.test_ds if ds_type == DatasetType.Test else
self.data.valid_ds if ds_type == DatasetType.Valid else
self.data.single_ds if ds_type == DatasetType.Single else
self.data.fix_ds)
@classmethod
def from_learner(cls, learn: Learner, ds_type:DatasetType=DatasetType.Valid, activ:nn.Module=None):
"Gets preds, y_true, losses to construct base class from a learner"
preds_res = learn.get_preds(ds_type=ds_type, activ=activ, with_loss=True)
return cls(learn, *preds_res)
def top_losses(self, k:int=None, largest=True):
"`k` largest(/smallest) losses and indexes, defaulting to all losses (sorted by `largest`)."
return self.losses.topk(ifnone(k, len(self.losses)), largest=largest)
# def top_scores(self, metric:Callable=None, k:int=None, largest=True):
# "`k` largest(/smallest) metric scores and indexes, defaulting to all scores (sorted by `largest`)."
# self.scores = metric(self.preds, self.y_true)
# return self.scores.topk(ifnone(k, len(self.scores)), largest=largest)
class ClassificationInterpretation(Interpretation):
"Interpretation methods for classification models."
def __init__(self, learn:Learner, preds:Tensor, y_true:Tensor, losses:Tensor, ds_type:DatasetType=DatasetType.Valid):
super(ClassificationInterpretation, self).__init__(learn,preds,y_true,losses,ds_type)
self.pred_class = self.preds.argmax(dim=1)
def confusion_matrix(self, slice_size:int=1):
"Confusion matrix as an `np.ndarray`."
x=torch.arange(0,self.data.c)
if slice_size is None: cm = ((self.pred_class==x[:,None]) & (self.y_true==x[:,None,None])).sum(2)
else:
cm = torch.zeros(self.data.c, self.data.c, dtype=x.dtype)
for i in range(0, self.y_true.shape[0], slice_size):
cm_slice = ((self.pred_class[i:i+slice_size]==x[:,None])
& (self.y_true[i:i+slice_size]==x[:,None,None])).sum(2)
torch.add(cm, cm_slice, out=cm)
return to_np(cm)
def plot_confusion_matrix(self, normalize:bool=False, title:str='Confusion matrix', cmap:Any="Blues", slice_size:int=1,
norm_dec:int=2, plot_txt:bool=True, return_fig:bool=None, **kwargs)->Optional[plt.Figure]:
"Plot the confusion matrix, with `title` and using `cmap`."
# This function is mainly copied from the sklearn docs
cm = self.confusion_matrix(slice_size=slice_size)
if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
fig = plt.figure(**kwargs)
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
tick_marks = np.arange(self.data.c)
plt.xticks(tick_marks, self.data.y.classes, rotation=90)
plt.yticks(tick_marks, self.data.y.classes, rotation=0)
if plot_txt:
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
coeff = f'{cm[i, j]:.{norm_dec}f}' if normalize else f'{cm[i, j]}'
plt.text(j, i, coeff, horizontalalignment="center", verticalalignment="center", color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('Actual')
plt.xlabel('Predicted')
plt.grid(False)
if ifnone(return_fig, defaults.return_fig): return fig
def most_confused(self, min_val:int=1, slice_size:int=1)->Collection[Tuple[str,str,int]]:
"Sorted descending list of largest non-diagonal entries of confusion matrix, presented as actual, predicted, number of occurrences."
cm = self.confusion_matrix(slice_size=slice_size)
np.fill_diagonal(cm, 0)
res = [(self.data.classes[i],self.data.classes[j],cm[i,j])
for i,j in zip(*np.where(cm>=min_val))]
return sorted(res, key=itemgetter(2), reverse=True)
def _learner_interpret(learn:Learner, ds_type:DatasetType=DatasetType.Valid):
"Create a `ClassificationInterpretation` object from `learner` on `ds_type` with `tta`."
return ClassificationInterpretation.from_learner(learn, ds_type=ds_type)
Learner.interpret = _learner_interpret
class MultiLabelClassificationInterpretation(Interpretation):
"Interpretation methods for classification models."
def __init__(self, learn:Learner, preds:Tensor, y_true:Tensor, losses:Tensor, ds_type:DatasetType=DatasetType.Valid,
sigmoid:bool=True, thresh:float=0.3):
raise NotImplementedError
super(MultiLabelClassificationInterpretation, self).__init__(learn,preds,y_true,losses,ds_type)
self.pred_class = self.preds.sigmoid(dim=1)>thresh if sigmoid else self.preds>thresh
|