HaochenGong commited on
Commit
e4c79c3
·
1 Parent(s): d32a207

batch images

Browse files
.idea/workspace.xml CHANGED
@@ -4,7 +4,12 @@
4
  <option name="autoReloadType" value="SELECTIVE" />
5
  </component>
6
  <component name="ChangeListManager">
7
- <list default="true" id="5e4481c0-7ba2-42e4-bbe6-4c36a0d36baa" name="Changes" comment="remove dotenv" />
 
 
 
 
 
8
  <option name="SHOW_DIALOG" value="false" />
9
  <option name="HIGHLIGHT_CONFLICTS" value="true" />
10
  <option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
@@ -99,14 +104,7 @@
99
  <workItem from="1723386970626" duration="451000" />
100
  <workItem from="1723387453009" duration="27328000" />
101
  <workItem from="1723978405562" duration="51637000" />
102
- <workItem from="1725925225487" duration="135179000" />
103
- </task>
104
- <task id="LOCAL-00011" summary="google drive">
105
- <created>1723456219009</created>
106
- <option name="number" value="00011" />
107
- <option name="presentableId" value="LOCAL-00011" />
108
- <option name="project" value="LOCAL" />
109
- <updated>1723456219009</updated>
110
  </task>
111
  <task id="LOCAL-00012" summary="google drive">
112
  <created>1723456794848</created>
@@ -444,7 +442,14 @@
444
  <option name="project" value="LOCAL" />
445
  <updated>1728915832031</updated>
446
  </task>
447
- <option name="localTasksCounter" value="60" />
 
 
 
 
 
 
 
448
  <servers />
449
  </component>
450
  <component name="TypeScriptGeneratedFilesManager">
@@ -481,9 +486,10 @@
481
  <MESSAGE value="output img type" />
482
  <MESSAGE value="." />
483
  <MESSAGE value="remove dotenv" />
484
- <option name="LAST_COMMIT_MESSAGE" value="remove dotenv" />
 
485
  </component>
486
  <component name="com.intellij.coverage.CoverageDataManagerImpl">
487
- <SUITE FILE_PATH="coverage/Cpp4App_test$app.coverage" NAME="app Coverage Results" MODIFIED="1728910172560" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
488
  </component>
489
  </project>
 
4
  <option name="autoReloadType" value="SELECTIVE" />
5
  </component>
6
  <component name="ChangeListManager">
7
+ <list default="true" id="5e4481c0-7ba2-42e4-bbe6-4c36a0d36baa" name="Changes" comment="font">
8
+ <change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
9
+ <change beforePath="$PROJECT_DIR$/CDM/detect_classify/classification.py" beforeDir="false" afterPath="$PROJECT_DIR$/CDM/detect_classify/classification.py" afterDir="false" />
10
+ <change beforePath="$PROJECT_DIR$/app.py" beforeDir="false" afterPath="$PROJECT_DIR$/app.py" afterDir="false" />
11
+ <change beforePath="$PROJECT_DIR$/logs/app.log" beforeDir="false" afterPath="$PROJECT_DIR$/logs/app.log" afterDir="false" />
12
+ </list>
13
  <option name="SHOW_DIALOG" value="false" />
14
  <option name="HIGHLIGHT_CONFLICTS" value="true" />
15
  <option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
 
104
  <workItem from="1723386970626" duration="451000" />
105
  <workItem from="1723387453009" duration="27328000" />
106
  <workItem from="1723978405562" duration="51637000" />
107
+ <workItem from="1725925225487" duration="136670000" />
 
 
 
 
 
 
 
108
  </task>
109
  <task id="LOCAL-00012" summary="google drive">
110
  <created>1723456794848</created>
 
442
  <option name="project" value="LOCAL" />
443
  <updated>1728915832031</updated>
444
  </task>
445
+ <task id="LOCAL-00060" summary="font">
446
+ <created>1728916518798</created>
447
+ <option name="number" value="00060" />
448
+ <option name="presentableId" value="LOCAL-00060" />
449
+ <option name="project" value="LOCAL" />
450
+ <updated>1728916518798</updated>
451
+ </task>
452
+ <option name="localTasksCounter" value="61" />
453
  <servers />
