File size: 3,968 Bytes
4b7c478 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | import torch
import torch.nn as nn
import torch.optim
def get_loss():
"""
Get an instance of the CrossEntropyLoss (useful for classification),
optionally moving it to the GPU if use_cuda is set to True
"""
# MY CODE HERE: select a loss appropriate for classification
loss = nn.CrossEntropyLoss()
if torch.cuda.is_available():
loss.cuda()
return loss
def get_optimizer(
model: nn.Module,
optimizer: str = "SGD",
learning_rate: float = 0.01,
momentum: float = 0.5,
weight_decay: float = 0,
):
"""
Returns an optimizer instance
:param model: the model to optimize
:param optimizer: one of 'SGD' or 'Adam'
:param learning_rate: the learning rate
:param momentum: the momentum (if the optimizer uses it)
:param weight_decay: regularization coefficient
"""
if optimizer.lower() == "sgd":
# MY CODE HERE: create an instance of the SGD
# optimizer. Use the input parameters learning_rate, momentum
# and weight_decay
opt = torch.optim.SGD(
model.parameters(),
lr=learning_rate, momentum=momentum, weight_decay=weight_decay
)
elif optimizer.lower() == "adam":
# MY CODE HERE: create an instance of the Adam
# optimizer. Use the input parameters learning_rate, momentum
# and weight_decay
opt = torch.optim.Adam(
model.parameters(),
lr=learning_rate, weight_decay=weight_decay
# NOTE: Adam does not support momentum parameter as it calculates the momentum itself.
)
elif optimizer.lower() == "adamw":
# MY CODE HERE: create an instance of the AdamW (recommended by session lead)
opt = torch.optim.AdamW(
model.parameters(),
lr=learning_rate, weight_decay=weight_decay
# NOTE: AdamW does not support momentum parameter as it calculates the momentum itself.
)
else:
raise ValueError(f"Optimizer {optimizer} not supported")
return opt
######################################################################################
# TESTS
######################################################################################
import pytest
@pytest.fixture(scope="session")
def fake_model():
return nn.Linear(16, 256)
def test_get_loss():
loss = get_loss()
assert isinstance(
loss, nn.CrossEntropyLoss
), f"Expected cross entropy loss, found {type(loss)}"
def test_get_optimizer_type(fake_model):
opt = get_optimizer(fake_model)
assert isinstance(opt, torch.optim.SGD), f"Expected SGD optimizer, got {type(opt)}"
def test_get_optimizer_is_linked_with_model(fake_model):
opt = get_optimizer(fake_model)
assert opt.param_groups[0]["params"][0].shape == torch.Size([256, 16])
def test_get_optimizer_returns_adam(fake_model):
opt = get_optimizer(fake_model, optimizer="adam")
assert opt.param_groups[0]["params"][0].shape == torch.Size([256, 16])
assert isinstance(opt, torch.optim.Adam), f"Expected SGD optimizer, got {type(opt)}"
def test_get_optimizer_sets_learning_rate(fake_model):
opt = get_optimizer(fake_model, optimizer="adam", learning_rate=0.123)
assert (
opt.param_groups[0]["lr"] == 0.123
), "get_optimizer is not setting the learning rate appropriately. Check your code."
def test_get_optimizer_sets_momentum(fake_model):
opt = get_optimizer(fake_model, optimizer="SGD", momentum=0.123)
assert (
opt.param_groups[0]["momentum"] == 0.123
), "get_optimizer is not setting the momentum appropriately. Check your code."
def test_get_optimizer_sets_weight_decat(fake_model):
opt = get_optimizer(fake_model, optimizer="SGD", weight_decay=0.123)
assert (
opt.param_groups[0]["weight_decay"] == 0.123
), "get_optimizer is not setting the weight_decay appropriately. Check your code."
|