H-AdminSim_Arena / utils /postprocess.py
ljm565's picture
feat: Updated rating criteria
3703b6a
import os, re
from utils.client import GPTClient
from utils.filesys_utils import json_load, get_files
def make_dialog_dict(dir_path: str = './dialogs') -> dict:
"""
Create a dictionary of dialogs from JSON files in the specified directory.
Args:
dir_path (str, optional): Path to the directory containing dialog JSON files. Defaults to './dialogs'.
Returns:
dict: A dictionary where keys are model name and values are the parsed JSON content.
"""
dialog_dict = {}
for model in os.listdir(dir_path):
model_path = os.path.join(dir_path, model)
dialog_files = get_files(model_path, ext='.json')
if not dialog_files:
continue
dialogues = [content
for file in dialog_files
for content in map(dialog_postprocessing, json_load(file).values())
if content
]
if dialogues:
dialog_dict[model] = dialogues
return dialog_dict
def dialog_postprocessing(dialog: str) -> str:
"""
Postprocess a dialog string by removing unwanted tokens.
Args:
dialog (str): The dialog string to be processed.
Returns:
str: The cleaned dialog string.
"""
department_candidates = ["gastroenterology", "cardiology", "pulmonary", "endocrinology/metabolism", "nephrology", "hematology/oncology", "allergy", "infectious diseases", "rheumatology"]
try:
answer_pattern = re.compile(r'Answer:\s*\d+\.\s*(.+)')
split_pattern = re.compile(r'\bAnswer:')
department = answer_pattern.search(dialog).group(1)
for candidate in department_candidates:
if department.lower().startswith(candidate):
department = candidate
break
assert department.lower() in department_candidates
before_answer = re.split(split_pattern, dialog)[0].strip()
before_answer += f' I will introduce you to a physician who work in the {department}.'
before_answer = before_answer.replace("Staff:", "<span style='color:rgb(0,102,204); font-weight:bold'>Staff</span>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;:")
before_answer = before_answer.replace("Patient:", "<span style='color:rgb(204,0,102); font-weight:bold'>Patient</span>:")
before_answer = before_answer.replace("\n", "<br>")
except:
before_answer = ''
return before_answer
def dialog_translate(dialog: str) -> str:
"""
Translate a dialogue between a hospital staff member and a patient from English to Korean.
Args:
dialog (str): A string containing the dialogue in HTML format.
Staff lines are marked with a blue "Staff" span and
patient lines with a pink "Patient" span, separated by '<br>' tags.
Raises:
TypeError: If translation or formatting fails.
Returns:
str: Translated dialogue lines in Korean
"""
user_prompt = "아래 문장 리스트를은 원무과 직원과 환자의 대화야. 순서대로 한국어로 번역해주고 string list로 반환해줘. 다른 결과, 미사여구 붙이지말고 딱 string list만 반환 해줘야해.\n\n{lines}"
staff_flag = "<span style='color:rgb(0,102,204); font-weight:bold'>Staff</span>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;:"
patient_flag = "<span style='color:rgb(204,0,102); font-weight:bold'>Patient</span>:"
lines = [''.join(line.split(staff_flag)[1:]).strip() if staff_flag in line else ''.join(line.split(patient_flag)[1:]).strip() for line in dialog.split('<br>')]
user_prompt = user_prompt.format(lines=lines)
client = GPTClient('gpt-5-nano')
try:
response = client(user_prompt, **{'reasoning_effort': 'minimal'})
response = eval(response)
response = '<br>'.join([staff_flag + ' ' + r if i % 2 == 0 else patient_flag + ' ' + r for i, r in enumerate(response)])
return response
except:
raise TypeError