landmarkclassifier / src /optimization.py
mewhenmonkeyavatar's picture
real initial commit.
4b7c478
Raw
History Blame Contribute Delete
3.97 kB
import torch
import torch.nn as nn
import torch.optim
def get_loss():
"""
Get an instance of the CrossEntropyLoss (useful for classification),
optionally moving it to the GPU if use_cuda is set to True
"""
# MY CODE HERE: select a loss appropriate for classification
loss = nn.CrossEntropyLoss()
if torch.cuda.is_available():
loss.cuda()
return loss
def get_optimizer(
model: nn.Module,
optimizer: str = "SGD",
learning_rate: float = 0.01,
momentum: float = 0.5,
weight_decay: float = 0,
):
"""
Returns an optimizer instance
:param model: the model to optimize
:param optimizer: one of 'SGD' or 'Adam'
:param learning_rate: the learning rate
:param momentum: the momentum (if the optimizer uses it)
:param weight_decay: regularization coefficient
"""
if optimizer.lower() == "sgd":
# MY CODE HERE: create an instance of the SGD
# optimizer. Use the input parameters learning_rate, momentum
# and weight_decay
opt = torch.optim.SGD(
model.parameters(),
lr=learning_rate, momentum=momentum, weight_decay=weight_decay
)
elif optimizer.lower() == "adam":
# MY CODE HERE: create an instance of the Adam
# optimizer. Use the input parameters learning_rate, momentum
# and weight_decay
opt = torch.optim.Adam(
model.parameters(),
lr=learning_rate, weight_decay=weight_decay
# NOTE: Adam does not support momentum parameter as it calculates the momentum itself.
)
elif optimizer.lower() == "adamw":
# MY CODE HERE: create an instance of the AdamW (recommended by session lead)
opt = torch.optim.AdamW(
model.parameters(),
lr=learning_rate, weight_decay=weight_decay
# NOTE: AdamW does not support momentum parameter as it calculates the momentum itself.
)
else:
raise ValueError(f"Optimizer {optimizer} not supported")
return opt
######################################################################################
# TESTS
######################################################################################
import pytest
@pytest.fixture(scope="session")
def fake_model():
return nn.Linear(16, 256)
def test_get_loss():
loss = get_loss()
assert isinstance(
loss, nn.CrossEntropyLoss
), f"Expected cross entropy loss, found {type(loss)}"
def test_get_optimizer_type(fake_model):
opt = get_optimizer(fake_model)
assert isinstance(opt, torch.optim.SGD), f"Expected SGD optimizer, got {type(opt)}"
def test_get_optimizer_is_linked_with_model(fake_model):
opt = get_optimizer(fake_model)
assert opt.param_groups[0]["params"][0].shape == torch.Size([256, 16])
def test_get_optimizer_returns_adam(fake_model):
opt = get_optimizer(fake_model, optimizer="adam")
assert opt.param_groups[0]["params"][0].shape == torch.Size([256, 16])
assert isinstance(opt, torch.optim.Adam), f"Expected SGD optimizer, got {type(opt)}"
def test_get_optimizer_sets_learning_rate(fake_model):
opt = get_optimizer(fake_model, optimizer="adam", learning_rate=0.123)
assert (
opt.param_groups[0]["lr"] == 0.123
), "get_optimizer is not setting the learning rate appropriately. Check your code."
def test_get_optimizer_sets_momentum(fake_model):
opt = get_optimizer(fake_model, optimizer="SGD", momentum=0.123)
assert (
opt.param_groups[0]["momentum"] == 0.123
), "get_optimizer is not setting the momentum appropriately. Check your code."
def test_get_optimizer_sets_weight_decat(fake_model):
opt = get_optimizer(fake_model, optimizer="SGD", weight_decay=0.123)
assert (
opt.param_groups[0]["weight_decay"] == 0.123
), "get_optimizer is not setting the weight_decay appropriately. Check your code."