Upload utils.py
Browse files
utils.py
CHANGED
|
@@ -125,56 +125,57 @@ def tensor_to_base64(img_tensor: torch.Tensor) -> str:
|
|
| 125 |
|
| 126 |
|
| 127 |
def load_image_from_url(image_url):
|
| 128 |
-
|
| 129 |
-
# 下载并读取图像
|
| 130 |
response = requests.get(image_url)
|
| 131 |
response.raise_for_status()
|
| 132 |
-
|
| 133 |
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
| 137 |
has_alpha = False
|
| 138 |
|
| 139 |
-
|
| 140 |
-
|
|
|
|
|
|
|
| 141 |
|
| 142 |
-
if
|
| 143 |
-
|
| 144 |
-
|
| 145 |
|
| 146 |
-
if
|
| 147 |
-
w, h =
|
| 148 |
-
|
|
|
|
| 149 |
continue
|
| 150 |
|
| 151 |
-
|
| 152 |
-
|
|
|
|
|
|
|
| 153 |
has_alpha = True
|
| 154 |
-
|
|
|
|
| 155 |
else:
|
| 156 |
-
mask_tensor = torch.zeros((64, 64), dtype=torch.float32)
|
| 157 |
|
| 158 |
output_images.append(image_tensor)
|
| 159 |
output_masks.append(mask_tensor.unsqueeze(0))
|
| 160 |
|
| 161 |
-
# 合并帧
|
| 162 |
if len(output_images) > 1 and getattr(img, "format", None) not in excluded_formats:
|
| 163 |
output_image = torch.cat(output_images, dim=0)
|
| 164 |
output_mask = torch.cat(output_masks, dim=0)
|
| 165 |
else:
|
| 166 |
-
output_image
|
|
|
|
| 167 |
|
| 168 |
-
# 融合 Alpha 通道
|
| 169 |
if has_alpha:
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
image = image.permute(1, 2, 0)
|
| 173 |
-
h, w = output_mask.shape
|
| 174 |
-
rgba = torch.zeros(h, w, 4)
|
| 175 |
-
rgba[:, :, :3] = image
|
| 176 |
-
rgba[:, :, 3] = output_mask
|
| 177 |
-
output_image = rgba
|
| 178 |
|
| 179 |
return output_image
|
| 180 |
|
|
@@ -196,4 +197,9 @@ def extract_json_from_text(text: str) -> dict:
|
|
| 196 |
return json.loads(json_str)
|
| 197 |
except json.JSONDecodeError as e:
|
| 198 |
print(f"JSON 解析失败: {e}")
|
| 199 |
-
return {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
|
| 127 |
def load_image_from_url(image_url):
|
| 128 |
+
# 下载图片数据
|
|
|
|
| 129 |
response = requests.get(image_url)
|
| 130 |
response.raise_for_status()
|
| 131 |
+
img_data = BytesIO(response.content)
|
| 132 |
|
| 133 |
+
# 使用 PIL 打开图像
|
| 134 |
+
img = pillow(Image.open, img_data)
|
| 135 |
+
|
| 136 |
+
output_images = []
|
| 137 |
+
output_masks = []
|
| 138 |
+
w, h = None, None
|
| 139 |
has_alpha = False
|
| 140 |
|
| 141 |
+
excluded_formats = ['MPO']
|
| 142 |
+
|
| 143 |
+
for i in ImageSequence.Iterator(img):
|
| 144 |
+
i = pillow(ImageOps.exif_transpose, i)
|
| 145 |
|
| 146 |
+
if i.mode == 'I':
|
| 147 |
+
i = i.point(lambda i: i * (1 / 255))
|
| 148 |
+
image = i.convert("RGB")
|
| 149 |
|
| 150 |
+
if len(output_images) == 0:
|
| 151 |
+
w, h = image.size
|
| 152 |
+
|
| 153 |
+
if image.size != (w, h):
|
| 154 |
continue
|
| 155 |
|
| 156 |
+
image_np = np.array(image).astype(np.float32) / 255.0
|
| 157 |
+
image_tensor = torch.from_numpy(image_np)[None, ...]
|
| 158 |
+
|
| 159 |
+
if 'A' in i.getbands():
|
| 160 |
has_alpha = True
|
| 161 |
+
mask_np = np.array(i.getchannel('A')).astype(np.float32) / 255.0
|
| 162 |
+
mask_tensor = 1. - torch.from_numpy(mask_np)
|
| 163 |
else:
|
| 164 |
+
mask_tensor = torch.zeros((64, 64), dtype=torch.float32, device="cpu")
|
| 165 |
|
| 166 |
output_images.append(image_tensor)
|
| 167 |
output_masks.append(mask_tensor.unsqueeze(0))
|
| 168 |
|
|
|
|
| 169 |
if len(output_images) > 1 and getattr(img, "format", None) not in excluded_formats:
|
| 170 |
output_image = torch.cat(output_images, dim=0)
|
| 171 |
output_mask = torch.cat(output_masks, dim=0)
|
| 172 |
else:
|
| 173 |
+
output_image = output_images[0]
|
| 174 |
+
output_mask = output_masks[0]
|
| 175 |
|
|
|
|
| 176 |
if has_alpha:
|
| 177 |
+
output_mask = output_mask.permute(1, 2, 0).unsqueeze(0)
|
| 178 |
+
output_image = torch.cat([output_image, output_mask], dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
|
| 180 |
return output_image
|
| 181 |
|
|
|
|
| 197 |
return json.loads(json_str)
|
| 198 |
except json.JSONDecodeError as e:
|
| 199 |
print(f"JSON 解析失败: {e}")
|
| 200 |
+
return {}
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
if __name__ == "__main__":
|
| 204 |
+
output_image = load_image_from_url("https://obs-large.mtlab.meitu.com/mtopen/fb348748a0ca48cc9ee1ff15059ff499/5264072a-668c-40d5-48c3-624020f88845.png")
|
| 205 |
+
print(output_image.shape)
|