Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn.functional as F | |
| # Soft aggregation from STM | |
| def aggregate(prob, dim, return_logits=False): | |
| new_prob = torch.cat([ | |
| torch.prod(1-prob, dim=dim, keepdim=True), | |
| prob | |
| ], dim).clamp(1e-7, 1-1e-7) | |
| logits = torch.log((new_prob /(1-new_prob))) | |
| prob = F.softmax(logits, dim=dim) | |
| if return_logits: | |
| return logits, prob | |
| else: | |
| return prob |