454
  </component>
455
  <component name="TypeScriptGeneratedFilesManager">
 
486
  <MESSAGE value="output img type" />
487
  <MESSAGE value="." />
488
  <MESSAGE value="remove dotenv" />
489
+ <MESSAGE value="font" />
490
+ <option name="LAST_COMMIT_MESSAGE" value="font" />
491
  </component>
492
  <component name="com.intellij.coverage.CoverageDataManagerImpl">
493
+ <SUITE FILE_PATH="coverage/Cpp4App_test$app.coverage" NAME="app Coverage Results" MODIFIED="1728918751288" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
494
  </component>
495
  </project>
CDM/detect_classify/classification.py CHANGED
@@ -209,10 +209,6 @@ def compo_classification(input_img, output_root, segment_root, merge_json, outpu
209
 
210
  # --------- classification ----------
211
 
212
- classification_start_time = time.process_time()
213
-
214
- model = get_clf_model(clf_model)
215
-
216
  # for compo in compos:
217
  #
218
  # # comp_grey = grey[compo.row_min:compo.row_max, compo.col_min:compo.col_max]
@@ -290,70 +286,131 @@ def compo_classification(input_img, output_root, segment_root, merge_json, outpu
290
  # else:
291
  # print("clf_model has to be ResNet18 or ViT")
292
 
293
- for compo in compos:
294
- compo_start_time = time.process_time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
 
296
- # 计时预处理部分
297
- preprocess_start = time.process_time()
 
 
 
 
 
 
 
298
  comp_grey = grey[compo.row_min:compo.row_max, compo.col_min:compo.col_max]
299
- preprocess_time = time.process_time() - preprocess_start
300
 
301
- # 计时图像调整部分
302
- resize_start = time.process_time()
303
  if clf_model == "ResNet18":
304
  comp_crop = cv2.resize(comp_grey, (32, 32))
305
  elif clf_model == "ViT":
306
  comp_crop = cv2.resize(comp_grey, (224, 224))
307
- resize_time = time.process_time() - resize_start
308
 
309
- # 计时张量转换部分
310
- tensor_start = time.process_time()
311
  if clf_model == "ResNet18":
312
  comp_crop = comp_crop.reshape(1, 1, 32, 32)
313
- comp_tensor = torch.tensor(comp_crop)
314
- comp_tensor = comp_tensor.permute(0, 1, 3, 2)
315
  elif clf_model == "ViT":
316
  comp_tensor = torch.from_numpy(comp_crop)
317
- comp_tensor = comp_tensor.view(1, 224, 224).repeat(3, 1, 1)
318
- comp_tensor = comp_tensor.unsqueeze(0)
319
- tensor_time = time.process_time() - tensor_start
320
-
321
- # 计时模型推理部分
322
- inference_start = time.process_time()
323
- with torch.no_grad():
324
- if clf_model == "ResNet18":
325
- pred_label = model(comp_tensor)
326
- output = pred_label
327
- elif clf_model == "ViT":
328
- output = model(comp_tensor)
329
- inference_time = time.process_time() - inference_start
330
-
331
- # 计时后处理部分
332
- postprocess_start = time.process_time()
333
  if clf_model == "ResNet18":
334
- predicted = np.argmax(output.cpu().data.numpy(), axis=1)[0]
335
  elif clf_model == "ViT":
336
- _, predicted = torch.max(output.logits, 1)
337
- predicted = predicted.cpu().numpy()[0]
338
 
339
- if str(predicted) in label_dic.keys():
340
- compo.label = label_dic[str(predicted)]
 
 
 
 
 
 
 
 
 
 
341
  elements.append(compo)
342
  else:
343
- compo.label = str(predicted)
344
- postprocess_time = time.process_time() - postprocess_start
345
-
346
- compo_total_time = time.process_time() - compo_start_time
347
-
348
- # 输出每个部分的耗时
349
- print("==============================================")
350
- print(f"Component processing time: {compo_total_time:.4f}s")
351
- print(f" Preprocessing time: {preprocess_time:.4f}s")
352
- print(f" Resize time: {resize_time:.4f}s")
353
- print(f" Tensor conversion time: {tensor_time:.4f}s")
354
- print(f" Inference time: {inference_time:.4f}s")
355
- print(f" Post-processing time: {postprocess_time:.4f}s\n")
356
- print("==============================================")
357
 
358
  time_cost_ic = time.process_time() - classification_start_time
359
  print("time cost for icon classification: %2.2f s" % time_cost_ic)
 
209
 
210
  # --------- classification ----------
211
 
 
 
 
 
212
  # for compo in compos:
213
  #
214
  # # comp_grey = grey[compo.row_min:compo.row_max, compo.col_min:compo.col_max]
 
286
  # else:
287
  # print("clf_model has to be ResNet18 or ViT")
288
 
289
+ # =============================================================================
290
+
291
+ # classification_start_time = time.process_time()
292
+ #
293
+ # model = get_clf_model(clf_model)
294
+ #
295
+ # for compo in compos:
296
+ # compo_start_time = time.process_time()
297
+ #
298
+ # # 计时预处理部分
299
+ # preprocess_start = time.process_time()
300
+ # comp_grey = grey[compo.row_min:compo.row_max, compo.col_min:compo.col_max]
301
+ # preprocess_time = time.process_time() - preprocess_start
302
+ #
303
+ # # 计时图像调整部分
304
+ # resize_start = time.process_time()
305
+ # if clf_model == "ResNet18":
306
+ # comp_crop = cv2.resize(comp_grey, (32, 32))
307
+ # elif clf_model == "ViT":
308
+ # comp_crop = cv2.resize(comp_grey, (224, 224))
309
+ # resize_time = time.process_time() - resize_start
310
+ #
311
+ # # 计时张量转换部分
312
+ # tensor_start = time.process_time()
313
+ # if clf_model == "ResNet18":
314
+ # comp_crop = comp_crop.reshape(1, 1, 32, 32)
315
+ # comp_tensor = torch.tensor(comp_crop)
316
+ # comp_tensor = comp_tensor.permute(0, 1, 3, 2)
317
+ # elif clf_model == "ViT":
318
+ # comp_tensor = torch.from_numpy(comp_crop)
319
+ # comp_tensor = comp_tensor.view(1, 224, 224).repeat(3, 1, 1)
320
+ # comp_tensor = comp_tensor.unsqueeze(0)
321
+ # tensor_time = time.process_time() - tensor_start
322
+ #
323
+ # # 计时模型推理部分
324
+ # inference_start = time.process_time()
325
+ # with torch.no_grad():
326
+ # if clf_model == "ResNet18":
327
+ # pred_label = model(comp_tensor)
328
+ # output = pred_label
329
+ # elif clf_model == "ViT":
330
+ # output = model(comp_tensor)
331
+ # inference_time = time.process_time() - inference_start
332
+ #
333
+ # # 计时后处理部分
334
+ # postprocess_start = time.process_time()
335
+ # if clf_model == "ResNet18":
336
+ # predicted = np.argmax(output.cpu().data.numpy(), axis=1)[0]
337
+ # elif clf_model == "ViT":
338
+ # _, predicted = torch.max(output.logits, 1)
339
+ # predicted = predicted.cpu().numpy()[0]
340
+ #
341
+ # if str(predicted) in label_dic.keys():
342
+ # compo.label = label_dic[str(predicted)]
343
+ # elements.append(compo)
344
+ # else:
345
+ # compo.label = str(predicted)
346
+ # postprocess_time = time.process_time() - postprocess_start
347
+ #
348
+ # compo_total_time = time.process_time() - compo_start_time
349
+ #
350
+ # # 输出每个部分的耗时
351
+ # print("==============================================")
352
+ # print(f"Component processing time: {compo_total_time:.4f}s")
353
+ # print(f" Preprocessing time: {preprocess_time:.4f}s")
354
+ # print(f" Resize time: {resize_time:.4f}s")
355
+ # print(f" Tensor conversion time: {tensor_time:.4f}s")
356
+ # print(f" Inference time: {inference_time:.4f}s")
357
+ # print(f" Post-processing time: {postprocess_time:.4f}s\n")
358
+ # print("==============================================")
359
+
360
+ # =============================================================================
361
 
362
+ classification_start_time = time.process_time()
363
+
364
+ model = get_clf_model(clf_model)
365
+ comp_tensors = []
366
+ elements = []
367
+
368
+ # 收集所有张量
369
+ for compo in compos:
370
+ # 预处理
371
  comp_grey = grey[compo.row_min:compo.row_max, compo.col_min:compo.col_max]
 
372
 
373
+ # 调整图像大小
 
374
  if clf_model == "ResNet18":
375
  comp_crop = cv2.resize(comp_grey, (32, 32))
376
  elif clf_model == "ViT":
377
  comp_crop = cv2.resize(comp_grey, (224, 224))
 
378
 
379
+ # 张量转换
 
380
  if clf_model == "ResNet18":
381
  comp_crop = comp_crop.reshape(1, 1, 32, 32)
382
+ comp_tensor = torch.tensor(comp_crop).permute(0, 1, 3, 2)
 
383
  elif clf_model == "ViT":
384
  comp_tensor = torch.from_numpy(comp_crop)
385
+ comp_tensor = comp_tensor.view(1, 224, 224).repeat(3, 1, 1).unsqueeze(0)
386
+
387
+ comp_tensors.append(comp_tensor)
388
+
389
+ # 将张量堆叠成批次
390
+ batch_tensor = torch.cat(comp_tensors, dim=0)
391
+
392
+ # 模型推理
393
+ with torch.no_grad():
 
 
 
 
 
 
 
394
  if clf_model == "ResNet18":
395
+ output = model(batch_tensor)
396
  elif clf_model == "ViT":
397
+ output = model(batch_tensor)
 
398
 
399
+ # 后处理
400
+ if clf_model == "ResNet18":
401
+ predicted = np.argmax(output.cpu().numpy(), axis=1)
402
+ elif clf_model == "ViT":
403
+ _, predicted = torch.max(output.logits, 1)
404
+ predicted = predicted.cpu().numpy()
405
+
406
+ # 为组件分配标签
407
+ for idx, compo in enumerate(compos):
408
+ pred_label = predicted[idx]
409
+ if str(pred_label) in label_dic.keys():
410
+ compo.label = label_dic[str(pred_label)]
411
  elements.append(compo)
412
  else:
413
+ compo.label = str(pred_label)
 
 
 
 
 
 
 
 
 
 
 
 
 
414
 
415
  time_cost_ic = time.process_time() - classification_start_time
416
  print("time cost for icon classification: %2.2f s" % time_cost_ic)
app.py CHANGED
@@ -29,7 +29,7 @@ from googleapiclient.http import MediaFileUpload
29
  # from dotenv import load_dotenv
30
  import os
31
 
32
- # 加载 .env 文件中的环境变量
33
  # load_dotenv()
34
 
35
  title = "Cpp4App_test"
 
29
  # from dotenv import load_dotenv
30
  import os
31
 
32
+ # # 加载 .env 文件中的环境变量
33
  # load_dotenv()
34
 
35
  title = "Cpp4App_test"
logs/app.log CHANGED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ 2024-10-15 02:12:39 - INFO - Application started
2
+ 2024-10-15 02:12:39 - INFO - file_cache is only supported with oauth2client<4.0.0
3
+ 2024-10-15 02:12:39 - INFO - HTTP Request: GET http://127.0.0.1:7860/startup-events "HTTP/1.1 200 OK"
4
+ 2024-10-15 02:12:39 - INFO - HTTP Request: HEAD http://127.0.0.1:7860/ "HTTP/1.1 200 OK"
5
+ 2024-10-15 02:12:40 - INFO - HTTP Request: GET https://checkip.amazonaws.com/ "HTTP/1.1 200 "
6
+ 2024-10-15 02:12:40 - INFO - HTTP Request: GET https://api.gradio.app/pkg-version "HTTP/1.1 200 OK"
7
+ 2024-10-15 02:12:40 - INFO - HTTP Request: GET https://checkip.amazonaws.com/ "HTTP/1.1 200 "