gtang666 commited on
Commit
9623cd1
·
verified ·
1 Parent(s): ed1622f
utils/__init__.py ADDED
File without changes
utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (145 Bytes). View file
 
utils/__pycache__/utils.cpython-39.pyc ADDED
Binary file (14.9 kB). View file
 
utils/utils.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from PIL import Image
3
+ import numpy as np
4
+ from copy import deepcopy
5
+ import cv2
6
+ import os
7
+ from tqdm import tqdm
8
+ import shutil
9
+ import torch
10
+ import torchvision.transforms as T
11
+ from PIL import Image, ImageOps
12
+ from torchvision.transforms.functional import InterpolationMode
13
+ import re
14
+ import imghdr
15
+
16
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
17
+ IMAGENET_STD = (0.229, 0.224, 0.225)
18
+
19
+
20
+ def calculate_iou(boxA, boxB,mini=False):
21
+ # 计算交集矩形的坐标
22
+ xA = max(boxA[0], boxB[0])
23
+ yA = max(boxA[1], boxB[1])
24
+ xB = min(boxA[2], boxB[2])
25
+ yB = min(boxA[3], boxB[3])
26
+
27
+ # 计算交集面积
28
+ interArea = max(0, xB - xA) * max(0, yB - yA)
29
+
30
+ # 计算两个边界框的面积
31
+ boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
32
+ boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
33
+
34
+ # 计算并集面积
35
+ unionArea = boxAArea + boxBArea - interArea
36
+
37
+ # 计算IoU
38
+ iou = interArea / unionArea
39
+ if mini:
40
+ iou=interArea/min(boxAArea,boxBArea)
41
+ return iou
42
+ def get_all_jpgs(folder_path,suffix='.jpg'):
43
+ """得到文件夹中的所有jpg文件路径"""
44
+ files = os.listdir(folder_path)
45
+ jpg_files = [folder_path+f for f in files if os.path.isfile(os.path.join(folder_path, f)) and f.endswith(suffix)]
46
+ return jpg_files
47
+
48
+ def get_all_jsons(folder_path):
49
+ """得到文件夹中的所有json文件路径"""
50
+ files = os.listdir(folder_path)
51
+ json_files = [folder_path+f for f in files if os.path.isfile(os.path.join(folder_path, f)) and f.endswith('json')]
52
+ return json_files
53
+
54
+ def load_json(pth):
55
+ """加载json文件"""
56
+ with open(pth, 'r', encoding='utf-8') as f:
57
+ data = json.load(f)
58
+ return data
59
+ def save_json(pth,data):
60
+ """保存json文件"""
61
+ with open(pth, 'w', encoding='utf-8') as f:
62
+ json.dump(data, f, ensure_ascii=False, indent=4)
63
+
64
+ def shuffle_lists(list1, list2,seed=42):
65
+ import random
66
+ assert len(list1) == len(list2), "两个列表必须等长"
67
+ random.seed(seed)
68
+ # 创建索引列表
69
+ indices = list(range(len(list1)))
70
+
71
+ # 打乱索引列表
72
+ random.shuffle(indices)
73
+
74
+ # 使用打乱后的索引列表重新排列两个列表
75
+ shuffled_list1 = [list1[i] for i in indices]
76
+ shuffled_list2 = [list2[i] for i in indices]
77
+
78
+ return shuffled_list1, shuffled_list2
79
+
80
+ def most_frequent_rgb(image_array):
81
+ """找一张图片中最frequent的rgb,用于填充mask"""
82
+ # Flatten the image array to a 2D array where each row is an RGB tuple
83
+ pixels = image_array.reshape(-1, image_array.shape[-1])
84
+
85
+ # Use np.unique with return_counts to find unique rows and their counts
86
+ unique_pixels, counts = np.unique(pixels, axis=0, return_counts=True)
87
+
88
+ # Find the index of the most frequent pixel
89
+ most_frequent_index = np.argmax(counts)
90
+
91
+ # Get the most frequent pixel and its count
92
+ most_frequent_pixel = unique_pixels[most_frequent_index]
93
+ frequency = counts[most_frequent_index]
94
+ return most_frequent_pixel, frequency
95
+
96
+ def half_divide(img,data):
97
+ """将图片从中分开,mask被穿过的char,并得到对应的左右json文件"""
98
+ left_data={"shapes":[],"imageHeight":data["imageHeight"],"imageWidth":data["imageWidth"]//2}
99
+ right_data={"shapes":[],"imageHeight":data["imageHeight"],"imageWidth":data["imageWidth"]//2}
100
+
101
+ # 获取原始尺寸
102
+ width, height = img.size
103
+
104
+ # 计算切割点
105
+ split_point = width // 2
106
+ image_array = np.array(img)
107
+ color,_=most_frequent_rgb(image_array)
108
+ modified_image=image_array.copy()
109
+
110
+ to_be_mask=[]
111
+ for item in data['shapes']:
112
+ if len(item['points'])!=2 or len(item['points'][0])!=2 or len(item['points'][1])!=2:
113
+ continue
114
+ [x1,y1],[x2,y2]=item['points']
115
+ if x2<split_point:
116
+ left_data['shapes'].append({"points":[[x1,y1],[x2,y2]]})
117
+ elif x1>split_point:
118
+ right_data['shapes'].append({"points":[[x1-split_point,y1],[x2-split_point,y2]]})
119
+ else:
120
+ to_be_mask.append([x1,y1,x2,y2])
121
+
122
+ for coord in to_be_mask:
123
+ x1, y1, x2, y2 = coord
124
+ modified_image[int(y1):int(y2), int(x1):int(x2)] =color
125
+
126
+ modified_image_pil = Image.fromarray(modified_image)
127
+ left_img = modified_image_pil.crop((0, 0, split_point, height))
128
+ right_img =modified_image_pil.crop((split_point, 0, width, height))
129
+ return [left_img,left_data,right_img,right_data]
130
+
131
+ def refine(jpg_path,json_path,save_dir):
132
+ """对一张图片进行half divide,直到子图都不超过300"""
133
+ data=load_json(json_path)
134
+ n=len(data['shapes'])
135
+ name=jpg_path.split('/')[-1].split('.')[0]
136
+ img = Image.open(jpg_path)
137
+ if n<300:
138
+
139
+ img.save(save_dir+name+f'.jpg')
140
+ save_json(save_dir+name+f'.json',data)
141
+ return None
142
+ else:
143
+ left_img,left_data,right_img,right_data=half_divide(img,data)
144
+ ###储存所有当下的子图和子data
145
+ sub_img=[left_img,right_img]
146
+ sub_data=[left_data,right_data]
147
+ i=0
148
+ while True:
149
+ if i==len(sub_img):
150
+ break
151
+ simg=sub_img[i]
152
+ sdata=sub_data[i]
153
+ if len(sdata['shapes'])>=300:
154
+ sub_img.pop(i)
155
+ sub_data.pop(i)
156
+ li,ld,ri,rd=half_divide(simg,sdata)
157
+ sub_img.append(li)
158
+ sub_img.append(ri)
159
+ sub_data.append(ld)
160
+ sub_data.append(rd)
161
+ i-=1
162
+ i+=1
163
+ j=0
164
+ for pic,d in zip(sub_img,sub_data):
165
+ save_json(save_dir+name+f'_{j}.json',d)
166
+ pic.save(save_dir+name+f'_{j}.jpg')
167
+ j+=1
168
+
169
+ def get_union(b1,b2):
170
+ """求box之间的union,用于合并得列"""
171
+ x1,y1,x2,y2=b1[0][0],b1[0][1],b1[1][0],b1[1][1]
172
+ x3,y3,x4,y4=b2[0][0],b2[0][1],b2[1][0],b2[1][1]
173
+ x=min(x1,x2,x3,x4)
174
+ X=max(x1,x2,x3,x4)
175
+ y=min(y1,y2,y3,y4)
176
+ Y=max(y1,y2,y3,y4)
177
+ return [[x,y],[X,Y]]
178
+ def list_union(boxes):
179
+ """求一个box列表的union,得这列的box"""
180
+ result=boxes[0]
181
+ for item in boxes[1:]:
182
+ result=get_union(result,item)
183
+ return result
184
+ def get_col_jsons(json_files,jpg_files,base,destination_jpgs):
185
+ """从gen_data转换为col_data,注意不是构建数据集,而是对每个json从字得列重新储存"""
186
+ for file_path,jpg_path in tqdm(zip(json_files,jpg_files)):
187
+
188
+ os.makedirs(destination_jpgs, exist_ok=True)
189
+
190
+ # 构建源文件的完整路径
191
+ source_file_path = os.path.join(base, jpg_path)
192
+
193
+ # 构建目标文件的完整路径
194
+ destination_file_path = os.path.join(destination_jpgs, jpg_path)
195
+
196
+ # 复制文件到目标文件夹
197
+ shutil.copy2(source_file_path, destination_file_path)
198
+
199
+ i=file_path.split('.')[0]
200
+ with open(base+file_path, 'r', encoding='utf-8') as file:
201
+ data = json.load(file)
202
+ height=data["imageHeight"]
203
+ width=data["imageWidth"]
204
+ content=data['shapes']
205
+ info=[]
206
+ dic={}
207
+ results=[]
208
+ for item in content:
209
+ col=item['col']
210
+ if col not in dic:
211
+ dic[col]=[item['points']]
212
+ else:
213
+ dic[col].append(item['points'])
214
+ for key,value in dic.items():
215
+ union=list_union(value)
216
+ results.append({'label':key,'points':union})
217
+ data['shapes']=results
218
+ save_json(os.path.join(destination_jpgs,file_path ),data)
219
+ def drawBoxes(results,jpg_path,save_path):
220
+ frame = cv2.imread(jpg_path)
221
+ for points in results:
222
+ x1, y1, x2, y2 = int(points[0][0]), int(points[0][1]), int(points[1][0]), int(points[1][1])
223
+ cv2.rectangle(frame, (x1, y1), (x2, y2), thickness=2,color=(255,0,0),lineType=cv2.LINE_AA)
224
+ label_position = ((x1+x2)//2,(y1+y2)//2) # Adjust the position of the label as needed
225
+ #cv2.putText(frame, str(idx), label_position, cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 1, cv2.LINE_AA)
226
+ name=jpg_path.split("/")[-1]
227
+ cv2.imwrite(save_path+"ordered_"+name,frame)
228
+
229
+
230
+ def intersection_length(x1, x3, x2, x4):
231
+ # 计算两个区间的交集起始点和结束点
232
+ start = max(x1, x2)
233
+ end = min(x3, x4)
234
+
235
+ # 如果交集起始点小于结束点,说明有交集
236
+ if start < end:
237
+ return end - start
238
+ else:
239
+ return 0
240
+
241
+
242
+ def union_length(x1, x3, x2, x4):
243
+ # 计算并集起始点和结束点
244
+ start = min(x1, x2)
245
+ end = max(x3, x4)
246
+
247
+ # 计算并集长度
248
+ union_len = end - start
249
+
250
+ return union_len
251
+
252
+
253
+ def distance_or_intersection(x1, x3, x2, x4):
254
+ # 计算不相交两个区间的最短距离
255
+ distance = min(abs(x1 - x4), abs(x2 - x3))
256
+
257
+ # 判断是否相交
258
+ if intersection_length(x1, x3, x2, x4) > 0:
259
+ return 0 # 区间相交,返回0
260
+ else:
261
+ return distance # 区间不相交,返回最短距离
262
+
263
+
264
+ def union(p1, p2):
265
+ [x1, y1], [x2, y2] = p1
266
+ [x3, y3], [x4, y4] = p2
267
+ lx = min(x1, x3)
268
+ ly = min(y1, y3)
269
+ rx = max(x2, x4)
270
+ ry = max(y2, y4)
271
+ return [[lx, ly], [rx, ry]]
272
+
273
+ def merge_boxes(boxes,thresx=0.7, thresy=2):
274
+
275
+
276
+ boxes = sorted(boxes, key=lambda box: (box[0][1]+box[1][1])/2)
277
+
278
+ now_len=len(boxes)
279
+ for _ in range(10):
280
+ ydis_mean = 0
281
+ for item in boxes:
282
+ [x1, y1], [x3, y3] = item
283
+ ydis_mean += abs(y1 - y3)
284
+ length = len(boxes)
285
+ if length==0:
286
+ break
287
+ ydis_mean /= length
288
+ i = 0
289
+ while i < length:
290
+ j = 0
291
+ # 依次遍历除自身外的全部box
292
+ while j < length:
293
+ mainbox = boxes[i]
294
+ if i == j:
295
+ j += 1
296
+ continue
297
+ length = len(boxes)
298
+ # 算x区间上相交的程度
299
+ intersection = intersection_length(mainbox[0][0], mainbox[1][0], boxes[j][0][0], boxes[j][1][0])
300
+ x_rate = intersection / min(abs(mainbox[0][0] - mainbox[1][0]), abs(boxes[j][0][0] - boxes[j][1][0]))
301
+
302
+ # 算y区间上相远离的程度,使用与字的y间距大小平均值的比值
303
+ y_dis = distance_or_intersection(boxes[i][0][1], boxes[i][1][1], boxes[j][0][1], boxes[j][1][1])
304
+ y_rate = y_dis / ydis_mean
305
+ h1=abs(boxes[i][0][0]-boxes[i][1][0])
306
+ h2=abs(boxes[j][0][0]-boxes[j][1][0])
307
+ l1=abs(boxes[i][0][1]-boxes[i][1][1])
308
+ l2=abs(boxes[j][0][1]-boxes[j][1][1])
309
+ s1=h1*l1
310
+ s2=h2*l2
311
+
312
+ y_rate=y_dis/((l1+l2)/2)
313
+ #print(min(s1,s2)/max(s1,s2))
314
+ if x_rate > thresx and y_rate < thresy:
315
+ rm = boxes[j]
316
+
317
+ u = union(mainbox, rm)
318
+ # 更新第boxes[i],删除被合并的boxes[j]
319
+ boxes[i] = u
320
+ boxes.remove(rm)
321
+ # 处理各个指标的改变
322
+ if j < i:
323
+ i -= 1
324
+ length -= 1
325
+ j -= 1
326
+ j += 1
327
+ i += 1
328
+ if now_len==len(boxes):
329
+ break
330
+ now_len=len(boxes)
331
+ return boxes
332
+
333
+ def merge_boxes_new(boxes):
334
+ boxes = sorted(boxes, key=lambda box: (box[0][1]+box[1][1])/2)
335
+
336
+ def combine_boxes(js,jpg):
337
+ data=load_json(js)
338
+ boxes=[]
339
+ h,w=data['imageHeight'],data['imageWidth']
340
+ for item in data['shapes']:
341
+ boxes.append(item['points'])
342
+ columns=merge_boxes(boxes)
343
+ columns=[[item[0][0],item[0][1],item[1][0],item[1][1]] for item in columns]
344
+ drawBoxes(columns,jpg,"/home/tangjq/WORK/boxes_sort/char2columns/")
345
+
346
+ def char2col(jpg_path,boxes):
347
+ columns=merge_boxes(boxes.copy())
348
+ img = cv2.imread(jpg_path)
349
+ h, w, channels = img.shape
350
+
351
+ results={"imageHeight":h,"imageWidth":w,"shapes":[{"points":col} for col in columns]}
352
+ return results
353
+
354
+ def build_transform(input_size):
355
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
356
+ transform = T.Compose([
357
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
358
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
359
+ T.ToTensor(),
360
+ T.Normalize(mean=MEAN, std=STD)
361
+ ])
362
+ return transform
363
+
364
+
365
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
366
+ best_ratio_diff = float('inf')
367
+ best_ratio = (1, 1)
368
+ area = width * height
369
+ for ratio in target_ratios:
370
+ target_aspect_ratio = ratio[0] / ratio[1]
371
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
372
+ if ratio_diff < best_ratio_diff:
373
+ best_ratio_diff = ratio_diff
374
+ best_ratio = ratio
375
+ elif ratio_diff == best_ratio_diff:
376
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
377
+ best_ratio = ratio
378
+ return best_ratio
379
+
380
+
381
+ def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
382
+ orig_width, orig_height = image.size
383
+ aspect_ratio = orig_width / orig_height
384
+
385
+ # calculate the existing image aspect ratio
386
+ target_ratios = set(
387
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
388
+ i * j <= max_num and i * j >= min_num)
389
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
390
+
391
+ # find the closest aspect ratio to the target
392
+ target_aspect_ratio = find_closest_aspect_ratio(
393
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
394
+
395
+ # calculate the target width and height
396
+ target_width = image_size * target_aspect_ratio[0]
397
+ target_height = image_size * target_aspect_ratio[1]
398
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
399
+
400
+ # resize the image
401
+ resized_img = image.resize((target_width, target_height))
402
+ processed_images = []
403
+ for i in range(blocks):
404
+ box = (
405
+ (i % (target_width // image_size)) * image_size,
406
+ (i // (target_width // image_size)) * image_size,
407
+ ((i % (target_width // image_size)) + 1) * image_size,
408
+ ((i // (target_width // image_size)) + 1) * image_size
409
+ )
410
+ # split the image
411
+ split_img = resized_img.crop(box)
412
+ processed_images.append(split_img)
413
+ assert len(processed_images) == blocks
414
+ if use_thumbnail and len(processed_images) != 1:
415
+ thumbnail_img = image.resize((image_size, image_size))
416
+ processed_images.append(thumbnail_img)
417
+ return processed_images
418
+
419
+
420
+ def load_image_2(image, input_size=448, max_num=12):
421
+ if isinstance(image,str):
422
+ image=Image.open(image).convert("RGB")
423
+ width, height = image.size
424
+
425
+ # 按比例缩放
426
+ if max(width, height) <= 200:
427
+ scale_factor = 200 / max(width, height)
428
+ elif max(width, height) >= 350:
429
+ scale_factor = 350 / max(width, height)
430
+ else:
431
+ scale_factor = 1.0
432
+
433
+ # 缩放图像
434
+ new_width = int(width * scale_factor)
435
+ new_height = int(height * scale_factor)
436
+ image = image.resize((new_width, new_height))
437
+
438
+ # 居中填充白色
439
+ padded_image = ImageOps.expand(image, border=(
440
+ (input_size - new_width) // 2, # 左边填充
441
+ (input_size - new_height) // 2, # 上边填充
442
+ (input_size - new_width + 1) // 2, # 右边填充
443
+ (input_size - new_height + 1) // 2 # 下边填充
444
+ ), fill=(255, 255, 255)) # 填充为白色
445
+ transform = build_transform(input_size=input_size)
446
+
447
+ # 预处理图像并将结果堆叠为张量
448
+ images = dynamic_preprocess(padded_image, image_size=input_size, use_thumbnail=True, max_num=max_num)
449
+ pixel_values = [transform(image) for image in images]
450
+ pixel_values = torch.stack(pixel_values)
451
+
452
+ return pixel_values
453
+ # transform = build_transform(input_size=input_size)
454
+ # # 看看是否最后的输入resized的整张图片会有影响
455
+ # images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
456
+ # # for i, item in enumerate(images):
457
+ # # item.save(os.path.join('/home/luoyx/InternVL/for_debug', f'{i}.png'))
458
+ # pixel_values = [transform(image) for image in images]
459
+ # pixel_values = torch.stack(pixel_values)
460
+ # return pixel_values
461
+
462
+
463
+ def load_image(image_file, input_size=448, max_num=12):
464
+ if isinstance(image_file,str):
465
+ image = Image.open(image_file).convert('RGB')
466
+ else:
467
+ image=image_file
468
+ # resize图片
469
+ # image = image.resize((448, 448))
470
+
471
+ transform = build_transform(input_size=input_size)
472
+ # 看看是否最后的输入resized的整张图片会有影响
473
+ images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
474
+
475
+ pixel_values = [transform(image) for image in images]
476
+
477
+ pixel_values = torch.stack(pixel_values)
478
+ return pixel_values
479
+
480
+
481
+ def remove_chinese_punctuation(text):
482
+ # 定义中文标点符号的正则表达式
483
+ chinese_punctuation_regex = re.compile(r'[\u3002\uFF1F\uFF01\u3001\uff0c\u300c\u300d\u300e\u300f\u2018\u2019\u201c\u201d\u2013\u2014\u2026\u3010\u3011\u300a\u300b\uff1a\uff1b]')
484
+ # 使用sub函数将匹配到的中文标点替换为空字符串
485
+ return chinese_punctuation_regex.sub('', text)
486
+
487
+ def remove_english_punctuation(text):
488
+
489
+ english_punctuation_regex = re.compile(r'[,\.!?:\'";\(\)\[\]\{\}\-\n\*1234567890]')
490
+
491
+ return english_punctuation_regex.sub('', text)
492
+
493
+ def get_image_paths(folder_path):
494
+ image_paths = []
495
+
496
+ # 遍历文件夹中的所有文件
497
+ for root, dirs, files in os.walk(folder_path):
498
+ for file in files:
499
+ # 检查文件是否为图片
500
+ if imghdr.what(os.path.join(root, file)): # imghdr.what() 可以识别图片文件类型
501
+ image_paths.append(os.path.join(root, file))
502
+
503
+ return image_paths
504
+
505
+ def is_image(file_path):
506
+ try:
507
+ result=imghdr.what(file_path)
508
+ if result is not None:
509
+ return True
510
+ return False
511
+ except:
512
+ return False
513
+