Upload utils.py
Browse files
utils.py
CHANGED
|
@@ -125,51 +125,56 @@ def tensor_to_base64(img_tensor: torch.Tensor) -> str:
|
|
| 125 |
|
| 126 |
|
| 127 |
def load_image_from_url(image_url):
|
| 128 |
-
|
|
|
|
| 129 |
response = requests.get(image_url)
|
| 130 |
response.raise_for_status()
|
| 131 |
-
|
| 132 |
|
| 133 |
-
|
| 134 |
-
|
|
|
|
|
|
|
| 135 |
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
w, h = None, None
|
| 139 |
|
| 140 |
-
|
|
|
|
|
|
|
| 141 |
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
if i.mode == 'I':
|
| 146 |
-
i = i.point(lambda i: i * (1 / 255))
|
| 147 |
-
image = i.convert("RGB")
|
| 148 |
-
|
| 149 |
-
if len(output_images) == 0:
|
| 150 |
-
w, h = image.size
|
| 151 |
-
|
| 152 |
-
if image.size != (w, h):
|
| 153 |
continue
|
| 154 |
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
mask_np = np.array(i.getchannel('A')).astype(np.float32) / 255.0
|
| 160 |
-
mask_tensor = 1. - torch.from_numpy(mask_np)
|
| 161 |
else:
|
| 162 |
-
mask_tensor = torch.zeros((64, 64), dtype=torch.float32
|
| 163 |
|
| 164 |
output_images.append(image_tensor)
|
| 165 |
output_masks.append(mask_tensor.unsqueeze(0))
|
| 166 |
|
|
|
|
| 167 |
if len(output_images) > 1 and getattr(img, "format", None) not in excluded_formats:
|
| 168 |
output_image = torch.cat(output_images, dim=0)
|
| 169 |
output_mask = torch.cat(output_masks, dim=0)
|
| 170 |
else:
|
| 171 |
-
output_image = output_images[0]
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
return output_image
|
| 175 |
|
|
|
|
| 125 |
|
| 126 |
|
| 127 |
def load_image_from_url(image_url):
|
| 128 |
+
"""从 URL 加载图像为 torch.Tensor,支持多帧和 Alpha 通道融合。"""
|
| 129 |
+
# 下载并读取图像
|
| 130 |
response = requests.get(image_url)
|
| 131 |
response.raise_for_status()
|
| 132 |
+
img = pillow(Image.open, BytesIO(response.content))
|
| 133 |
|
| 134 |
+
excluded_formats = {'MPO'}
|
| 135 |
+
output_images, output_masks = [], []
|
| 136 |
+
w = h = None
|
| 137 |
+
has_alpha = False
|
| 138 |
|
| 139 |
+
for frame in ImageSequence.Iterator(img):
|
| 140 |
+
frame = pillow(ImageOps.exif_transpose, frame)
|
|
|
|
| 141 |
|
| 142 |
+
if frame.mode == 'I':
|
| 143 |
+
frame = frame.point(lambda px: px * (1 / 255))
|
| 144 |
+
rgb_image = frame.convert("RGB")
|
| 145 |
|
| 146 |
+
if w is None:
|
| 147 |
+
w, h = rgb_image.size
|
| 148 |
+
if rgb_image.size != (w, h):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
continue
|
| 150 |
|
| 151 |
+
image_tensor = torch.from_numpy(np.array(rgb_image, dtype=np.float32) / 255.0)[None, ...]
|
| 152 |
+
if 'A' in frame.getbands():
|
| 153 |
+
has_alpha = True
|
| 154 |
+
mask_tensor = 1.0 - torch.from_numpy(np.array(frame.getchannel('A'), dtype=np.float32) / 255.0)
|
|
|
|
|
|
|
| 155 |
else:
|
| 156 |
+
mask_tensor = torch.zeros((64, 64), dtype=torch.float32)
|
| 157 |
|
| 158 |
output_images.append(image_tensor)
|
| 159 |
output_masks.append(mask_tensor.unsqueeze(0))
|
| 160 |
|
| 161 |
+
# 合并帧
|
| 162 |
if len(output_images) > 1 and getattr(img, "format", None) not in excluded_formats:
|
| 163 |
output_image = torch.cat(output_images, dim=0)
|
| 164 |
output_mask = torch.cat(output_masks, dim=0)
|
| 165 |
else:
|
| 166 |
+
output_image, output_mask = output_images[0], output_masks[0]
|
| 167 |
+
|
| 168 |
+
# 融合 Alpha 通道
|
| 169 |
+
if has_alpha:
|
| 170 |
+
image = output_image.squeeze(0)
|
| 171 |
+
if image.dim() == 3 and image.shape[0] in (1, 3, 4):
|
| 172 |
+
image = image.permute(1, 2, 0)
|
| 173 |
+
h, w = output_mask.shape
|
| 174 |
+
rgba = torch.zeros(h, w, 4)
|
| 175 |
+
rgba[:, :, :3] = image
|
| 176 |
+
rgba[:, :, 3] = output_mask
|
| 177 |
+
output_image = rgba
|
| 178 |
|
| 179 |
return output_image
|
| 180 |
|