| import os |
| import sys |
| if __name__ == "__main__": |
| from train import * |
| else: |
| from .train import * |
|
|
|
|
|
|
|
|
| try: |
| test_item = sys.argv[1] |
| except IndexError: |
| assert __name__ == "__main__" |
| test_item = "./checkpoint_test" |
| test_items = [] |
| if os.path.isdir(test_item): |
| for item in os.listdir(test_item): |
| item = os.path.join(test_item, item) |
| test_items.append(item) |
| elif os.path.isfile(test_item): |
| test_items.append(test_item) |
|
|
|
|
|
|
|
|
| for item in test_items: |
| state = torch.load(item, map_location="cpu") |
| model.load_state_dict({key: value.to(torch.float32).to(device) for key, value in state.items()}) |
| loss, acc, all_targets, all_predicts = test(model=model) |