File size: 2,932 Bytes
006869b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import json
from collections import OrderedDict
from itertools import repeat
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple
import pandas as pd
import PIL.Image as Image
import torch
def read_json(fname):
fname = Path(fname)
with fname.open('rt') as handle:
return json.load(handle, object_hook=OrderedDict)
def write_json(content, fname):
fname = Path(fname)
with fname.open('wt') as handle:
json.dump(content, handle, indent=4, sort_keys=False)
def pil_loader(fname: str) -> Image.Image:
return Image.open(fname)
def prepare_device(n_gpu_use):
"""
setup GPU device if available. get gpu device indices which are used for DataParallel
"""
n_gpu = torch.cuda.device_count()
if n_gpu_use > 0 and n_gpu == 0:
print("Warning: There\'s no GPU available on this machine,"
"training will be performed on CPU.")
n_gpu_use = 0
if n_gpu_use > n_gpu:
print(f"Warning: The number of GPU\'s configured to use is {n_gpu_use}, but only {n_gpu} are "
"available on this machine.")
n_gpu_use = n_gpu
device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
list_ids = list(range(n_gpu_use))
return device, list_ids
class TransformMultiple:
def __init__(self, transform: Optional[Callable] = None) -> None:
self.transform = transform
def __call__(self, data: Any) -> Tuple:
if self.transform is not None:
cat_data = torch.cat(data)
cat_data = self.transform(cat_data)
return torch.split(cat_data, [x.size()[0] for x in data])
return data
def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
lines = transform.__repr__().splitlines()
return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
def __repr__(self) -> str:
body = [self.__class__.__name__]
if self.transform is not None:
body += self._format_transform_repr(self.transform, "Transform: ")
return "\n".join(body)
class MetricTracker:
def __init__(self, *keys, writer=None):
self.writer = writer
self._data = pd.DataFrame(index=keys, columns=['total', 'counts', 'average'])
self.reset()
def reset(self):
for col in self._data.columns:
self._data[col].values[:] = 0
def update(self, key, value, n=1):
self._data.total[key] += value * n
self._data.counts[key] += n
self._data.average[key] = self._data.total[key] / self._data.counts[key]
def add_scalers(self):
if self.writer is not None:
for key in self._data.index:
self.writer.add_scalar(key, self._data.average[key])
def avg(self, key):
return self._data.average[key]
def result(self):
return dict(self._data.average)
|