File size: 6,665 Bytes
a91cc9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import asyncio
from functools import wraps
import json
import re
import argparse
import logging
import ast
import os
import logging
import sys
import time


class StdoutToLogger:
    def write(self, text):
        text = text.strip()
        if text:                         # ignore bare newlines
            logging.info(text)
    def flush(self):                     # needed for Python’s IO contract
        pass            


def read_json(filepath):
    with open(filepath, "r") as f:
        data = json.load(f)
    return data

def write_json(char_name, save_name, data, args: argparse.Namespace, examinator_prompt=None):
    filepath = os.path.join(args.result_dir, char_name, save_name)
    if examinator_prompt:
        data = {
            "examinator_prompt": examinator_prompt,
            "n_cross_examine": args.n_cross_examine,
            "model": args.model,
            "n_repeat": args.n_repeat,
            "data": data
        }
    with open(filepath, "w") as f:
        json.dump(data, f, indent=4, ensure_ascii=False)

def concat_str_list(text):
    lines = text.splitlines()
    result = []
    i = 0

    def is_list_item(l):
        return re.match(r'^\s*(\d+\.\s+|[-*•]\s+)', l.strip()) and l.strip()

    while i < len(lines):
        line = lines[i]

        # If this line starts a list
        if is_list_item(line):
            # Flatten consecutive list lines
            list_lines = []
            while i < len(lines) and is_list_item(lines[i]):
                line_content = lines[i].strip()
                line_content = re.sub(r'^(\d+)\.(\s+)', r'\1\2', line_content)
                list_lines.append(line_content)
                i += 1
            flat_list = " ".join(list_lines)

            # If result is not empty and previous block is not empty, append list to previous line
            if result and result[-1].strip():
                # Add colon if the previous line doesn't end with punctuation
                prev_line = result[-1].rstrip()
                if not prev_line.endswith((':', '.', '!', '?')):
                    prev_line += ':'
                result[-1] = prev_line + " " + flat_list
            else:
                result.append(flat_list)
        else:
            result.append(line)
            i += 1

    return "\n".join(result)   

def parse_output(output: str):
    output = output.strip()
    
    if not output:
        raise ValueError("Output is empty or only whitespace.")

    # First, attempt direct JSON parse
    try:
        return json.loads(output, strict=False)
    except json.JSONDecodeError:
        pass  # Proceed to regex extraction

    # Attempt to extract JSON from code blocks
    code_block_patterns = [
        r"```json\s*([\s\S]+?)\s*```",  # triple-backtick with json
        r"```([\s\S]+?)\s*```",         # triple-backtick fallback
        r"(\{[\s\S]*?\})"               # any JSON-looking dict
    ]
    
    for pattern in code_block_patterns:
        match = re.search(pattern, output)
        if match:
            json_str = match.group(1).strip()

            # Try JSON decode first
            try:
                return json.loads(json_str, strict=False)
            except json.JSONDecodeError:
                pass

            # Try using ast.literal_eval as a fallback
            try:
                parsed = ast.literal_eval(json_str)
                if isinstance(parsed, dict):
                    return parsed
            except Exception:
                pass

            # Remove control characters and try once more
            json_str_cleaned = re.sub(r'[\x00-\x1f\x7f-\x9f]', '', json_str)
            try:
                return json.loads(json_str_cleaned, strict=False)
            except Exception as e:
                raise ValueError(f"Failed to parse JSON after cleaning: {e}")

    raise ValueError("No valid JSON object found in output.")

    # print(json_str)
    
    try:
        return json.loads(json_str, strict=False)
    except json.JSONDecodeError as e:
        raise ValueError(f"Extracted string is not valid JSON: {e}")


def retry_on_connection_error(max_retries: int = 3, delay: float = 3.0, backoff_factor: float = 2.0):
    """연결 오류 시 재시도하는 데코레이터"""
    def decorator(func):
        @wraps(func)
        async def wrapper(*args, **kwargs):
            last_exception = None
            current_delay = delay
            
            for attempt in range(max_retries + 1):
                try:
                    return await func(*args, **kwargs)
                except Exception as e:
                    last_exception = e
                    error_msg = str(e).lower()
                    
                    # 연결 관련 오류인지 확인
                    if any(keyword in error_msg for keyword in [
                        'connection reset by peer', 'connection refused', 'timeout',
                        'network', 'rpc', 'statuscode.unknown', 'put', 'read tcp', 'broken pipe', 'ws_recv', 'ws_send'
                    ]):
                        if attempt < max_retries:
                            logging.warning(f"Connection error on attempt {attempt + 1}/{max_retries}: {e}")
                            logging.info(f"Retrying in {current_delay} seconds...")
                            await asyncio.sleep(current_delay)
                            current_delay *= backoff_factor
                            continue
                        else:
                            logging.error(f"Max retries ({max_retries}) reached. Final error: {e}")
                    else:
                        # 연결 오류가 아닌 경우 즉시 재발생
                        raise e
            
            raise last_exception
        return wrapper
    return decorator

def setup_logging(log_to_file: bool, process_name: str = None):
    if log_to_file:
        os.makedirs(f'logs/{time.strftime("%Y-%m-%d")}', exist_ok=True)
        log_filename = f'logs/{time.strftime("%Y-%m-%d")}/{process_name}_{time.strftime("%Y-%m-%d_%H-%M-%S")}.log'
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s  %(levelname)s  %(message)s',
            handlers=[
                logging.FileHandler(log_filename),
                logging.StreamHandler(sys.stdout)
            ])
    else:
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s  %(levelname)s  %(message)s',
            handlers=[logging.StreamHandler(sys.stdout)]
        )
    for noisy in ("LiteLLM", "httpx", "google", "urllib3"):
            logging.getLogger(noisy).setLevel(logging.WARNING)