File size: 4,000 Bytes
b2a5882
3703b6a
 
b2a5882
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3703b6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os, re

from utils.client import GPTClient
from utils.filesys_utils import json_load, get_files



def make_dialog_dict(dir_path: str = './dialogs') -> dict:
    """
    Create a dictionary of dialogs from JSON files in the specified directory.

    Args:
        dir_path (str, optional): Path to the directory containing dialog JSON files. Defaults to './dialogs'.

    Returns:
        dict: A dictionary where keys are model name and values are the parsed JSON content.
    """
    dialog_dict = {}

    for model in os.listdir(dir_path):
        model_path = os.path.join(dir_path, model)
        dialog_files = get_files(model_path, ext='.json')
        if not dialog_files:
            continue

        dialogues = [content 
            for file in dialog_files
            for content in map(dialog_postprocessing, json_load(file).values())
            if content
        ]

        if dialogues:
            dialog_dict[model] = dialogues

    return dialog_dict



def dialog_postprocessing(dialog: str) -> str:
    """
    Postprocess a dialog string by removing unwanted tokens.

    Args:
        dialog (str): The dialog string to be processed.

    Returns:
        str: The cleaned dialog string.
    """
    department_candidates = ["gastroenterology", "cardiology", "pulmonary", "endocrinology/metabolism", "nephrology", "hematology/oncology", "allergy", "infectious diseases", "rheumatology"]
    try:
        answer_pattern = re.compile(r'Answer:\s*\d+\.\s*(.+)')
        split_pattern = re.compile(r'\bAnswer:')
        
        department = answer_pattern.search(dialog).group(1)
        for candidate in department_candidates:
            if department.lower().startswith(candidate):
                department = candidate
                break
        assert department.lower() in department_candidates
        
        before_answer = re.split(split_pattern, dialog)[0].strip()
        before_answer += f' I will introduce you to a physician who work in the {department}.'

        before_answer = before_answer.replace("Staff:", "<span style='color:rgb(0,102,204); font-weight:bold'>Staff</span>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;:")
        before_answer = before_answer.replace("Patient:", "<span style='color:rgb(204,0,102); font-weight:bold'>Patient</span>:")
        before_answer = before_answer.replace("\n", "<br>")
    
    except:
        before_answer = ''

    return before_answer



def dialog_translate(dialog: str) -> str:
    """
    Translate a dialogue between a hospital staff member and a patient from English to Korean.

    Args:
        dialog (str): A string containing the dialogue in HTML format. 
                      Staff lines are marked with a blue "Staff" span and 
                      patient lines with a pink "Patient" span, separated by '<br>' tags.

    Raises:
        TypeError: If translation or formatting fails.

    Returns:
        str: Translated dialogue lines in Korean
    """
    user_prompt = "아래 문장 리스트를은 원무과 직원과 환자의 대화야. 순서대로 한국어로 번역해주고 string list로 반환해줘. 다른 결과, 미사여구 붙이지말고 딱 string list만 반환 해줘야해.\n\n{lines}"
    staff_flag = "<span style='color:rgb(0,102,204); font-weight:bold'>Staff</span>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;:"
    patient_flag = "<span style='color:rgb(204,0,102); font-weight:bold'>Patient</span>:"
    lines = [''.join(line.split(staff_flag)[1:]).strip() if staff_flag in line else ''.join(line.split(patient_flag)[1:]).strip() for line in dialog.split('<br>')]
    user_prompt = user_prompt.format(lines=lines)
    client = GPTClient('gpt-5-nano')
    
    try:    
        response = client(user_prompt, **{'reasoning_effort': 'minimal'})
        response = eval(response)
        response = '<br>'.join([staff_flag + ' ' + r if i % 2 == 0 else patient_flag + ' ' + r for i, r in enumerate(response)])
        return response
    except:
        raise TypeError