File size: 3,068 Bytes
e7069ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e42bc71
e7069ae
 
 
 
 
 
 
 
e42bc71
e7069ae
 
 
 
 
 
 
 
e42bc71
e7069ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import re
import random
import numpy as np
from typing import Union
from datetime import datetime, timedelta

# NOTE: torch removed — set_seed only uses random + numpy for reproducibility.
# This is sufficient for API-based agents (GPT, Gemini) that don't run local models.

from . import colorstr
from ..registry.detection_key import DDX_DETECT_KEYS


def set_seed(seed: int) -> None:
    """
    Set the random seed for reproducibility (CPU-only, no torch required).
    """
    random.seed(seed)
    np.random.seed(seed)


def split_string(string: Union[str, list], delimiter: str = ",") -> list:
    if isinstance(string, str):
        return [s.strip() for s in string.split(delimiter)]
    elif isinstance(string, list):
        return [s.strip() for s in string]
    else:
        raise ValueError(colorstr("red", "Input must be a string or a list of strings."))


def prompt_valid_check(prompt: str, data_dict: dict) -> None:
    keys = re.findall(r'\{(.*?)\}', prompt)
    missing_keys = [key for key in keys if key not in data_dict]
    if missing_keys:
        raise ValueError(colorstr("red", f"Missing keys in the prompt: {missing_keys}. Please ensure all required keys are present in the data dictionary."))


def detect_ed_termination(text: str) -> bool:
    pattern = re.compile(r'\[ddx\]:\s*\d+\.\s*.+', re.IGNORECASE)
    end_flag = any(key.lower() in text.lower() for key in DDX_DETECT_KEYS)
    return bool(pattern.search(text.lower())) or end_flag


def detect_op_termination(text: str) -> bool:
    try:
        pattern = re.compile(r'Answer:\s*\d+\.\s*(.+)')
        return bool(pattern.search(text))
    except Exception:
        return False


def str_to_datetime(iso_time: Union[str, datetime]) -> datetime:
    try:
        if isinstance(iso_time, str):
            return datetime.fromisoformat(iso_time)
        return iso_time
    except Exception:
        raise ValueError(colorstr("red", f"`iso_time` must be str or date format, but got {type(iso_time)}"))


def datetime_to_str(iso_time: Union[str, datetime], format: str) -> str:
    try:
        if not isinstance(iso_time, str):
            return iso_time.strftime(format)
        return iso_time
    except Exception:
        raise ValueError(colorstr("red", f"`iso_time` must be str or date format, but got {type(iso_time)}"))


def generate_random_date(start_date: Union[str, datetime] = '1960-01-01',
                         end_date: Union[str, datetime] = '2000-12-31') -> str:
    start = str_to_datetime(start_date)
    end = str_to_datetime(end_date)
    delta = (end - start).days
    random_days = random.randint(0, delta)
    random_date = start + timedelta(days=random_days)
    return datetime_to_str(random_date, '%Y-%m-%d')


def exponential_backoff(retry_count: int,
                        base_delay: int = 5,
                        max_delay: int = 65,
                        jitter: bool = True) -> float:
    delay = min(base_delay * (2 ** retry_count), max_delay)
    if jitter:
        delay = random.uniform(delay * 0.8, delay * 1.2)
    return delay