|
|
from math import pi |
|
|
|
|
|
try: |
|
|
import torch |
|
|
except ImportError: |
|
|
torch = None |
|
|
|
|
|
try: |
|
|
import numpy |
|
|
except ImportError: |
|
|
numpy = None |
|
|
|
|
|
if numpy is None and torch is None: |
|
|
raise ImportError("Must have either Numpy or PyTorch but both not found") |
|
|
|
|
|
|
|
|
def set_framework_dependencies(x): |
|
|
if type(x) is numpy.ndarray: |
|
|
to_dtype = lambda a: a |
|
|
fw = numpy |
|
|
else: |
|
|
to_dtype = lambda a: a.to(x.dtype) |
|
|
fw = torch |
|
|
eps = fw.finfo(fw.float32).eps |
|
|
return fw, to_dtype, eps |
|
|
|
|
|
|
|
|
def support_sz(sz): |
|
|
def wrapper(f): |
|
|
f.support_sz = sz |
|
|
return f |
|
|
return wrapper |
|
|
|
|
|
|
|
|
@support_sz(4) |
|
|
def cubic(x): |
|
|
fw, to_dtype, eps = set_framework_dependencies(x) |
|
|
absx = fw.abs(x) |
|
|
absx2 = absx ** 2 |
|
|
absx3 = absx ** 3 |
|
|
return ((1.5 * absx3 - 2.5 * absx2 + 1.) * to_dtype(absx <= 1.) + |
|
|
(-0.5 * absx3 + 2.5 * absx2 - 4. * absx + 2.) * |
|
|
to_dtype((1. < absx) & (absx <= 2.))) |
|
|
|
|
|
|
|
|
@support_sz(4) |
|
|
def lanczos2(x): |
|
|
fw, to_dtype, eps = set_framework_dependencies(x) |
|
|
return (((fw.sin(pi * x) * fw.sin(pi * x / 2) + eps) / |
|
|
((pi**2 * x**2 / 2) + eps)) * to_dtype(abs(x) < 2)) |
|
|
|
|
|
|
|
|
@support_sz(6) |
|
|
def lanczos3(x): |
|
|
fw, to_dtype, eps = set_framework_dependencies(x) |
|
|
return (((fw.sin(pi * x) * fw.sin(pi * x / 3) + eps) / |
|
|
((pi**2 * x**2 / 3) + eps)) * to_dtype(abs(x) < 3)) |
|
|
|
|
|
|
|
|
@support_sz(2) |
|
|
def linear(x): |
|
|
fw, to_dtype, eps = set_framework_dependencies(x) |
|
|
return ((x + 1) * to_dtype((-1 <= x) & (x < 0)) + (1 - x) * |
|
|
to_dtype((0 <= x) & (x <= 1))) |
|
|
|
|
|
@support_sz(1) |
|
|
def box(x): |
|
|
fw, to_dtype, eps = set_framework_dependencies(x) |
|
|
return to_dtype((-1 <= x) & (x < 0)) + to_dtype((0 <= x) & (x <= 1)) |