Spaces:
Sleeping
Sleeping
| import torch.nn as nn | |
| def get_loss(name="cross_entropy", device="cuda:0"): | |
| print(f"Using loss: '{LOSSES[name]}'") | |
| return LOSSES[name].to(device) | |
| LOSSES = { | |
| "binary_ce": nn.BCEWithLogitsLoss(), | |
| "cross_entropy": nn.CrossEntropyLoss() | |
| } | |