human_interview / utils.py
sujeongim
add : required files
a91cc9f
import asyncio
from functools import wraps
import json
import re
import argparse
import logging
import ast
import os
import logging
import sys
import time
class StdoutToLogger:
def write(self, text):
text = text.strip()
if text: # ignore bare newlines
logging.info(text)
def flush(self): # needed for Pythonโ€™s IO contract
pass
def read_json(filepath):
with open(filepath, "r") as f:
data = json.load(f)
return data
def write_json(char_name, save_name, data, args: argparse.Namespace, examinator_prompt=None):
filepath = os.path.join(args.result_dir, char_name, save_name)
if examinator_prompt:
data = {
"examinator_prompt": examinator_prompt,
"n_cross_examine": args.n_cross_examine,
"model": args.model,
"n_repeat": args.n_repeat,
"data": data
}
with open(filepath, "w") as f:
json.dump(data, f, indent=4, ensure_ascii=False)
def concat_str_list(text):
lines = text.splitlines()
result = []
i = 0
def is_list_item(l):
return re.match(r'^\s*(\d+\.\s+|[-*โ€ข]\s+)', l.strip()) and l.strip()
while i < len(lines):
line = lines[i]
# If this line starts a list
if is_list_item(line):
# Flatten consecutive list lines
list_lines = []
while i < len(lines) and is_list_item(lines[i]):
line_content = lines[i].strip()
line_content = re.sub(r'^(\d+)\.(\s+)', r'\1\2', line_content)
list_lines.append(line_content)
i += 1
flat_list = " ".join(list_lines)
# If result is not empty and previous block is not empty, append list to previous line
if result and result[-1].strip():
# Add colon if the previous line doesn't end with punctuation
prev_line = result[-1].rstrip()
if not prev_line.endswith((':', '.', '!', '?')):
prev_line += ':'
result[-1] = prev_line + " " + flat_list
else:
result.append(flat_list)
else:
result.append(line)
i += 1
return "\n".join(result)
def parse_output(output: str):
output = output.strip()
if not output:
raise ValueError("Output is empty or only whitespace.")
# First, attempt direct JSON parse
try:
return json.loads(output, strict=False)
except json.JSONDecodeError:
pass # Proceed to regex extraction
# Attempt to extract JSON from code blocks
code_block_patterns = [
r"```json\s*([\s\S]+?)\s*```", # triple-backtick with json
r"```([\s\S]+?)\s*```", # triple-backtick fallback
r"(\{[\s\S]*?\})" # any JSON-looking dict
]
for pattern in code_block_patterns:
match = re.search(pattern, output)
if match:
json_str = match.group(1).strip()
# Try JSON decode first
try:
return json.loads(json_str, strict=False)
except json.JSONDecodeError:
pass
# Try using ast.literal_eval as a fallback
try:
parsed = ast.literal_eval(json_str)
if isinstance(parsed, dict):
return parsed
except Exception:
pass
# Remove control characters and try once more
json_str_cleaned = re.sub(r'[\x00-\x1f\x7f-\x9f]', '', json_str)
try:
return json.loads(json_str_cleaned, strict=False)
except Exception as e:
raise ValueError(f"Failed to parse JSON after cleaning: {e}")
raise ValueError("No valid JSON object found in output.")
# print(json_str)
try:
return json.loads(json_str, strict=False)
except json.JSONDecodeError as e:
raise ValueError(f"Extracted string is not valid JSON: {e}")
def retry_on_connection_error(max_retries: int = 3, delay: float = 3.0, backoff_factor: float = 2.0):
"""์—ฐ๊ฒฐ ์˜ค๋ฅ˜ ์‹œ ์žฌ์‹œ๋„ํ•˜๋Š” ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
last_exception = None
current_delay = delay
for attempt in range(max_retries + 1):
try:
return await func(*args, **kwargs)
except Exception as e:
last_exception = e
error_msg = str(e).lower()
# ์—ฐ๊ฒฐ ๊ด€๋ จ ์˜ค๋ฅ˜์ธ์ง€ ํ™•์ธ
if any(keyword in error_msg for keyword in [
'connection reset by peer', 'connection refused', 'timeout',
'network', 'rpc', 'statuscode.unknown', 'put', 'read tcp', 'broken pipe', 'ws_recv', 'ws_send'
]):
if attempt < max_retries:
logging.warning(f"Connection error on attempt {attempt + 1}/{max_retries}: {e}")
logging.info(f"Retrying in {current_delay} seconds...")
await asyncio.sleep(current_delay)
current_delay *= backoff_factor
continue
else:
logging.error(f"Max retries ({max_retries}) reached. Final error: {e}")
else:
# ์—ฐ๊ฒฐ ์˜ค๋ฅ˜๊ฐ€ ์•„๋‹Œ ๊ฒฝ์šฐ ์ฆ‰์‹œ ์žฌ๋ฐœ์ƒ
raise e
raise last_exception
return wrapper
return decorator
def setup_logging(log_to_file: bool, process_name: str = None):
if log_to_file:
os.makedirs(f'logs/{time.strftime("%Y-%m-%d")}', exist_ok=True)
log_filename = f'logs/{time.strftime("%Y-%m-%d")}/{process_name}_{time.strftime("%Y-%m-%d_%H-%M-%S")}.log'
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s %(levelname)s %(message)s',
handlers=[
logging.FileHandler(log_filename),
logging.StreamHandler(sys.stdout)
])
else:
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s %(levelname)s %(message)s',
handlers=[logging.StreamHandler(sys.stdout)]
)
for noisy in ("LiteLLM", "httpx", "google", "urllib3"):
logging.getLogger(noisy).setLevel(logging.WARNING)