Spaces:
Runtime error
Runtime error
| from ..torch_core import * | |
| from ..basic_data import * | |
| from ..basic_train import * | |
| from ..train import ClassificationInterpretation | |
| import matplotlib.cm as cm | |
| __all__ = ['TextClassificationInterpretation'] | |
| def value2rgba(x:float, cmap:Callable=cm.RdYlGn, alpha_mult:float=1.0)->Tuple: | |
| "Convert a value `x` from 0 to 1 (inclusive) to an RGBA tuple according to `cmap` times transparency `alpha_mult`." | |
| c = cmap(x) | |
| rgb = (np.array(c[:-1]) * 255).astype(int) | |
| a = c[-1] * alpha_mult | |
| return tuple(rgb.tolist() + [a]) | |
| def piece_attn_html(pieces:List[str], attns:List[float], sep:str=' ', **kwargs)->str: | |
| html_code,spans = ['<span style="font-family: monospace;">'], [] | |
| for p, a in zip(pieces, attns): | |
| p = html.escape(p) | |
| c = str(value2rgba(a, alpha_mult=0.5, **kwargs)) | |
| spans.append(f'<span title="{a:.3f}" style="background-color: rgba{c};">{p}</span>') | |
| html_code.append(sep.join(spans)) | |
| html_code.append('</span>') | |
| return ''.join(html_code) | |
| def show_piece_attn(*args, **kwargs): | |
| from IPython.display import display, HTML | |
| display(HTML(piece_attn_html(*args, **kwargs))) | |
| def _eval_dropouts(mod): | |
| module_name = mod.__class__.__name__ | |
| if 'Dropout' in module_name or 'BatchNorm' in module_name: mod.training = False | |
| for module in mod.children(): _eval_dropouts(module) | |
| class TextClassificationInterpretation(ClassificationInterpretation): | |
| """Provides an interpretation of classification based on input sensitivity. | |
| This was designed for AWD-LSTM only for the moment, because Transformer already has its own attentional model. | |
| """ | |
| def __init__(self, learn: Learner, preds: Tensor, y_true: Tensor, losses: Tensor, ds_type: DatasetType = DatasetType.Valid): | |
| super(TextClassificationInterpretation, self).__init__(learn,preds,y_true,losses,ds_type) | |
| self.model = learn.model | |
| def from_learner(cls, learn: Learner, ds_type:DatasetType=DatasetType.Valid, activ:nn.Module=None): | |
| "Gets preds, y_true, losses to construct base class from a learner" | |
| preds_res = learn.get_preds(ds_type=ds_type, activ=activ, with_loss=True, ordered=True) | |
| return cls(learn, *preds_res) | |
| def intrinsic_attention(self, text:str, class_id:int=None): | |
| """Calculate the intrinsic attention of the input w.r.t to an output `class_id`, or the classification given by the model if `None`. | |
| For reference, see the Sequential Jacobian session at https://www.cs.toronto.edu/~graves/preprint.pdf | |
| """ | |
| self.model.train() | |
| _eval_dropouts(self.model) | |
| self.model.zero_grad() | |
| self.model.reset() | |
| ids = self.data.one_item(text)[0] | |
| emb = self.model[0].module.encoder(ids).detach().requires_grad_(True) | |
| lstm_output = self.model[0].module(emb, from_embeddings=True) | |
| self.model.eval() | |
| cl = self.model[1](lstm_output + (torch.zeros_like(ids).byte(),))[0].softmax(dim=-1) | |
| if class_id is None: class_id = cl.argmax() | |
| cl[0][class_id].backward() | |
| attn = emb.grad.squeeze().abs().sum(dim=-1) | |
| attn /= attn.max() | |
| tokens = self.data.single_ds.reconstruct(ids[0]) | |
| return tokens, attn | |
| def html_intrinsic_attention(self, text:str, class_id:int=None, **kwargs)->str: | |
| text, attn = self.intrinsic_attention(text, class_id) | |
| return piece_attn_html(text.text.split(), to_np(attn), **kwargs) | |
| def show_intrinsic_attention(self, text:str, class_id:int=None, **kwargs)->None: | |
| text, attn = self.intrinsic_attention(text, class_id) | |
| show_piece_attn(text.text.split(), to_np(attn), **kwargs) | |
| def show_top_losses(self, k:int, max_len:int=70)->None: | |
| """ | |
| Create a tabulation showing the first `k` texts in top_losses along with their prediction, actual,loss, and probability of | |
| actual class. `max_len` is the maximum number of tokens displayed. | |
| """ | |
| from IPython.display import display, HTML | |
| items = [] | |
| tl_val,tl_idx = self.top_losses() | |
| for i,idx in enumerate(tl_idx): | |
| if k <= 0: break | |
| k -= 1 | |
| tx,cl = self.data.dl(self.ds_type).dataset[idx] | |
| cl = cl.data | |
| classes = self.data.classes | |
| txt = ' '.join(tx.text.split(' ')[:max_len]) if max_len is not None else tx.text | |
| tmp = [txt, f'{classes[self.pred_class[idx]]}', f'{classes[cl]}', f'{self.losses[idx]:.2f}', | |
| f'{self.preds[idx][cl]:.2f}'] | |
| items.append(tmp) | |
| items = np.array(items) | |
| names = ['Text', 'Prediction', 'Actual', 'Loss', 'Probability'] | |
| df = pd.DataFrame({n:items[:,i] for i,n in enumerate(names)}, columns=names) | |
| with pd.option_context('display.max_colwidth', -1): | |
| display(HTML(df.to_html(index=False))) | |