Spaces:
Runtime error
Runtime error
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch | |
| import math | |
| class Attention(nn.Module): | |
| """ | |
| Compute 'Scaled Dot Product Attention | |
| """ | |
| def forward(self, query, key, value, mask=None, dropout=None): | |
| scores = torch.matmul(query, key.transpose(-2, -1)) \ | |
| / math.sqrt(query.size(-1)) | |
| if mask is not None: | |
| scores = scores.masked_fill(mask == 0, -1e9) | |
| p_attn = F.softmax(scores, dim=-1) | |
| if dropout is not None: | |
| p_attn = dropout(p_attn) | |
| return torch.matmul(p_attn, value), p_attn | |