File size: 5,650 Bytes
0dd6c2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import json
import logging
import sys
import unicodedata
from collections import defaultdict
from datetime import datetime
from importlib import util
from pathlib import Path
from typing import TYPE_CHECKING, Any

from datasets.dataset_dict import DatasetDict
from huggingface_hub import HfApi

logger = logging.getLogger(__name__)

LLAMA_CPP_DIR = Path(__file__).parent / "distillation" / "llama-cpp" / "models"
if TYPE_CHECKING:
    from types import ModuleType


def get_config_dir() -> str:
    """Get the path of the config directory"""
    script_dir = Path(__file__).parent.parent
    return str(script_dir / "config")


def get_log_file_path() -> str:
    """
    Finds and returns the file path of the first FileHandler found in the logger's handlers.
    Raises ValueError if no FileHandler is found.
    """
    logger = logging.getLogger()  # Get root logger
    for handler in logger.handlers:
        if isinstance(handler, logging.FileHandler):
            return handler.baseFilename
    raise ValueError("No FileHandler found in the logger's handlers")


def setup_logging(
    level: int = logging.INFO, include_timestamp: bool = False, file_suffix: str = "linalg_zero.log"
) -> None:  # pragma: no cover
    """
    Set up simple logging configuration. Will log to console and file.

    Args:
        level: Logging level (default: INFO)
        include_timestamp: Whether to include timestamp in logs
    """
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    Path("logs").mkdir(exist_ok=True)

    format_string = "%(asctime)s - %(levelname)s: %(message)s" if include_timestamp else "%(levelname)s: %(message)s"

    logging.basicConfig(
        level=level,
        format=format_string,
        force=True,
        handlers=[logging.StreamHandler(sys.stdout), logging.FileHandler(f"logs/{timestamp}_{file_suffix}")],
    )

    logging.info(f"Logging to {Path('logs') / f'{timestamp}_{file_suffix}'}")


def get_logger(name: str) -> logging.Logger:
    """Get a logger instance for the given name."""
    return logging.getLogger(name)


def normalize_text(s: str, normalize_unicode: bool) -> str:
    """
    Normalize Unicode text using NFKC normalization and replace minus signs.
    """
    if not normalize_unicode or not isinstance(s, str):
        return s
    s = unicodedata.normalize("NFKC", s)
    return s.replace("\u2212", "-")


def get_representative_examples_indices(dataset: Any, per_category: int, include_remaining: bool = True) -> list[int]:
    """Get representative indices first (per_category samples per problem type), then all remaining indices."""
    categories: defaultdict[str, list[int]] = defaultdict(list)
    representative_indices = []

    # First pass: collect representative examples per category
    for idx, example in enumerate(dataset):
        task = example["problem_type"]
        if len(categories[task]) < per_category:
            categories[task].append(idx)
            representative_indices.append(idx)

    # Second pass: add all remaining indices
    representative_set = set(representative_indices)
    remaining_indices = [i for i in range(len(dataset)) if i not in representative_set]
    print(f"Number of representative indices: {len(representative_indices)}")

    if include_remaining:
        return representative_indices + remaining_indices
    else:
        return representative_indices


def get_libpath() -> Path:
    """Returns the path to the library of functions."""
    return Path(__file__).parent / "lib.py"


def load_module_from_path(path: Path) -> "ModuleType":
    """Loads a python module from a given path."""
    spec = util.spec_from_file_location("module.name", path)
    assert spec is not None and spec.loader is not None
    module = util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return module


def get_function_schema() -> str:
    """Return a JSON string with the full tool function schemas (sorted by name)."""
    # TODO: verify loaded functions
    libpath_module = load_module_from_path(get_libpath())
    tools = libpath_module.get_tools()
    # Ensure deterministic order for readability
    tools = sorted(tools, key=lambda t: t["function"]["name"])

    # For prompts, show only the inner function object (cleaner to read than the wrapper)
    extracted_functions = [tool_info["function"] for tool_info in tools]
    return json.dumps(extracted_functions, indent=2)


def push_to_hub(
    dataset: DatasetDict | dict, hub_dataset_name: str, private: bool = False, config_path: str | None = None
) -> None:
    """Push the dataset to Hugging Face Hub, optionally including entropy settings."""

    if isinstance(dataset, dict):
        dataset = DatasetDict(dataset)

    try:
        dataset.push_to_hub(hub_dataset_name, private=private)
        logger.info(f"Successfully pushed dataset to: https://huggingface.co/datasets/{hub_dataset_name}")

        # Upload entropy settings as an additional file if it exists
        if config_path and Path(config_path).exists():
            api = HfApi()
            api.upload_file(
                path_or_fileobj=config_path,
                path_in_repo="entropy_settings.json",
                repo_id=hub_dataset_name,
                repo_type="dataset",
            )
            logger.info(
                f"✅ Successfully uploaded entropy settings to: https://huggingface.co/datasets/{hub_dataset_name}"
            )
        elif config_path:
            logger.warning(f"Warning: Entropy settings file not found at {config_path}")
    except Exception:
        logger.exception("Failed to push dataset to Hugging Face Hub.")
        raise