|
|
| import os |
| import sys |
| import pytest |
| import torch |
| from open_clip_train.main import main |
|
|
| os.environ["CUDA_VISIBLE_DEVICES"] = "" |
|
|
| if hasattr(torch._C, '_jit_set_profiling_executor'): |
| |
| |
| torch._C._jit_set_profiling_executor(True) |
| torch._C._jit_set_profiling_mode(False) |
|
|
| @pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals") |
| def test_training(): |
| main([ |
| '--save-frequency', '1', |
| '--zeroshot-frequency', '1', |
| '--dataset-type', "synthetic", |
| '--train-num-samples', '16', |
| '--warmup', '1', |
| '--batch-size', '4', |
| '--lr', '1e-3', |
| '--wd', '0.1', |
| '--epochs', '1', |
| '--workers', '2', |
| '--model', 'RN50' |
| ]) |
|
|
| @pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals") |
| def test_training_coca(): |
| main([ |
| '--save-frequency', '1', |
| '--zeroshot-frequency', '1', |
| '--dataset-type', "synthetic", |
| '--train-num-samples', '16', |
| '--warmup', '1', |
| '--batch-size', '4', |
| '--lr', '1e-3', |
| '--wd', '0.1', |
| '--epochs', '1', |
| '--workers', '2', |
| '--model', 'coca_ViT-B-32' |
| ]) |
|
|
| @pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals") |
| def test_training_mt5(): |
| main([ |
| '--save-frequency', '1', |
| '--zeroshot-frequency', '1', |
| '--dataset-type', "synthetic", |
| '--train-num-samples', '16', |
| '--warmup', '1', |
| '--batch-size', '4', |
| '--lr', '1e-3', |
| '--wd', '0.1', |
| '--epochs', '1', |
| '--workers', '2', |
| '--model', 'mt5-base-ViT-B-32', |
| '--lock-text', |
| '--lock-text-unlocked-layers', '2' |
| ]) |
|
|
|
|
|
|
| @pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals") |
| def test_training_unfreezing_vit(): |
| main([ |
| '--save-frequency', '1', |
| '--zeroshot-frequency', '1', |
| '--dataset-type', "synthetic", |
| '--train-num-samples', '16', |
| '--warmup', '1', |
| '--batch-size', '4', |
| '--lr', '1e-3', |
| '--wd', '0.1', |
| '--epochs', '1', |
| '--workers', '2', |
| '--model', 'ViT-B-32', |
| '--lock-image', |
| '--lock-image-unlocked-groups', '5', |
| '--accum-freq', '2' |
| ]) |
|
|
|
|
| @pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals") |
| def test_training_clip_with_jit(): |
| main([ |
| '--save-frequency', '1', |
| '--zeroshot-frequency', '1', |
| '--dataset-type', "synthetic", |
| '--train-num-samples', '16', |
| '--warmup', '1', |
| '--batch-size', '4', |
| '--lr', '1e-3', |
| '--wd', '0.1', |
| '--epochs', '1', |
| '--workers', '2', |
| '--model', 'ViT-B-32', |
| '--torchscript' |
| ]) |
|
|