| | import unittest |
| |
|
| | from detectron2.solver.build import _expand_param_groups, reduce_param_groups |
| |
|
| |
|
| | class TestOptimizer(unittest.TestCase): |
| | def testExpandParamsGroups(self): |
| | params = [ |
| | { |
| | "params": ["p1", "p2", "p3", "p4"], |
| | "lr": 1.0, |
| | "weight_decay": 3.0, |
| | }, |
| | { |
| | "params": ["p2", "p3", "p5"], |
| | "lr": 2.0, |
| | "momentum": 2.0, |
| | }, |
| | { |
| | "params": ["p1"], |
| | "weight_decay": 4.0, |
| | }, |
| | ] |
| | out = _expand_param_groups(params) |
| | gt = [ |
| | dict(params=["p1"], lr=1.0, weight_decay=4.0), |
| | dict(params=["p2"], lr=2.0, weight_decay=3.0, momentum=2.0), |
| | dict(params=["p3"], lr=2.0, weight_decay=3.0, momentum=2.0), |
| | dict(params=["p4"], lr=1.0, weight_decay=3.0), |
| | dict(params=["p5"], lr=2.0, momentum=2.0), |
| | ] |
| | self.assertEqual(out, gt) |
| |
|
| | def testReduceParamGroups(self): |
| | params = [ |
| | dict(params=["p1"], lr=1.0, weight_decay=4.0), |
| | dict(params=["p2", "p6"], lr=2.0, weight_decay=3.0, momentum=2.0), |
| | dict(params=["p3"], lr=2.0, weight_decay=3.0, momentum=2.0), |
| | dict(params=["p4"], lr=1.0, weight_decay=3.0), |
| | dict(params=["p5"], lr=2.0, momentum=2.0), |
| | ] |
| | gt_groups = [ |
| | { |
| | "lr": 1.0, |
| | "weight_decay": 4.0, |
| | "params": ["p1"], |
| | }, |
| | { |
| | "lr": 2.0, |
| | "weight_decay": 3.0, |
| | "momentum": 2.0, |
| | "params": ["p2", "p6", "p3"], |
| | }, |
| | { |
| | "lr": 1.0, |
| | "weight_decay": 3.0, |
| | "params": ["p4"], |
| | }, |
| | { |
| | "lr": 2.0, |
| | "momentum": 2.0, |
| | "params": ["p5"], |
| | }, |
| | ] |
| | out = reduce_param_groups(params) |
| | self.assertEqual(out, gt_groups) |
| |
|