|
|
"`fastai.core` contains essential util functions to format and split data" |
|
|
from .imports.core import * |
|
|
|
|
|
warnings.filterwarnings("ignore", message="numpy.dtype size changed") |
|
|
warnings.filterwarnings("ignore", message="numpy.ufunc size changed") |
|
|
|
|
|
AnnealFunc = Callable[[Number,Number,float], Number] |
|
|
ArgStar = Collection[Any] |
|
|
BatchSamples = Collection[Tuple[Collection[int], int]] |
|
|
DataFrameOrChunks = Union[DataFrame, pd.io.parsers.TextFileReader] |
|
|
FilePathList = Collection[Path] |
|
|
Floats = Union[float, Collection[float]] |
|
|
ImgLabel = str |
|
|
ImgLabels = Collection[ImgLabel] |
|
|
IntsOrStrs = Union[int, Collection[int], str, Collection[str]] |
|
|
KeyFunc = Callable[[int], int] |
|
|
KWArgs = Dict[str,Any] |
|
|
ListOrItem = Union[Collection[Any],int,float,str] |
|
|
ListRules = Collection[Callable[[str],str]] |
|
|
ListSizes = Collection[Tuple[int,int]] |
|
|
NPArrayableList = Collection[Union[np.ndarray, list]] |
|
|
NPArrayList = Collection[np.ndarray] |
|
|
NPArrayMask = np.ndarray |
|
|
NPImage = np.ndarray |
|
|
OptDataFrame = Optional[DataFrame] |
|
|
OptListOrItem = Optional[ListOrItem] |
|
|
OptRange = Optional[Tuple[float,float]] |
|
|
OptStrTuple = Optional[Tuple[str,str]] |
|
|
OptStats = Optional[Tuple[np.ndarray, np.ndarray]] |
|
|
PathOrStr = Union[Path,str] |
|
|
PathLikeOrBinaryStream = Union[PathOrStr, BufferedWriter, BytesIO] |
|
|
PBar = Union[MasterBar, ProgressBar] |
|
|
Point=Tuple[float,float] |
|
|
Points=Collection[Point] |
|
|
Sizes = List[List[int]] |
|
|
SplitArrayList = List[Tuple[np.ndarray,np.ndarray]] |
|
|
StartOptEnd=Union[float,Tuple[float,float]] |
|
|
StrList = Collection[str] |
|
|
Tokens = Collection[Collection[str]] |
|
|
OptStrList = Optional[StrList] |
|
|
np.set_printoptions(precision=6, threshold=50, edgeitems=4, linewidth=120) |
|
|
|
|
|
def num_cpus()->int: |
|
|
"Get number of cpus" |
|
|
try: return len(os.sched_getaffinity(0)) |
|
|
except AttributeError: return os.cpu_count() |
|
|
|
|
|
_default_cpus = min(16, num_cpus()) |
|
|
defaults = SimpleNamespace(cpus=_default_cpus, cmap='viridis', return_fig=False, silent=False) |
|
|
|
|
|
def is_listy(x:Any)->bool: return isinstance(x, (tuple,list)) |
|
|
def is_tuple(x:Any)->bool: return isinstance(x, tuple) |
|
|
def is_dict(x:Any)->bool: return isinstance(x, dict) |
|
|
def is_pathlike(x:Any)->bool: return isinstance(x, (str,Path)) |
|
|
def noop(x): return x |
|
|
|
|
|
class PrePostInitMeta(type): |
|
|
"A metaclass that calls optional `__pre_init__` and `__post_init__` methods" |
|
|
def __new__(cls, name, bases, dct): |
|
|
x = super().__new__(cls, name, bases, dct) |
|
|
old_init = x.__init__ |
|
|
def _pass(self): pass |
|
|
@functools.wraps(old_init) |
|
|
def _init(self,*args,**kwargs): |
|
|
self.__pre_init__() |
|
|
old_init(self, *args,**kwargs) |
|
|
self.__post_init__() |
|
|
x.__init__ = _init |
|
|
if not hasattr(x,'__pre_init__'): x.__pre_init__ = _pass |
|
|
if not hasattr(x,'__post_init__'): x.__post_init__ = _pass |
|
|
return x |
|
|
|
|
|
def chunks(l:Collection, n:int)->Iterable: |
|
|
"Yield successive `n`-sized chunks from `l`." |
|
|
for i in range(0, len(l), n): yield l[i:i+n] |
|
|
|
|
|
def recurse(func:Callable, x:Any, *args, **kwargs)->Any: |
|
|
if is_listy(x): return [recurse(func, o, *args, **kwargs) for o in x] |
|
|
if is_dict(x): return {k: recurse(func, v, *args, **kwargs) for k,v in x.items()} |
|
|
return func(x, *args, **kwargs) |
|
|
|
|
|
def first_el(x: Any)->Any: |
|
|
"Recursively get the first element of `x`." |
|
|
if is_listy(x): return first_el(x[0]) |
|
|
if is_dict(x): return first_el(x[list(x.keys())[0]]) |
|
|
return x |
|
|
|
|
|
def to_int(b:Any)->Union[int,List[int]]: |
|
|
"Recursively convert `b` to an int or list/dict of ints; raises exception if not convertible." |
|
|
return recurse(lambda x: int(x), b) |
|
|
|
|
|
def ifnone(a:Any,b:Any)->Any: |
|
|
"`a` if `a` is not None, otherwise `b`." |
|
|
return b if a is None else a |
|
|
|
|
|
def is1d(a:Collection)->bool: |
|
|
"Return `True` if `a` is one-dimensional" |
|
|
return len(a.shape) == 1 if hasattr(a, 'shape') else len(np.array(a).shape) == 1 |
|
|
|
|
|
def uniqueify(x:Series, sort:bool=False)->List: |
|
|
"Return sorted unique values of `x`." |
|
|
res = list(OrderedDict.fromkeys(x).keys()) |
|
|
if sort: res.sort() |
|
|
return res |
|
|
|
|
|
def idx_dict(a): |
|
|
"Create a dictionary value to index from `a`." |
|
|
return {v:k for k,v in enumerate(a)} |
|
|
|
|
|
def find_classes(folder:Path)->FilePathList: |
|
|
"List of label subdirectories in imagenet-style `folder`." |
|
|
classes = [d for d in folder.iterdir() |
|
|
if d.is_dir() and not d.name.startswith('.')] |
|
|
assert(len(classes)>0) |
|
|
return sorted(classes, key=lambda d: d.name) |
|
|
|
|
|
def arrays_split(mask:NPArrayMask, *arrs:NPArrayableList)->SplitArrayList: |
|
|
"Given `arrs` is [a,b,...] and `mask`index - return[(a[mask],a[~mask]),(b[mask],b[~mask]),...]." |
|
|
assert all([len(arr)==len(arrs[0]) for arr in arrs]), 'All arrays should have same length' |
|
|
mask = array(mask) |
|
|
return list(zip(*[(a[mask],a[~mask]) for a in map(np.array, arrs)])) |
|
|
|
|
|
def random_split(valid_pct:float, *arrs:NPArrayableList)->SplitArrayList: |
|
|
"Randomly split `arrs` with `valid_pct` ratio. good for creating validation set." |
|
|
assert (valid_pct>=0 and valid_pct<=1), 'Validation set percentage should be between 0 and 1' |
|
|
is_train = np.random.uniform(size=(len(arrs[0]),)) > valid_pct |
|
|
return arrays_split(is_train, *arrs) |
|
|
|
|
|
def listify(p:OptListOrItem=None, q:OptListOrItem=None): |
|
|
"Make `p` listy and the same length as `q`." |
|
|
if p is None: p=[] |
|
|
elif isinstance(p, str): p = [p] |
|
|
elif not isinstance(p, Iterable): p = [p] |
|
|
|
|
|
else: |
|
|
try: a = len(p) |
|
|
except: p = [p] |
|
|
n = q if type(q)==int else len(p) if q is None else len(q) |
|
|
if len(p)==1: p = p * n |
|
|
assert len(p)==n, f'List len mismatch ({len(p)} vs {n})' |
|
|
return list(p) |
|
|
|
|
|
_camel_re1 = re.compile('(.)([A-Z][a-z]+)') |
|
|
_camel_re2 = re.compile('([a-z0-9])([A-Z])') |
|
|
def camel2snake(name:str)->str: |
|
|
"Change `name` from camel to snake style." |
|
|
s1 = re.sub(_camel_re1, r'\1_\2', name) |
|
|
return re.sub(_camel_re2, r'\1_\2', s1).lower() |
|
|
|
|
|
def even_mults(start:float, stop:float, n:int)->np.ndarray: |
|
|
"Build log-stepped array from `start` to `stop` in `n` steps." |
|
|
mult = stop/start |
|
|
step = mult**(1/(n-1)) |
|
|
return np.array([start*(step**i) for i in range(n)]) |
|
|
|
|
|
def extract_kwargs(names:Collection[str], kwargs:KWArgs): |
|
|
"Extract the keys in `names` from the `kwargs`." |
|
|
new_kwargs = {} |
|
|
for arg_name in names: |
|
|
if arg_name in kwargs: |
|
|
arg_val = kwargs.pop(arg_name) |
|
|
new_kwargs[arg_name] = arg_val |
|
|
return new_kwargs, kwargs |
|
|
|
|
|
def partition(a:Collection, sz:int)->List[Collection]: |
|
|
"Split iterables `a` in equal parts of size `sz`" |
|
|
return [a[i:i+sz] for i in range(0, len(a), sz)] |
|
|
|
|
|
def partition_by_cores(a:Collection, n_cpus:int)->List[Collection]: |
|
|
"Split data in `a` equally among `n_cpus` cores" |
|
|
return partition(a, len(a)//n_cpus + 1) |
|
|
|
|
|
def series2cat(df:DataFrame, *col_names): |
|
|
"Categorifies the columns `col_names` in `df`." |
|
|
for c in listify(col_names): df[c] = df[c].astype('category').cat.as_ordered() |
|
|
|
|
|
TfmList = Union[Callable, Collection[Callable]] |
|
|
|
|
|
class ItemBase(): |
|
|
"Base item type in the fastai library." |
|
|
def __init__(self, data:Any): self.data=self.obj=data |
|
|
def __repr__(self)->str: return f'{self.__class__.__name__} {str(self)}' |
|
|
def show(self, ax:plt.Axes, **kwargs): |
|
|
"Subclass this method if you want to customize the way this `ItemBase` is shown on `ax`." |
|
|
ax.set_title(str(self)) |
|
|
def apply_tfms(self, tfms:Collection, **kwargs): |
|
|
"Subclass this method if you want to apply data augmentation with `tfms` to this `ItemBase`." |
|
|
if tfms: raise Exception(f"Not implemented: you can't apply transforms to this type of item ({self.__class__.__name__})") |
|
|
return self |
|
|
def __eq__(self, other): return recurse_eq(self.data, other.data) |
|
|
|
|
|
def recurse_eq(arr1, arr2): |
|
|
if is_listy(arr1): return is_listy(arr2) and len(arr1) == len(arr2) and np.all([recurse_eq(x,y) for x,y in zip(arr1,arr2)]) |
|
|
else: return np.all(np.atleast_1d(arr1 == arr2)) |
|
|
|
|
|
def download_url(url:str, dest:str, overwrite:bool=False, pbar:ProgressBar=None, |
|
|
show_progress=True, chunk_size=1024*1024, timeout=4, retries=5)->None: |
|
|
"Download `url` to `dest` unless it exists and not `overwrite`." |
|
|
if os.path.exists(dest) and not overwrite: return |
|
|
|
|
|
s = requests.Session() |
|
|
s.mount('http://',requests.adapters.HTTPAdapter(max_retries=retries)) |
|
|
u = s.get(url, stream=True, timeout=timeout) |
|
|
try: file_size = int(u.headers["Content-Length"]) |
|
|
except: show_progress = False |
|
|
|
|
|
with open(dest, 'wb') as f: |
|
|
nbytes = 0 |
|
|
if show_progress: pbar = progress_bar(range(file_size), auto_update=False, leave=False, parent=pbar) |
|
|
try: |
|
|
for chunk in u.iter_content(chunk_size=chunk_size): |
|
|
nbytes += len(chunk) |
|
|
if show_progress: pbar.update(nbytes) |
|
|
f.write(chunk) |
|
|
except requests.exceptions.ConnectionError as e: |
|
|
fname = url.split('/')[-1] |
|
|
from fastai.datasets import Config |
|
|
data_dir = Config().data_path() |
|
|
timeout_txt =(f'\n Download of {url} has failed after {retries} retries\n' |
|
|
f' Fix the download manually:\n' |
|
|
f'$ mkdir -p {data_dir}\n' |
|
|
f'$ cd {data_dir}\n' |
|
|
f'$ wget -c {url}\n' |
|
|
f'$ tar -zxvf {fname}\n\n' |
|
|
f'And re-run your code once the download is successful\n') |
|
|
print(timeout_txt) |
|
|
import sys;sys.exit(1) |
|
|
|
|
|
def range_of(x): |
|
|
"Create a range from 0 to `len(x)`." |
|
|
return list(range(len(x))) |
|
|
def arange_of(x): |
|
|
"Same as `range_of` but returns an array." |
|
|
return np.arange(len(x)) |
|
|
|
|
|
Path.ls = lambda x: list(x.iterdir()) |
|
|
|
|
|
def join_path(fname:PathOrStr, path:PathOrStr='.')->Path: |
|
|
"Return `Path(path)/Path(fname)`, `path` defaults to current dir." |
|
|
return Path(path)/Path(fname) |
|
|
|
|
|
def join_paths(fnames:FilePathList, path:PathOrStr='.')->Collection[Path]: |
|
|
"Join `path` to every file name in `fnames`." |
|
|
path = Path(path) |
|
|
return [join_path(o,path) for o in fnames] |
|
|
|
|
|
def loadtxt_str(path:PathOrStr)->np.ndarray: |
|
|
"Return `ndarray` of `str` of lines of text from `path`." |
|
|
with open(path, 'r') as f: lines = f.readlines() |
|
|
return np.array([l.strip() for l in lines]) |
|
|
|
|
|
def save_texts(fname:PathOrStr, texts:Collection[str]): |
|
|
"Save in `fname` the content of `texts`." |
|
|
with open(fname, 'w') as f: |
|
|
for t in texts: f.write(f'{t}\n') |
|
|
|
|
|
def df_names_to_idx(names:IntsOrStrs, df:DataFrame): |
|
|
"Return the column indexes of `names` in `df`." |
|
|
if not is_listy(names): names = [names] |
|
|
if isinstance(names[0], int): return names |
|
|
return [df.columns.get_loc(c) for c in names] |
|
|
|
|
|
def one_hot(x:Collection[int], c:int): |
|
|
"One-hot encode `x` with `c` classes." |
|
|
res = np.zeros((c,), np.float32) |
|
|
res[listify(x)] = 1. |
|
|
return res |
|
|
|
|
|
def index_row(a:Union[Collection,pd.DataFrame,pd.Series], idxs:Collection[int])->Any: |
|
|
"Return the slice of `a` corresponding to `idxs`." |
|
|
if a is None: return a |
|
|
if isinstance(a,(pd.DataFrame,pd.Series)): |
|
|
res = a.iloc[idxs] |
|
|
if isinstance(res,(pd.DataFrame,pd.Series)): return res.copy() |
|
|
return res |
|
|
return a[idxs] |
|
|
|
|
|
def func_args(func)->bool: |
|
|
"Return the arguments of `func`." |
|
|
code = func.__code__ |
|
|
return code.co_varnames[:code.co_argcount] |
|
|
|
|
|
def has_arg(func, arg)->bool: |
|
|
"Check if `func` accepts `arg`." |
|
|
return arg in func_args(func) |
|
|
|
|
|
def split_kwargs_by_func(kwargs, func): |
|
|
"Split `kwargs` between those expected by `func` and the others." |
|
|
args = func_args(func) |
|
|
func_kwargs = {a:kwargs.pop(a) for a in args if a in kwargs} |
|
|
return func_kwargs, kwargs |
|
|
|
|
|
def array(a, dtype:type=None, **kwargs)->np.ndarray: |
|
|
"Same as `np.array` but also handles generators. `kwargs` are passed to `np.array` with `dtype`." |
|
|
if not isinstance(a, collections.abc.Sized) and not getattr(a,'__array_interface__',False): |
|
|
a = list(a) |
|
|
if np.int_==np.int32 and dtype is None and is_listy(a) and len(a) and isinstance(a[0],int): |
|
|
dtype=np.int64 |
|
|
return np.array(a, dtype=dtype, **kwargs) |
|
|
|
|
|
class EmptyLabel(ItemBase): |
|
|
"Should be used for a dummy label." |
|
|
def __init__(self): self.obj,self.data = 0,0 |
|
|
def __str__(self): return '' |
|
|
def __hash__(self): return hash(str(self)) |
|
|
|
|
|
class Category(ItemBase): |
|
|
"Basic class for single classification labels." |
|
|
def __init__(self,data,obj): self.data,self.obj = data,obj |
|
|
def __int__(self): return int(self.data) |
|
|
def __str__(self): return str(self.obj) |
|
|
def __hash__(self): return hash(str(self)) |
|
|
|
|
|
class MultiCategory(ItemBase): |
|
|
"Basic class for multi-classification labels." |
|
|
def __init__(self,data,obj,raw): self.data,self.obj,self.raw = data,obj,raw |
|
|
def __str__(self): return ';'.join([str(o) for o in self.obj]) |
|
|
def __hash__(self): return hash(str(self)) |
|
|
|
|
|
class FloatItem(ItemBase): |
|
|
"Basic class for float items." |
|
|
def __init__(self,obj): self.data,self.obj = np.array(obj).astype(np.float32),obj |
|
|
def __str__(self): return str(self.obj) |
|
|
def __hash__(self): return hash(str(self)) |
|
|
|
|
|
def _treat_html(o:str)->str: |
|
|
o = str(o) |
|
|
to_replace = {'\n':'\\n', '<':'<', '>':'>', '&':'&'} |
|
|
for k,v in to_replace.items(): o = o.replace(k, v) |
|
|
return o |
|
|
|
|
|
def text2html_table(items:Collection[Collection[str]])->str: |
|
|
"Put the texts in `items` in an HTML table, `widths` are the widths of the columns in %." |
|
|
html_code = f"""<table border="1" class="dataframe">""" |
|
|
html_code += f""" <thead>\n <tr style="text-align: right;">\n""" |
|
|
for i in items[0]: html_code += f" <th>{_treat_html(i)}</th>" |
|
|
html_code += f" </tr>\n </thead>\n <tbody>" |
|
|
html_code += " <tbody>" |
|
|
for line in items[1:]: |
|
|
html_code += " <tr>" |
|
|
for i in line: html_code += f" <td>{_treat_html(i)}</td>" |
|
|
html_code += " </tr>" |
|
|
html_code += " </tbody>\n</table>" |
|
|
return html_code |
|
|
|
|
|
def parallel(func, arr:Collection, max_workers:int=None, leave=False): |
|
|
"Call `func` on every element of `arr` in parallel using `max_workers`." |
|
|
max_workers = ifnone(max_workers, defaults.cpus) |
|
|
if max_workers<2: results = [func(o,i) for i,o in progress_bar(enumerate(arr), total=len(arr), leave=leave)] |
|
|
else: |
|
|
with ProcessPoolExecutor(max_workers=max_workers) as ex: |
|
|
futures = [ex.submit(func,o,i) for i,o in enumerate(arr)] |
|
|
results = [] |
|
|
for f in progress_bar(concurrent.futures.as_completed(futures), total=len(arr), leave=leave): |
|
|
results.append(f.result()) |
|
|
if any([o is not None for o in results]): return results |
|
|
|
|
|
def subplots(rows:int, cols:int, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, title=None, **kwargs): |
|
|
"Like `plt.subplots` but with consistent axs shape, `kwargs` passed to `fig.suptitle` with `title`" |
|
|
figsize = ifnone(figsize, (imgsize*cols, imgsize*rows)) |
|
|
fig, axs = plt.subplots(rows,cols,figsize=figsize) |
|
|
if rows==cols==1: axs = [[axs]] |
|
|
elif (rows==1 and cols!=1) or (cols==1 and rows!=1): axs = [axs] |
|
|
if title is not None: fig.suptitle(title, **kwargs) |
|
|
return array(axs) |
|
|
|
|
|
def show_some(items:Collection, n_max:int=5, sep:str=','): |
|
|
"Return the representation of the first `n_max` elements in `items`." |
|
|
if items is None or len(items) == 0: return '' |
|
|
res = sep.join([f'{o}' for o in items[:n_max]]) |
|
|
if len(items) > n_max: res += '...' |
|
|
return res |
|
|
|
|
|
def get_tmp_file(dir=None): |
|
|
"Create and return a tmp filename, optionally at a specific path. `os.remove` when done with it." |
|
|
with tempfile.NamedTemporaryFile(delete=False, dir=dir) as f: return f.name |
|
|
|
|
|
def compose(funcs:List[Callable])->Callable: |
|
|
"Compose `funcs`" |
|
|
def compose_(funcs, x, *args, **kwargs): |
|
|
for f in listify(funcs): x = f(x, *args, **kwargs) |
|
|
return x |
|
|
return partial(compose_, funcs) |
|
|
|
|
|
class PrettyString(str): |
|
|
"Little hack to get strings to show properly in Jupyter." |
|
|
def __repr__(self): return self |
|
|
|
|
|
def float_or_x(x): |
|
|
"Tries to convert to float, returns x if it can't" |
|
|
try: return float(x) |
|
|
except:return x |
|
|
|
|
|
def bunzip(fn:PathOrStr): |
|
|
"bunzip `fn`, raising exception if output already exists" |
|
|
fn = Path(fn) |
|
|
assert fn.exists(), f"{fn} doesn't exist" |
|
|
out_fn = fn.with_suffix('') |
|
|
assert not out_fn.exists(), f"{out_fn} already exists" |
|
|
with bz2.BZ2File(fn, 'rb') as src, out_fn.open('wb') as dst: |
|
|
for d in iter(lambda: src.read(1024*1024), b''): dst.write(d) |
|
|
|
|
|
@contextmanager |
|
|
def working_directory(path:PathOrStr): |
|
|
"Change working directory to `path` and return to previous on exit." |
|
|
prev_cwd = Path.cwd() |
|
|
os.chdir(path) |
|
|
try: yield |
|
|
finally: os.chdir(prev_cwd) |
|
|
|
|
|
|