sin30 commited on
Commit
4188837
·
verified ·
1 Parent(s): 840a3a6

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +194 -0
utils.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+ import requests
4
+ from io import BytesIO
5
+ from PIL import Image, ImageOps, ImageSequence
6
+ import numpy as np
7
+ import torch
8
+ import base64
9
+ from PIL import ImageFile, UnidentifiedImageError
10
+ import hashlib
11
+ import time
12
+
13
+
14
+ def pillow(fn, arg):
15
+ prev_value = None
16
+ try:
17
+ x = fn(arg)
18
+ except (OSError, UnidentifiedImageError, ValueError): #PIL issues #4472 and #2445, also fixes ComfyUI issue #3416
19
+ prev_value = ImageFile.LOAD_TRUNCATED_IMAGES
20
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
21
+ x = fn(arg)
22
+ finally:
23
+ if prev_value is not None:
24
+ ImageFile.LOAD_TRUNCATED_IMAGES = prev_value
25
+ return x
26
+
27
+
28
+ def image_file_to_base64(path: str) -> str:
29
+ """
30
+ 读取本地图片文件,并返回其 Base64 编码字符串(不带 data URI 头)。
31
+ """
32
+ with open(path, 'rb') as f:
33
+ data = f.read()
34
+ # 将二进制数据编码为 Base64 字符串,并解码为 str
35
+ base64_str = base64.b64encode(data).decode('utf-8')
36
+ return base64_str
37
+
38
+
39
+ def tensor_to_url(img_tensor: torch.Tensor) -> str:
40
+ # 批量→单张
41
+ if img_tensor.dim() == 4:
42
+ img_tensor = img_tensor[0]
43
+
44
+ # 确保在 CPU、float
45
+ img = img_tensor.detach().cpu().float()
46
+
47
+ # 只有三维时才考虑 permute
48
+ if img.dim() == 3:
49
+ d0, d1, d2 = img.shape
50
+ # 如果第一维是通道(常见 1,3,4),就把 (C,H,W)→(H,W,C)
51
+ if d0 in (1, 3, 4):
52
+ img = img.permute(1, 2, 0)
53
+ # 否则,如果最后一维是通道,就假设已经是 (H,W,C) 了
54
+ elif d2 in (1, 3, 4):
55
+ pass
56
+ else:
57
+ raise ValueError(f"Unexpected tensor shape {img.shape} for image conversion.")
58
+
59
+ # 此时 img.shape 应该是 (H, W) 或 (H, W, C)
60
+ arr = img.clamp(0, 1).mul(255).byte().numpy()
61
+
62
+ # 创建 PIL 图像
63
+ pil_img = Image.fromarray(arr)
64
+
65
+ buf = BytesIO()
66
+ pil_img.save(buf, format="JPEG") # 或 "JPEG"
67
+ buf.seek(0)
68
+
69
+ # 3. 用 requests 传文件
70
+ files = {
71
+ "file": ("test.png", buf, "image/jpeg") # (文件名, 文件内容, MIME类型)
72
+ }
73
+
74
+ UPLOAD_URL = "http://deepnet.meitustat.com/data-center/v1/cloud/upload"
75
+
76
+ # 超时重试配置
77
+ max_retries = 3
78
+ timeout = 30 # 30秒超时
79
+ retry_delay = 1 # 重试间隔1秒
80
+
81
+ for attempt in range(max_retries):
82
+ try:
83
+ response = requests.post(UPLOAD_URL, files=files, timeout=timeout)
84
+ response.raise_for_status() # Raise exception for HTTP errors
85
+ return response.json()['data']['url']
86
+ except Exception as e:
87
+ if attempt == max_retries - 1: # 最后一次尝试失败
88
+ raise Exception(f"上传失败,已重试 {max_retries} 次: {str(e)}")
89
+ print(f"上传失败 (尝试 {attempt + 1}/{max_retries}): {str(e)},{retry_delay}秒后重试...")
90
+ time.sleep(retry_delay)
91
+ retry_delay *= 2 # 指数退避,下次重试间隔翻倍
92
+
93
+
94
+ def tensor_to_base64(img_tensor: torch.Tensor) -> str:
95
+ # 批量→单张
96
+ if img_tensor.dim() == 4:
97
+ img_tensor = img_tensor[0]
98
+
99
+ # 确保在 CPU、float
100
+ img = img_tensor.detach().cpu().float()
101
+
102
+ # 只有三维时才考虑 permute
103
+ if img.dim() == 3:
104
+ d0, d1, d2 = img.shape
105
+ # 如果第一维是通道(常见 1,3,4),就把 (C,H,W)→(H,W,C)
106
+ if d0 in (1, 3, 4):
107
+ img = img.permute(1, 2, 0)
108
+ # 否则,如果最后一维是通道,就假设已经是 (H,W,C) 了
109
+ elif d2 in (1, 3, 4):
110
+ pass
111
+ else:
112
+ raise ValueError(f"Unexpected tensor shape {img.shape} for image conversion.")
113
+
114
+ # 此时 img.shape 应该是 (H, W) 或 (H, W, C)
115
+ arr = img.clamp(0, 1).mul(255).byte().numpy()
116
+
117
+ # 创建 PIL 图像
118
+ pil_img = Image.fromarray(arr)
119
+
120
+ # 存到内存、再转 Base64
121
+ buf = BytesIO()
122
+ pil_img.save(buf, format="JPEG")
123
+ b64 = base64.b64encode(buf.getvalue()).decode("ascii")
124
+ return b64
125
+
126
+
127
+ def load_image_from_url(image_url):
128
+ # 下载图片数据
129
+ response = requests.get(image_url)
130
+ response.raise_for_status()
131
+ img_data = BytesIO(response.content)
132
+
133
+ # 使用 PIL 打开图像
134
+ img = pillow(Image.open, img_data)
135
+
136
+ output_images = []
137
+ output_masks = []
138
+ w, h = None, None
139
+
140
+ excluded_formats = ['MPO']
141
+
142
+ for i in ImageSequence.Iterator(img):
143
+ i = pillow(ImageOps.exif_transpose, i)
144
+
145
+ if i.mode == 'I':
146
+ i = i.point(lambda i: i * (1 / 255))
147
+ image = i.convert("RGB")
148
+
149
+ if len(output_images) == 0:
150
+ w, h = image.size
151
+
152
+ if image.size != (w, h):
153
+ continue
154
+
155
+ image_np = np.array(image).astype(np.float32) / 255.0
156
+ image_tensor = torch.from_numpy(image_np)[None, ...]
157
+
158
+ if 'A' in i.getbands():
159
+ mask_np = np.array(i.getchannel('A')).astype(np.float32) / 255.0
160
+ mask_tensor = 1. - torch.from_numpy(mask_np)
161
+ else:
162
+ mask_tensor = torch.zeros((64, 64), dtype=torch.float32, device="cpu")
163
+
164
+ output_images.append(image_tensor)
165
+ output_masks.append(mask_tensor.unsqueeze(0))
166
+
167
+ if len(output_images) > 1 and getattr(img, "format", None) not in excluded_formats:
168
+ output_image = torch.cat(output_images, dim=0)
169
+ output_mask = torch.cat(output_masks, dim=0)
170
+ else:
171
+ output_image = output_images[0]
172
+ output_mask = output_masks[0]
173
+
174
+ return output_image
175
+
176
+
177
+ def extract_json_from_text(text: str) -> dict:
178
+ """
179
+ 从给定的文本中提取第一个 JSON 对象并解析成 Python 字典。
180
+
181
+ :param text: 包含 JSON 的文本
182
+ :return: 解析后的字典,如果未找到 JSON 则返回空字典
183
+ """
184
+ # 匹配第一个花括号及其中所有内容
185
+ match = re.search(r"\{[\s\S]*?\}", text)
186
+ if not match:
187
+ return {}
188
+
189
+ json_str = match.group(0)
190
+ try:
191
+ return json.loads(json_str)
192
+ except json.JSONDecodeError as e:
193
+ print(f"JSON 解析失败: {e}")
194
+ return {}