sin30 commited on
Commit
7b5eca7
·
verified ·
1 Parent(s): 9ec22e9

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +36 -30
utils.py CHANGED
@@ -125,56 +125,57 @@ def tensor_to_base64(img_tensor: torch.Tensor) -> str:
125
 
126
 
127
  def load_image_from_url(image_url):
128
- """从 URL 加载图像为 torch.Tensor,支持多帧和 Alpha 通道融合。"""
129
- # 下载并读取图像
130
  response = requests.get(image_url)
131
  response.raise_for_status()
132
- img = pillow(Image.open, BytesIO(response.content))
133
 
134
- excluded_formats = {'MPO'}
135
- output_images, output_masks = [], []
136
- w = h = None
 
 
 
137
  has_alpha = False
138
 
139
- for frame in ImageSequence.Iterator(img):
140
- frame = pillow(ImageOps.exif_transpose, frame)
 
 
141
 
142
- if frame.mode == 'I':
143
- frame = frame.point(lambda px: px * (1 / 255))
144
- rgb_image = frame.convert("RGB")
145
 
146
- if w is None:
147
- w, h = rgb_image.size
148
- if rgb_image.size != (w, h):
 
149
  continue
150
 
151
- image_tensor = torch.from_numpy(np.array(rgb_image, dtype=np.float32) / 255.0)[None, ...]
152
- if 'A' in frame.getbands():
 
 
153
  has_alpha = True
154
- mask_tensor = 1.0 - torch.from_numpy(np.array(frame.getchannel('A'), dtype=np.float32) / 255.0)
 
155
  else:
156
- mask_tensor = torch.zeros((64, 64), dtype=torch.float32)
157
 
158
  output_images.append(image_tensor)
159
  output_masks.append(mask_tensor.unsqueeze(0))
160
 
161
- # 合并帧
162
  if len(output_images) > 1 and getattr(img, "format", None) not in excluded_formats:
163
  output_image = torch.cat(output_images, dim=0)
164
  output_mask = torch.cat(output_masks, dim=0)
165
  else:
166
- output_image, output_mask = output_images[0], output_masks[0]
 
167
 
168
- # 融合 Alpha 通道
169
  if has_alpha:
170
- image = output_image.squeeze(0)
171
- if image.dim() == 3 and image.shape[0] in (1, 3, 4):
172
- image = image.permute(1, 2, 0)
173
- h, w = output_mask.shape
174
- rgba = torch.zeros(h, w, 4)
175
- rgba[:, :, :3] = image
176
- rgba[:, :, 3] = output_mask
177
- output_image = rgba
178
 
179
  return output_image
180
 
@@ -196,4 +197,9 @@ def extract_json_from_text(text: str) -> dict:
196
  return json.loads(json_str)
197
  except json.JSONDecodeError as e:
198
  print(f"JSON 解析失败: {e}")
199
- return {}
 
 
 
 
 
 
125
 
126
 
127
  def load_image_from_url(image_url):
128
+ # 下载图片数据
 
129
  response = requests.get(image_url)
130
  response.raise_for_status()
131
+ img_data = BytesIO(response.content)
132
 
133
+ # 使用 PIL 打开图像
134
+ img = pillow(Image.open, img_data)
135
+
136
+ output_images = []
137
+ output_masks = []
138
+ w, h = None, None
139
  has_alpha = False
140
 
141
+ excluded_formats = ['MPO']
142
+
143
+ for i in ImageSequence.Iterator(img):
144
+ i = pillow(ImageOps.exif_transpose, i)
145
 
146
+ if i.mode == 'I':
147
+ i = i.point(lambda i: i * (1 / 255))
148
+ image = i.convert("RGB")
149
 
150
+ if len(output_images) == 0:
151
+ w, h = image.size
152
+
153
+ if image.size != (w, h):
154
  continue
155
 
156
+ image_np = np.array(image).astype(np.float32) / 255.0
157
+ image_tensor = torch.from_numpy(image_np)[None, ...]
158
+
159
+ if 'A' in i.getbands():
160
  has_alpha = True
161
+ mask_np = np.array(i.getchannel('A')).astype(np.float32) / 255.0
162
+ mask_tensor = 1. - torch.from_numpy(mask_np)
163
  else:
164
+ mask_tensor = torch.zeros((64, 64), dtype=torch.float32, device="cpu")
165
 
166
  output_images.append(image_tensor)
167
  output_masks.append(mask_tensor.unsqueeze(0))
168
 
 
169
  if len(output_images) > 1 and getattr(img, "format", None) not in excluded_formats:
170
  output_image = torch.cat(output_images, dim=0)
171
  output_mask = torch.cat(output_masks, dim=0)
172
  else:
173
+ output_image = output_images[0]
174
+ output_mask = output_masks[0]
175
 
 
176
  if has_alpha:
177
+ output_mask = output_mask.permute(1, 2, 0).unsqueeze(0)
178
+ output_image = torch.cat([output_image, output_mask], dim=-1)
 
 
 
 
 
 
179
 
180
  return output_image
181
 
 
197
  return json.loads(json_str)
198
  except json.JSONDecodeError as e:
199
  print(f"JSON 解析失败: {e}")
200
+ return {}
201
+
202
+
203
+ if __name__ == "__main__":
204
+ output_image = load_image_from_url("https://obs-large.mtlab.meitu.com/mtopen/fb348748a0ca48cc9ee1ff15059ff499/5264072a-668c-40d5-48c3-624020f88845.png")
205
+ print(output_image.shape)