sin30 commited on
Commit
9ec22e9
·
verified ·
1 Parent(s): b131fcc

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +33 -28
utils.py CHANGED
@@ -125,51 +125,56 @@ def tensor_to_base64(img_tensor: torch.Tensor) -> str:
125
 
126
 
127
  def load_image_from_url(image_url):
128
- # 下载图片数据
 
129
  response = requests.get(image_url)
130
  response.raise_for_status()
131
- img_data = BytesIO(response.content)
132
 
133
- # 使用 PIL 打开图像
134
- img = pillow(Image.open, img_data)
 
 
135
 
136
- output_images = []
137
- output_masks = []
138
- w, h = None, None
139
 
140
- excluded_formats = ['MPO']
 
 
141
 
142
- for i in ImageSequence.Iterator(img):
143
- i = pillow(ImageOps.exif_transpose, i)
144
-
145
- if i.mode == 'I':
146
- i = i.point(lambda i: i * (1 / 255))
147
- image = i.convert("RGB")
148
-
149
- if len(output_images) == 0:
150
- w, h = image.size
151
-
152
- if image.size != (w, h):
153
  continue
154
 
155
- image_np = np.array(image).astype(np.float32) / 255.0
156
- image_tensor = torch.from_numpy(image_np)[None, ...]
157
-
158
- if 'A' in i.getbands():
159
- mask_np = np.array(i.getchannel('A')).astype(np.float32) / 255.0
160
- mask_tensor = 1. - torch.from_numpy(mask_np)
161
  else:
162
- mask_tensor = torch.zeros((64, 64), dtype=torch.float32, device="cpu")
163
 
164
  output_images.append(image_tensor)
165
  output_masks.append(mask_tensor.unsqueeze(0))
166
 
 
167
  if len(output_images) > 1 and getattr(img, "format", None) not in excluded_formats:
168
  output_image = torch.cat(output_images, dim=0)
169
  output_mask = torch.cat(output_masks, dim=0)
170
  else:
171
- output_image = output_images[0]
172
- output_mask = output_masks[0]
 
 
 
 
 
 
 
 
 
 
173
 
174
  return output_image
175
 
 
125
 
126
 
127
  def load_image_from_url(image_url):
128
+ """从 URL 加载图像为 torch.Tensor,支持多帧和 Alpha 通道融合。"""
129
+ # 下载并读取图像
130
  response = requests.get(image_url)
131
  response.raise_for_status()
132
+ img = pillow(Image.open, BytesIO(response.content))
133
 
134
+ excluded_formats = {'MPO'}
135
+ output_images, output_masks = [], []
136
+ w = h = None
137
+ has_alpha = False
138
 
139
+ for frame in ImageSequence.Iterator(img):
140
+ frame = pillow(ImageOps.exif_transpose, frame)
 
141
 
142
+ if frame.mode == 'I':
143
+ frame = frame.point(lambda px: px * (1 / 255))
144
+ rgb_image = frame.convert("RGB")
145
 
146
+ if w is None:
147
+ w, h = rgb_image.size
148
+ if rgb_image.size != (w, h):
 
 
 
 
 
 
 
 
149
  continue
150
 
151
+ image_tensor = torch.from_numpy(np.array(rgb_image, dtype=np.float32) / 255.0)[None, ...]
152
+ if 'A' in frame.getbands():
153
+ has_alpha = True
154
+ mask_tensor = 1.0 - torch.from_numpy(np.array(frame.getchannel('A'), dtype=np.float32) / 255.0)
 
 
155
  else:
156
+ mask_tensor = torch.zeros((64, 64), dtype=torch.float32)
157
 
158
  output_images.append(image_tensor)
159
  output_masks.append(mask_tensor.unsqueeze(0))
160
 
161
+ # 合并帧
162
  if len(output_images) > 1 and getattr(img, "format", None) not in excluded_formats:
163
  output_image = torch.cat(output_images, dim=0)
164
  output_mask = torch.cat(output_masks, dim=0)
165
  else:
166
+ output_image, output_mask = output_images[0], output_masks[0]
167
+
168
+ # 融合 Alpha 通道
169
+ if has_alpha:
170
+ image = output_image.squeeze(0)
171
+ if image.dim() == 3 and image.shape[0] in (1, 3, 4):
172
+ image = image.permute(1, 2, 0)
173
+ h, w = output_mask.shape
174
+ rgba = torch.zeros(h, w, 4)
175
+ rgba[:, :, :3] = image
176
+ rgba[:, :, 3] = output_mask
177
+ output_image = rgba
178
 
179
  return output_image
180