Spaces:
Runtime error
Runtime error
Update codes.py
Browse files
codes.py
CHANGED
|
@@ -1,66 +1,173 @@
|
|
| 1 |
import os, sys
|
| 2 |
import numpy as np
|
| 3 |
-
from PIL import Image
|
| 4 |
-
import itertools
|
| 5 |
-
import glob
|
| 6 |
import random
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
import torch
|
| 8 |
-
import torchvision
|
| 9 |
-
import torchvision.transforms as transforms
|
| 10 |
-
import torch.nn as nn
|
| 11 |
-
import torch.optim as optim
|
| 12 |
-
import torch.nn.functional as F
|
| 13 |
-
from torch.nn.functional import relu as RLU
|
| 14 |
|
| 15 |
-
registration_method = 'Additive_Recurence' #{'Rawblock', 'matching_points', 'Additive_Recurence', 'Multiplicative_Recurence'} #'recurrent_matrix',
|
| 16 |
-
imposed_point = 0
|
| 17 |
-
Arch = 'ResNet'
|
| 18 |
-
Fix_Torch_Wrap = False
|
| 19 |
-
BW_Position = False
|
| 20 |
-
dim = 128
|
| 21 |
-
dim0 =224
|
| 22 |
-
crop_ratio = dim/dim0
|
| 23 |
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
def __init__(self):
|
| 29 |
-
super(
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
super(Build_IRmodel_Resnet, self).__init__()
|
| 36 |
-
self.resnet_model = resnet_model
|
| 37 |
-
self.BW_Position = BW_Position
|
| 38 |
-
self.N_parameters = 6
|
| 39 |
-
self.registration_method = registration_method
|
| 40 |
-
self.fc1 =nn.Linear(6, 64)
|
| 41 |
-
self.fc2 =nn.Linear(64, 128*3)
|
| 42 |
-
self.fc3 =nn.Linear(512, self.N_parameters)
|
| 43 |
def forward(self, input_X_batch):
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
def pil_to_numpy(im):
|
| 66 |
im.load()
|
|
@@ -80,93 +187,27 @@ def pil_to_numpy(im):
|
|
| 80 |
raise RuntimeError("encoder error %d in tobytes" % s)
|
| 81 |
return data
|
| 82 |
|
| 83 |
-
def load_image_pil_accelerated(image_path, dim=128):
|
| 84 |
-
image = Image.open(image_path).convert("RGB")
|
| 85 |
-
array = pil_to_numpy(image)
|
| 86 |
-
tensor = torch.from_numpy(np.rollaxis(array,2,0)/255).to(torch.float32)
|
| 87 |
-
tensor = torchvision.transforms.Resize((dim,dim))(tensor)
|
| 88 |
-
return tensor
|
| 89 |
|
| 90 |
|
| 91 |
-
def workaround_matrix(Affine_mtrx0, acc = 2):
|
| 92 |
-
# To find the equivalent torch-compatible matrix from a correct matrix set acc=2 #This will be needed for transforming an image
|
| 93 |
-
# To find the correct Affine matrix from Torch compatible matrix set acc=0.5
|
| 94 |
-
Affine_mtrx_adj = inv_AM(Affine_mtrx0)
|
| 95 |
-
Affine_mtrx_adj[:,:,2]*=acc
|
| 96 |
-
return Affine_mtrx_adj
|
| 97 |
-
|
| 98 |
-
def inv_AM(Affine_mtrx):
|
| 99 |
-
AM3 = mtrx3(Affine_mtrx)
|
| 100 |
-
AM_inv = torch.linalg.inv(AM3)
|
| 101 |
-
return AM_inv[:,0:2,:]
|
| 102 |
-
|
| 103 |
-
def mtrx3(Affine_mtrx):
|
| 104 |
-
mtrx_shape = Affine_mtrx.shape
|
| 105 |
-
if len(mtrx_shape)==3:
|
| 106 |
-
N_Mbatches = mtrx_shape[0]
|
| 107 |
-
AM3 = torch.zeros( [N_Mbatches,3,3])#.to(device)
|
| 108 |
-
AM3[:,0:2,:] = Affine_mtrx
|
| 109 |
-
AM3[:,2,2] = 1
|
| 110 |
-
elif len(mtrx_shape)==2:
|
| 111 |
-
N_Mbatches = 1
|
| 112 |
-
AM3 = torch.zeros([3,3])#.to(device)
|
| 113 |
-
AM3[0:2,:] = Affine_mtrx
|
| 114 |
-
AM3[2,2] = 1
|
| 115 |
-
return AM3
|
| 116 |
-
|
| 117 |
-
def standarize_point(d, dim=128, flip = False):
|
| 118 |
-
if flip:
|
| 119 |
-
d = -d
|
| 120 |
-
return d/dim - 0.5
|
| 121 |
-
|
| 122 |
-
def destandarize_point(d, dim=128, flip = False):
|
| 123 |
-
if flip:
|
| 124 |
-
d = -d
|
| 125 |
-
return dim*(d + 0.5)
|
| 126 |
|
| 127 |
-
def generate_standard_elips(N_samples = 100, a= 1,b = 1):
|
| 128 |
-
radius = 0.25
|
| 129 |
-
center = 0
|
| 130 |
-
N_samples1 = int(N_samples/2 - 1)
|
| 131 |
-
N_samples2 = N_samples - N_samples1
|
| 132 |
-
x1 = torch.distributions.uniform.Uniform(center-radius,center + radius).sample([N_samples1])
|
| 133 |
-
x1_ordered = torch.sort(x1).values
|
| 134 |
-
y1 = center + b*torch.sqrt(radius**2 - ((x1_ordered-center)/a)**2)
|
| 135 |
-
x2 = torch.distributions.uniform.Uniform(center-radius,center + radius).sample([N_samples2])
|
| 136 |
-
x2_ordered = torch.sort(x2, descending=True).values
|
| 137 |
-
y2 = center - b*torch.sqrt(radius**2 - ((x2_ordered-center)/a)**2)
|
| 138 |
-
x = torch.cat([x1_ordered, x2_ordered])
|
| 139 |
-
y = torch.cat([y1, y2])
|
| 140 |
-
return x, y
|
| 141 |
|
| 142 |
-
def transform_standard_points(Affine_mat, x,y):
|
| 143 |
-
XY = torch.ones([3,x.shape[0]])
|
| 144 |
-
XY[0,:]= x
|
| 145 |
-
XY[1,:]= y
|
| 146 |
-
XYt = torch.matmul(Affine_mat.to('cpu').detach(),XY)
|
| 147 |
-
xt0 = XYt[0]
|
| 148 |
-
yt0 = XYt[1]
|
| 149 |
-
return xt0, yt0
|
| 150 |
-
|
| 151 |
-
def wrap_points(img, x_source, y_source, l=1, DIM =dim):
|
| 152 |
-
for i in range(len(y_source)):
|
| 153 |
-
x0 = x_source[i].int()
|
| 154 |
-
y0 = y_source[i].int()
|
| 155 |
-
if (x0<DIM) and (x0>0) and (y0<DIM) and (y0>0):
|
| 156 |
-
img[:,:,y0-l:y0+l,x0-l:x0+l] = 0
|
| 157 |
-
return img
|
| 158 |
|
| 159 |
|
| 160 |
-
def wrap_imge_cropped(Affine_mtrx, source_img, dim1=224, dim2=128):
|
| 161 |
-
source_img224 = torch.nn.ZeroPad2d(int((dim1-dim2)/2))(source_img)
|
| 162 |
-
grd = torch.nn.functional.affine_grid(Affine_mtrx, size=source_img224.shape,align_corners=False)
|
| 163 |
-
wrapped_img = torch.nn.functional.grid_sample(source_img224, grid=grd,
|
| 164 |
-
mode='bilinear', padding_mode='zeros', align_corners=False)
|
| 165 |
-
wrapped_img = torchvision.transforms.CenterCrop((dim2, dim2))(wrapped_img)
|
| 166 |
-
return wrapped_img
|
| 167 |
|
| 168 |
|
| 169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
def preprocess_image(image_path, dim = 128):
|
| 171 |
img = torch.zeros([1,3,dim,dim])
|
| 172 |
img[0] = load_image_pil_accelerated(image_path, dim)
|
|
|
|
| 1 |
import os, sys
|
| 2 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
| 3 |
import random
|
| 4 |
+
import time
|
| 5 |
+
import json
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import tqdm
|
| 8 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
+
def generate_standard_elips(N_samples = 10, a= 1,b = 1):
|
| 13 |
+
radius = 0.5
|
| 14 |
+
center = 0
|
| 15 |
+
N_samples1 = int(N_samples/2 - 1)
|
| 16 |
+
N_samples2 = N_samples - N_samples1
|
| 17 |
+
x1 = np.random.uniform((center-radius)*a,(center+radius)*a, size = N_samples1)
|
| 18 |
+
x1_ordered = np.sort(x1)
|
| 19 |
+
y1 = center + b*np.sqrt(radius**2 - ((x1_ordered-center)/a)**2)
|
| 20 |
+
x2 = np.random.uniform((center-radius)*a,(center+radius)*a, size = N_samples - N_samples1)
|
| 21 |
+
x2_ordered = -np.sort(-x2) #the minus sign to sort descindingly
|
| 22 |
+
y2 = center - b*np.sqrt(radius**2 - ((x2_ordered-center)/a)**2)
|
| 23 |
+
x = np.concatenate([x1_ordered, x2_ordered], axis=0)
|
| 24 |
+
y = np.concatenate([y1, y2], axis=0)
|
| 25 |
+
return x, y
|
| 26 |
|
| 27 |
+
def destandarize_point(d, dim=128):
|
| 28 |
+
return dim*(d + 0.5)
|
| 29 |
|
| 30 |
+
def To_pointcloud(x,y,z=0):
|
| 31 |
+
N_points = x.shape[0]
|
| 32 |
+
point_cloud = np.zeros([N_points,3])
|
| 33 |
+
point_cloud[:,0] = x
|
| 34 |
+
point_cloud[:,1] = y
|
| 35 |
+
if not z==0:
|
| 36 |
+
point_cloud[:,2] = z
|
| 37 |
+
return point_cloud
|
| 38 |
+
|
| 39 |
+
def To_xyz(point_cloud):
|
| 40 |
+
x = point_cloud[:,0]
|
| 41 |
+
y = point_cloud[:,1]
|
| 42 |
+
z = point_cloud[:,2]
|
| 43 |
+
return x,y,z
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def random_rigid_transformation(dim=2):
|
| 47 |
+
#dim = 4
|
| 48 |
+
rotation_x = 0
|
| 49 |
+
rotation_y = 0
|
| 50 |
+
rotation_z = random.uniform(0, 2)*np.pi
|
| 51 |
+
translation_x = random.uniform(-1, 1)*dim
|
| 52 |
+
translation_y = random.uniform(-1, 1)*dim
|
| 53 |
+
translation_z = 0
|
| 54 |
+
reflection_x = random.sample([-1,1],1)[0]
|
| 55 |
+
reflection_y = random.sample([-1,1],1)[0]
|
| 56 |
+
reflection_z = 1
|
| 57 |
+
Rotx = np.array([[1,0,0],
|
| 58 |
+
[0,np.cos(rotation_x),-np.sin(rotation_x)],
|
| 59 |
+
[0,np.sin(rotation_x),np.cos(rotation_x)]])
|
| 60 |
+
Roty = np.array([[np.cos(rotation_y),0,np.sin(rotation_y)],
|
| 61 |
+
[0,1,0],
|
| 62 |
+
[-np.sin(rotation_y),0,np.cos(rotation_y)]])
|
| 63 |
+
Rotz = np.array([[np.cos(rotation_z),-np.sin(rotation_z),0],
|
| 64 |
+
[np.sin(rotation_z),np.cos(rotation_z),0],
|
| 65 |
+
[0,0,1]])
|
| 66 |
+
Rotation = np.matmul(Rotz, np.matmul(Roty,Rotx))
|
| 67 |
+
Reflection = np.array([[reflection_x,0,0],[0,reflection_y,0],[0,0,reflection_z]])
|
| 68 |
+
Translation = np.array([translation_x,translation_y,translation_z])
|
| 69 |
+
RefRotation = np.matmul(Reflection,Rotation)
|
| 70 |
+
return RefRotation, Translation
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def rigid_2Dtransformation(prdction):
|
| 75 |
+
#prediction = [rotationz, reflectionx, reflectiony, translationx, translationy]
|
| 76 |
+
N_examples = prdction['rotation'].shape[0]
|
| 77 |
+
Translation = prdction['translation']
|
| 78 |
+
Reflection = torch.zeros([N_examples,3,3])#
|
| 79 |
+
Reflection[:,0,0] = prdction['reflection'][:,0]
|
| 80 |
+
Reflection[:,1,1] = prdction['reflection'][:,1]
|
| 81 |
+
Reflection[:,2,2] = 1.
|
| 82 |
+
rotation_z = prdction['rotation'][:,2]
|
| 83 |
+
Rotation = torch.zeros([N_examples,3,3])#np.repeat(np.eye(3)[None,:,:],N_examples, axis=0))
|
| 84 |
+
Rotation[:,0,0] = torch.cos(rotation_z)
|
| 85 |
+
Rotation[:,1,1] = torch.cos(rotation_z)
|
| 86 |
+
Rotation[:,0,1] = -torch.sin(rotation_z)
|
| 87 |
+
Rotation[:,2,2] = 1.
|
| 88 |
+
Rotation[:,1,0] = torch.sin(rotation_z)
|
| 89 |
+
RefRotation = torch.matmul(Reflection,Rotation)
|
| 90 |
+
return RefRotation, Translation
|
| 91 |
+
|
| 92 |
+
def batch_hausdorff_prcnt_distance(batch_point_cloud1, point_cloud2, percentile = 0.95):
|
| 93 |
+
assert point_cloud2.shape[0]==3
|
| 94 |
+
assert batch_point_cloud1.shape[1]==3
|
| 95 |
+
distances = torch.norm(batch_point_cloud1[:, :, None,:] - point_cloud2[None, :, :,None], dim=1)
|
| 96 |
+
dists1 = torch.min(distances, dim=1).values
|
| 97 |
+
dists2 = torch.min(distances, dim=2).values
|
| 98 |
+
# Calculate the 95th percentile distance
|
| 99 |
+
percentile_95 = torch.quantile(torch.cat([dists1, dists2],axis=1), percentile, interpolation='linear', dim=1)
|
| 100 |
+
return percentile_95
|
| 101 |
+
|
| 102 |
+
def HDloss(prd, pointcloud_source_norm_torch,pointcloud_target_norm_torch, percentile = 0.95):
|
| 103 |
+
A, b = rigid_2Dtransformation(prd)
|
| 104 |
+
point_cloud_wrapped = torch.matmul(A, pointcloud_source_norm_torch.T) + b[:,:,None]
|
| 105 |
+
loss = batch_hausdorff_prcnt_distance(point_cloud_wrapped, pointcloud_target_norm_torch.T, percentile)
|
| 106 |
+
return loss
|
| 107 |
+
|
| 108 |
+
def Mean_HDloss(prd, pointcloud_source_norm_torch, pointcloud_target_norm_torch, percentile = 0.95):
|
| 109 |
+
loss = HDloss(prd, pointcloud_source_norm_torch, pointcloud_target_norm_torch, percentile = 0.95)
|
| 110 |
+
return torch.mean(loss)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def wrap_pointcloud(record, pointcloud_source):
|
| 114 |
+
#normalize first
|
| 115 |
+
PC1_mean = np.mean(pointcloud_source, axis=0)
|
| 116 |
+
pointcloud_source_norm = pointcloud_source - PC1_mean
|
| 117 |
+
pointcloud_source_norm_torch = torch.tensor(pointcloud_source_norm, requires_grad=False).to(torch.float32)
|
| 118 |
+
# find Tx
|
| 119 |
+
A, b = rigid_2Dtransformation(record)
|
| 120 |
+
point_cloud_wrapped = torch.matmul(A, pointcloud_source_norm_torch.T) + b[:,:,None]
|
| 121 |
+
return point_cloud_wrapped
|
| 122 |
+
|
| 123 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 124 |
+
|
| 125 |
+
def move_dict2device(dictionary,device):
|
| 126 |
+
for key in list(dictionary.keys()):
|
| 127 |
+
dictionary[key] = dictionary[key].to(device)
|
| 128 |
+
return dictionary
|
| 129 |
+
|
| 130 |
+
eps = 0.000001948
|
| 131 |
+
class Optimization_model(torch.nn.Module):
|
| 132 |
def __init__(self):
|
| 133 |
+
super(Optimization_model, self).__init__()
|
| 134 |
+
self.alpha = torch.nn.Parameter(torch.tensor(0.5, requires_grad=True))
|
| 135 |
+
self.rotation = torch.nn.Parameter(torch.tensor([0.0, 0.0, 0.0], requires_grad=True))
|
| 136 |
+
self.translation = torch.nn.Parameter(torch.tensor([0.01,-0.01, 0.0], requires_grad=True))
|
| 137 |
+
self.reflection = torch.nn.Parameter(torch.sign(torch.tensor([0.01,-0.01, 1], requires_grad=True)))
|
| 138 |
+
#self.rigid = torch.nn.Parameter(torch.tensor([0.0,1.0,1.0,0.1,0.1], requires_grad=True))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
def forward(self, input_X_batch):
|
| 140 |
+
predicted_rotation = self.alpha*self.rotation + (1-self.alpha)*input_X_batch['rotation']
|
| 141 |
+
predicted_translation = self.alpha*self.translation + (1-self.alpha)*input_X_batch['translation']
|
| 142 |
+
predicted_reflection = torch.sign(self.alpha*self.reflection +
|
| 143 |
+
(1-self.alpha)*input_X_batch['reflection']+eps)
|
| 144 |
+
return {'rotation':predicted_rotation,
|
| 145 |
+
'translation':predicted_translation,
|
| 146 |
+
'reflection':predicted_reflection}
|
| 147 |
+
|
| 148 |
+
class Dataset(torch.utils.data.Dataset):
|
| 149 |
+
def __init__(self, dataset_size, N_dim = 2):
|
| 150 |
+
self.dataset_size = dataset_size
|
| 151 |
+
self.N_dim = 2
|
| 152 |
+
def __len__(self):
|
| 153 |
+
return int(self.dataset_size)
|
| 154 |
+
def __getitem__(self, index):
|
| 155 |
+
rotation = np.pi*(-1 + 2*torch.rand([3]))
|
| 156 |
+
translation = -0.1 + 0.2*torch.rand([3])
|
| 157 |
+
reflection = torch.sign(torch.rand([3]) - 0.5)
|
| 158 |
+
if self.N_dim == 2:
|
| 159 |
+
rotation[0:2]=0
|
| 160 |
+
translation[2]=0
|
| 161 |
+
reflection[2]=1
|
| 162 |
+
random_solution = {'rotation':rotation,
|
| 163 |
+
'translation':translation,
|
| 164 |
+
'reflection':reflection}
|
| 165 |
+
return random_solution
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
|
| 171 |
|
| 172 |
def pil_to_numpy(im):
|
| 173 |
im.load()
|
|
|
|
| 187 |
raise RuntimeError("encoder error %d in tobytes" % s)
|
| 188 |
return data
|
| 189 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
|
| 191 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
|
| 197 |
|
| 198 |
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def load_image_pil_accelerated(image_path, dim=128):
|
| 204 |
+
image = Image.open(image_path).convert("RGB")
|
| 205 |
+
array = pil_to_numpy(image)
|
| 206 |
+
tensor = torch.from_numpy(np.rollaxis(array,2,0)/255).to(torch.float32)
|
| 207 |
+
tensor = torchvision.transforms.Resize((dim,dim))(tensor)
|
| 208 |
+
return tensor
|
| 209 |
+
|
| 210 |
+
|
| 211 |
def preprocess_image(image_path, dim = 128):
|
| 212 |
img = torch.zeros([1,3,dim,dim])
|
| 213 |
img[0] = load_image_pil_accelerated(image_path, dim)
|