| from math import pi |
|
|
| try: |
| import torch |
| except ImportError: |
| torch = None |
|
|
| try: |
| import numpy |
| except ImportError: |
| numpy = None |
|
|
| if numpy is None and torch is None: |
| raise ImportError("Must have either Numpy or PyTorch but both not found") |
|
|
|
|
| def set_framework_dependencies(x): |
| if type(x) is numpy.ndarray: |
| to_dtype = lambda a: a |
| fw = numpy |
| else: |
| to_dtype = lambda a: a.to(x.dtype) |
| fw = torch |
| eps = fw.finfo(fw.float32).eps |
| return fw, to_dtype, eps |
|
|
|
|
| def support_sz(sz): |
| def wrapper(f): |
| f.support_sz = sz |
| return f |
| return wrapper |
|
|
| @support_sz(4) |
| def cubic(x): |
| fw, to_dtype, eps = set_framework_dependencies(x) |
| absx = fw.abs(x) |
| absx2 = absx ** 2 |
| absx3 = absx ** 3 |
| return ((1.5 * absx3 - 2.5 * absx2 + 1.) * to_dtype(absx <= 1.) + |
| (-0.5 * absx3 + 2.5 * absx2 - 4. * absx + 2.) * |
| to_dtype((1. < absx) & (absx <= 2.))) |
|
|
| @support_sz(4) |
| def lanczos2(x): |
| fw, to_dtype, eps = set_framework_dependencies(x) |
| return (((fw.sin(pi * x) * fw.sin(pi * x / 2) + eps) / |
| ((pi**2 * x**2 / 2) + eps)) * to_dtype(abs(x) < 2)) |
|
|
| @support_sz(6) |
| def lanczos3(x): |
| fw, to_dtype, eps = set_framework_dependencies(x) |
| return (((fw.sin(pi * x) * fw.sin(pi * x / 3) + eps) / |
| ((pi**2 * x**2 / 3) + eps)) * to_dtype(abs(x) < 3)) |
|
|
| @support_sz(2) |
| def linear(x): |
| fw, to_dtype, eps = set_framework_dependencies(x) |
| return ((x + 1) * to_dtype((-1 <= x) & (x < 0)) + (1 - x) * |
| to_dtype((0 <= x) & (x <= 1))) |
|
|
| @support_sz(1) |
| def box(x): |
| fw, to_dtype, eps = set_framework_dependencies(x) |
| return to_dtype((-1 <= x) & (x < 0)) + to_dtype((0 <= x) & (x <= 1)) |
|
|