Cpp4App_test / CDM /run_single.py
HaochenGong
change time cost count
1c42b13
from os.path import join as pjoin
import cv2
import os
import shutil
import time
import json
import numpy as np
import CDM.detect_compo.ip_region_proposal as ip
import CDM.detect_classify.classification as clf
import pandas as pd
import openai
# def summarize_segment(segment):
# openai.api_key = os.environ.get('openai_key')
#
# prompt = f"Shorten this paragraph: \"{str(segment)}\"."
#
# response = openai.ChatCompletion.create(
# # engine="text-davinci-002",
# model="gpt-3.5-turbo",
# messages=[
# # {"role": "system", "content": "You are a helpful assistant."},
# {"role": "user", "content": prompt}
# ],
# max_tokens=400,
# n=1,
# stop=None,
# temperature=0,
# )
#
# shortened_segment = response.choices[0].message['content']
#
# return shortened_segment
# model = clf.get_clf_model("ViT")
def resize_height_by_longest_edge(img_path, resize_length=800):
org = cv2.imread(img_path)
height, width = org.shape[:2]
if height > width:
return resize_length
else:
return int(resize_length * (height / width))
def run_single_img(input_img, output_root, segment_root):
# input_img_root = "./input_examples/"
# output_root = "./result_classification"
# segment_root = '../scrutinizing_alexa/txt'
global output_boards
if os.path.exists(output_root):
shutil.rmtree(output_root)
os.makedirs(output_root)
# image_list = os.listdir(input_img_root)
#
# input_imgs = [input_img_root + image_name for image_name in image_list]
key_params = {'min-grad': 4, 'ffl-block': 5, 'min-ele-area': 50, 'merge-contained-ele': True,
'max-word-inline-gap': 10, 'max-line-ingraph-gap': 4, 'remove-top-bar': False}
is_ip = True
is_clf = False
is_ocr = True
is_merge = True
is_classification = True
# # Load deep learning models in advance
# compo_classifier = None
# if is_ip and is_clf:
# compo_classifier = {}
# from cnn.CNN import CNN
# # compo_classifier['Image'] = CNN('Image')
# compo_classifier['Elements'] = CNN('Elements')
# # compo_classifier['Noise'] = CNN('Noise')
# ocr_model = None
if is_ocr:
import CDM.detect_text.text_detection as text
# set the range of target inputs' indices
# num = 0
# start_index = 30800 # 61728
# end_index = 100000
img_time_cost_all = []
ocr_time_cost_all = []
ic_time_cost_all = []
ts_time_cost_all = []
cd_time_cost_all = []
resize_by_height = 800
# for input_img in input_imgs:
output_data = pd.DataFrame(columns=['screenshot', 'id', 'label', 'index', 'text', 'sentences'])
this_img_start_time = time.time()
resized_height = resize_height_by_longest_edge(input_img, resize_by_height)
index = input_img.split('/')[-1][:-4]
# if index != "1-1" and index != "1-2":
# continue
if is_ocr:
os.makedirs(pjoin(output_root, 'ocr'), exist_ok=True)
this_ocr_time_cost = text.text_detection(input_img, output_root, show=False, method='google') # pytesseract
ocr_time_cost_all.append(this_ocr_time_cost)
if is_ip:
os.makedirs(pjoin(output_root, 'ip'), exist_ok=True)
this_cd_time_cost = ip.compo_detection(input_img, output_root, key_params,
resize_by_height=resized_height, show=False)
cd_time_cost_all.append(this_cd_time_cost)
detection_cost = time.time() - this_img_start_time
if is_merge:
import CDM.detect_merge.merge as merge
os.makedirs(pjoin(output_root, 'merge'), exist_ok=True)
compo_path = pjoin(output_root, 'ip', str(index) + '.json')
ocr_path = pjoin(output_root, 'ocr', str(index) + '.json')
board_merge, components_merge = merge.merge(input_img, compo_path, ocr_path, pjoin(output_root, 'merge'),
is_remove_top_bar=key_params['remove-top-bar'], show=False)
# ic_time_cost_all.append(this_ic_time_cost)
# ts_time_cost_all.append(this_ts_time_cost)
if is_classification:
os.makedirs(pjoin(output_root, 'classification'), exist_ok=True)
merge_path = pjoin(output_root, 'merge', str(index) + '.json')
merge_json = json.load(open(merge_path, 'r'))
os.makedirs(pjoin(output_root, 'classification', 'GUI'), exist_ok=True)
this_time_cost_ic, this_time_cost_ts, output_data, output_boards, classification_cost = clf.compo_classification(input_img, output_root,
segment_root, merge_json,
output_data,
resize_by_height=resize_by_height, clf_model="ViT", model = clf.get_clf_model("ViT"))
ic_time_cost_all.append(this_time_cost_ic)
ts_time_cost_all.append(this_time_cost_ts)
this_img_time_cost = time.time() - this_img_start_time
img_time_cost_all.append(this_img_time_cost)
# print("time cost for this image: %2.2f s" % this_img_time_cost)
print("检测+分类共花费: %2.2f s" % (classification_cost + detection_cost))
if os.path.isfile(output_root + '/output.csv'):
output_data.to_csv(output_root + '/output.csv', index=False, mode='a', header=False)
else:
output_data.to_csv(output_root + '/output.csv', index=False, mode='w')
# avg_ocr_time_cost = sum(ocr_time_cost_all) / len(ocr_time_cost_all)
# avg_cd_time_cost = sum(cd_time_cost_all) / len(cd_time_cost_all)
# avg_ic_time_cost = sum(ic_time_cost_all) / len(ic_time_cost_all)
# avg_ts_time_cost = sum(ts_time_cost_all) / len(ts_time_cost_all)
# avg_time_cost = sum(img_time_cost_all) / len(img_time_cost_all)
# print("average text extraction time cost for this app: %2.2f s" % avg_ocr_time_cost)
# print("average widget detection time cost for this app: %2.2f s" % avg_cd_time_cost)
# print("average icon classification time cost for this app: %2.2f s" % avg_ic_time_cost)
# print("average text selection processing time cost for this app: %2.2f s" % avg_ts_time_cost)
# print("average screenshot processing time cost for this app: %2.2f s" % avg_time_cost)
# short_output_data = output_data[['id', 'label', 'text']].copy()
# short_output_data = short_output_data.rename(columns={'text': 'segment'})
# summarize segments:
# original_output_data = short_output_data.copy()
# retries = 3
# for index in range(1, len(short_output_data)):
# seg = short_output_data.loc[index, 'segment']
# for i in range(retries):
# try:
# shortened_seg = summarize_segment(seg)
# break
# except openai.error.RateLimitError as e:
# if "overloaded" in str(e):
# # Exponential backoff with jitter
# sleep_time = 2 * (2 ** i) + 0.1
# time.sleep(sleep_time)
# except Exception as e:
# # If you wish, you can print or log the exception details here without raising it
# print(e)
# else:
# # This part will be executed if the for loop doesn't hit 'break'
# shortened_seg = seg
#
# short_output_data.loc[index, 'segment'] = shortened_seg
# original_output = []
# retries = 3
# summarized_data = [] # List to hold summarized rows
# for index, row in short_output_data.iterrows():
# seg = row['segment']
# for i in range(retries):
# try:
# shortened_seg = summarize_segment(seg)
# break
# except openai.error.RateLimitError as e:
# if "overloaded" in str(e):
#
# sleep_time = 2 * (2 ** i) + 0.1
# # sleep_time = 3
# time.sleep(sleep_time)
# except Exception as e:
# # If you wish, you can print or log the exception details here without raising it
# print(e)
# else:
# # This part will be executed if the for loop doesn't hit 'break'
# shortened_seg = seg
#
# summarized_data.append({'id': row['id'], 'label': row['label'], 'segment': shortened_seg})
# original_output.append({'id': row['id'], 'label': row['label'], 'segment': seg[0].upper() + seg[1:]})
#
# summarized_output_data = pd.DataFrame(summarized_data)
# original_output_data = pd.DataFrame(original_output)
return output_boards