Update app.py
Browse files
app.py
CHANGED
|
@@ -12,7 +12,9 @@ def geometry_transform(images: list,
|
|
| 12 |
|
| 13 |
file_names: list = [f.name for f in images]
|
| 14 |
image_list: list = [K.io.load_image(f, K.io.ImageLoadType(0)).float().unsqueeze(0)/255 for f in file_names]
|
| 15 |
-
|
|
|
|
|
|
|
| 16 |
center: torch.Tensor = torch.tensor([x.shape[1:] for x in image_batch])/2
|
| 17 |
translation = torch.tensor(translation).repeat(len(image_list), 2)
|
| 18 |
scale = torch.tensor(scale).repeat(len(image_list), 2)
|
|
@@ -20,7 +22,7 @@ def geometry_transform(images: list,
|
|
| 20 |
affine_matrix: torch.Tensor = KG.get_affine_matrix2d(translation, center, scale, angle)
|
| 21 |
with torch.inference_mode():
|
| 22 |
transformed: torch.Tensor = KG.transform.warp_affine(image_batch, affine_matrix[:, :2], dsize=image_batch.shape[2:])
|
| 23 |
-
concat_images: list = torch.cat(transformed, dim=-1)
|
| 24 |
final_images: np.ndarray = K.tensor_to_image(concat_images*255).astype(np.uint8)
|
| 25 |
|
| 26 |
return final_images
|
|
|
|
| 12 |
|
| 13 |
file_names: list = [f.name for f in images]
|
| 14 |
image_list: list = [K.io.load_image(f, K.io.ImageLoadType(0)).float().unsqueeze(0)/255 for f in file_names]
|
| 15 |
+
if len(image_list) > 1:
|
| 16 |
+
image_list = [K.geometry.resize(x, x.shape[-2:], antialias=True) for x in image_list]
|
| 17 |
+
image_batch: torch.Tensor = torch.cat(image_list, 0)
|
| 18 |
center: torch.Tensor = torch.tensor([x.shape[1:] for x in image_batch])/2
|
| 19 |
translation = torch.tensor(translation).repeat(len(image_list), 2)
|
| 20 |
scale = torch.tensor(scale).repeat(len(image_list), 2)
|
|
|
|
| 22 |
affine_matrix: torch.Tensor = KG.get_affine_matrix2d(translation, center, scale, angle)
|
| 23 |
with torch.inference_mode():
|
| 24 |
transformed: torch.Tensor = KG.transform.warp_affine(image_batch, affine_matrix[:, :2], dsize=image_batch.shape[2:])
|
| 25 |
+
concat_images: list = torch.cat([x for x in transformed], dim=-1)
|
| 26 |
final_images: np.ndarray = K.tensor_to_image(concat_images*255).astype(np.uint8)
|
| 27 |
|
| 28 |
return final_images
|