TestTranslator / lib /report.py
yujuanqin's picture
fix ZeroDivisionError
b27f71f
from datetime import datetime, timedelta
from typing import List
from dataclasses import dataclass, astuple
from tabulate import tabulate
from lib.log_parser import LogTag, LogItem, WebItem
from lib.utils import *
class LogReport:
"""用于处理 log文件"""
def __init__(self):
self.items:List[LogItem] = []
def from_logfile(self, log_file, start_line=0):
"""将log文件中有效的行转换成 LogItem,返回 LogItem的列表
和当前文件的行数(用于下一个 case读取 log文件的起始行数)"""
print(f"generate LogReport from logfile: {log_file}")
with open(log_file, "r") as f:
lines = f.readlines()
print(f"read log file lines {start_line}:{len(lines)}")
for l in lines[start_line:]:
for item in LogTag:
if item.value in l:
log_item = LogItem.from_log(item, l)
self.items.append(log_item)
return self.items, len(lines)
def item_to_rows(self):
"""将 LogItem 列表转换成 csv行的格式,每行以 audio开始"""
rows = []
current_line = []
for index, item in enumerate(self.items):
if item.tag in [LogTag.load_start, LogTag.load_end]:
continue
# 每次检查到 audio_end就另起一行
if item.tag == LogTag.audio_end:
rows.append(current_line)
current_line = []
current_line += [item.tag.name, item.timestamp, item.content]
return rows
def to_csv(self, csv_path=None):
header_mapping = {
# 注释掉header,在 csv中就不保留对应的列
"audio_end_tag": 0,
"audio_end_tsp": 1,
"audio_length": 2,
"transcribe_cost_tag": 3,
"transcribe_cost_tsp": 4,
"transcribe_cost": 5,
"transcribe_end_tag": 6,
"transcribe_end_tsp": 7,
"transcribe_output": 8,
"translate_start_tag": 9,
"translate_start_tsp": 10,
"translate_input": 11,
"translate_cost_tag": 12,
"translate_cost_tsp": 13,
"translate_cost": 14,
"translate_end_tag": 15,
"translate_end_tsp": 16,
"translate_output": 17,
}
rows = self.item_to_rows()
header = list(header_mapping.keys())
rows = [[row[i] for i in header_mapping.values() if i < len(row)] for row in rows]
save_csv(csv_path, header, rows)
@dataclass
class DelaySummary:
audio_name:str = ""
trans_type: str = ""
audio_length:str = ""
load_start: datetime =None
load_end: datetime=None
load: float = 0
avg_audio_len: float = 0
total_tsb: float = 0
avg_tsb_per_second: float = 0
total_tsl: float = 0
avg_tsl_per_second: float = 0
total_web: float = 0
avg_web_per_second: float = 0
avg_web_freq: float = 0
@dataclass
class DelayDetailRow:
audio_end_tsp:datetime = ""
audio_length:float =0
tsb_end_tsp:datetime =""
tsb_opt:str =""
tsb_cost:float = 0
tsb_cost_per_second: float = 0
tsl_ipt:str =""
tsl_end_tsp:datetime =""
tsl_opt:str =""
tsl_cost:float =0
tsl_cost_per_second: float = 0
web_tsp:datetime =""
web_src:str =""
web_dst:str =""
web_delay: float = 0
web_delay_per_second: float = 0
web_freq: float = 0
def __repr__(self):
return f"Row(audio_length={self.audio_length}, tsb_opt={self.tsb_opt})"
@dataclass
class DelayItem:
"""存储delay 报告中每一个 case的结果"""
translation_type: str = ''
audio: str = ""
audio_length: str = ""
web_items: List[WebItem] = None
log_items: List[LogItem] = None
def to_rows(self):
"""将 log和 web的结果合并, 返回 DelaySummary和 DelayDetail的列表
返回 row_0包含音频信息和 load 时间
rows 是每次推理的详细信息"""
print(f"length of log_items: {len(self.log_items)}")
web_items_dict = {i.src_text + i.dst_text: i for i in self.web_items}
summary = DelaySummary(audio_name=self.audio,trans_type=self.translation_type,
audio_length=self.audio_length)
detail_rows = []
current_row = DelayDetailRow()
for i in self.log_items:
if i.tag == LogTag.load_start:
summary.load_start = i.timestamp
elif i.tag == LogTag.load_end:
summary.load_end = i.timestamp
summary.load = (summary.load_end-summary.load_start).total_seconds()
elif i.tag == LogTag.audio_end:
if current_row.audio_length > 0:
detail_rows.append(current_row)
# 每次到 audio_end就是新的一行
current_row = DelayDetailRow()
current_row.audio_end_tsp = i.timestamp
current_row.audio_length = time_to_float(i.content)
elif i.tag == LogTag.transcribe_end:
current_row.tsb_end_tsp = i.timestamp
current_row.tsb_opt = i.content
elif i.tag == LogTag.transcribe_cost:
current_row.tsb_cost = time_to_float(i.content)
current_row.tsb_cost_per_second = current_row.tsb_cost/current_row.audio_length if current_row.audio_length else 0
elif i.tag == LogTag.translate_start:
current_row.tsl_ipt = i.content
elif i.tag in [LogTag.translate_end, LogTag.translate_large_end]:
current_row.tsl_end_tsp = i.timestamp
current_row.tsl_opt = i.content
# 假设一行有翻译结果时,就一定已经有asr的结果
if web_item:=web_items_dict.get(current_row.tsb_opt+current_row.tsl_opt):
current_row.web_tsp = web_item.timestamp
current_row.web_src = web_item.src_text
current_row.web_dst = web_item.dst_text
current_row.web_delay = (current_row.web_tsp - current_row.audio_end_tsp).total_seconds()
current_row.web_delay_per_second = current_row.web_delay / current_row.audio_length if current_row.audio_length else 0
# 删除 dict已匹配过的内容,避免多次匹配
web_items_dict.pop(current_row.tsb_opt+current_row.tsl_opt)
if len(detail_rows)>=1 and detail_rows[-1].web_tsp:
current_row.web_freq = (current_row.web_tsp - detail_rows[-1].web_tsp).total_seconds()
elif i.tag in [LogTag.translate_cost, LogTag.translate_large_cost]:
current_row.tsl_cost = time_to_float(i.content)
current_row.tsl_cost_per_second = current_row.tsl_cost/current_row.audio_length if current_row.audio_length else 0
summary = self.get_summary(summary, detail_rows)
return summary, detail_rows # [astuple(i) for i in rows]
def get_summary(self,summary: DelaySummary, detail_rows):
audio_len = []
total_tsb = []
avg_tsb_per_second = []
total_tsl = []
avg_tsl_per_second = []
total_web = []
avg_web_per_second = []
web_freq = []
for row in detail_rows:
if row.audio_length:
audio_len.append(row.audio_length)
if row.tsb_cost:
total_tsb.append(row.tsb_cost)
if row.tsb_cost_per_second:
avg_tsb_per_second.append(row.tsb_cost_per_second)
if row.tsl_cost:
total_tsl.append(row.tsl_cost)
if row.tsl_cost_per_second:
avg_tsl_per_second.append(row.tsl_cost_per_second)
if row.web_delay:
total_web.append(row.web_delay)
if row.web_delay_per_second:
avg_web_per_second.append(row.web_delay_per_second)
if row.web_freq:
web_freq.append(row.web_freq)
summary.avg_audio_len = sum(audio_len) / len(audio_len) if len(audio_len)>0 else 0
summary.total_tsb = sum(total_tsb)
summary.avg_tsb_per_second = sum(avg_tsb_per_second) / len(avg_tsb_per_second) if len(avg_tsb_per_second)>0 else 0
summary.total_tsl = sum(total_tsl)
summary.avg_tsl_per_second = sum(avg_tsl_per_second) / len(avg_tsl_per_second) if len(avg_tsl_per_second)>0 else 0
summary.total_web = sum(total_web)
summary.avg_web_per_second = sum(avg_web_per_second) / len(avg_web_per_second) if len(avg_web_per_second)>0 else 0
summary.avg_web_freq = sum(web_freq) /len(web_freq) if len(web_freq)>0 else 0
return summary
class DelayReport:
"""存储delay 报告中所有 case的结果"""
start_line = 0
items: List[DelayItem] = []
def print_summary(self, data):
print(tabulate(data))
def to_csv(self, csv_path):
summaries = [["audio_name", "translation", "audio_length",
"load_start", "load_end", "load", "avg_audio_len",
"total_tsb", "avg_tsb_per_sec", "total_tsl", "avg_tsl_per_sec",
"total_web", "avg_web_per_sec", "avg_web_freq"]]
details = [["audio_end_tsp", "audio_length",
"tsb_end_tsp", "tsp_opt", "tsb_cost", "tsb_cost_per_sec",
"tsl_ipt", "tsl_end_tsp", "tsl_opt", "tsl_cost", "tsl_cost_per_sec",
"web_tsp", "web_src", "web_dst", "web_delay", "web_delay_per_sec", "web_freq"]]
for i in self.items:
summary, detail_rows = i.to_rows()
summaries.append(astuple(summary))
details += [astuple(i) for i in detail_rows]
details.append([])
self.print_summary(summaries)
save_csv(csv_path, [], summaries+[[]]+details)
@dataclass
class AccuracyItem:
"""存储accuracy 报告中每一个 case的结果"""
translation_type: str = ''
audio: str = ""
audio_length: str = ""
audio_text: str = ""
src_text: str = ""
dst_text: str = ""
asr_accuracy: tuple= (0,1)
text_compare: str = ""
def __post_init__(self):
if self.translation_type == "en2zh":
text1 = clean_text_for_comparison_en(self.audio_text)
text2 = clean_text_for_comparison_en(self.src_text)
spliter = " "
else:
text1 = clean_text_for_comparison_zh(self.audio_text)
text2 = clean_text_for_comparison_zh(self.src_text)
spliter = ""
self.asr_accuracy = run_textdistance(text1, text2)
self.text_compare = highlight_diff(text1, text2, spliter)
def to_list(self):
return [self.audio, self.translation_type, self.audio_length,
self.asr_accuracy[0], self.asr_accuracy[1],
self.src_text, self.audio_text, self.text_compare]
class AccuracyReport:
items:List[AccuracyItem] = []
def print_summary(self):
header = ["audio", "distance", "normalized distance"]
rows = [[i.audio, i.asr_accuracy[0], i.asr_accuracy[1]] for i in self.items]
print(tabulate(rows, header))
def to_csv(self, csv_path):
print("accuracy item length: ", len(self.items))
self.print_summary()
header = ["audio_name", "translation", "audio_length",
"distance", "normalized distance",
"src text", "audio text", "text compare"]
save_csv(csv_path, header, [i.to_list() for i in self.items])