Spaces:
Running on Zero
Running on Zero
File size: 5,650 Bytes
0dd6c2f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | import json
import logging
import sys
import unicodedata
from collections import defaultdict
from datetime import datetime
from importlib import util
from pathlib import Path
from typing import TYPE_CHECKING, Any
from datasets.dataset_dict import DatasetDict
from huggingface_hub import HfApi
logger = logging.getLogger(__name__)
LLAMA_CPP_DIR = Path(__file__).parent / "distillation" / "llama-cpp" / "models"
if TYPE_CHECKING:
from types import ModuleType
def get_config_dir() -> str:
"""Get the path of the config directory"""
script_dir = Path(__file__).parent.parent
return str(script_dir / "config")
def get_log_file_path() -> str:
"""
Finds and returns the file path of the first FileHandler found in the logger's handlers.
Raises ValueError if no FileHandler is found.
"""
logger = logging.getLogger() # Get root logger
for handler in logger.handlers:
if isinstance(handler, logging.FileHandler):
return handler.baseFilename
raise ValueError("No FileHandler found in the logger's handlers")
def setup_logging(
level: int = logging.INFO, include_timestamp: bool = False, file_suffix: str = "linalg_zero.log"
) -> None: # pragma: no cover
"""
Set up simple logging configuration. Will log to console and file.
Args:
level: Logging level (default: INFO)
include_timestamp: Whether to include timestamp in logs
"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
Path("logs").mkdir(exist_ok=True)
format_string = "%(asctime)s - %(levelname)s: %(message)s" if include_timestamp else "%(levelname)s: %(message)s"
logging.basicConfig(
level=level,
format=format_string,
force=True,
handlers=[logging.StreamHandler(sys.stdout), logging.FileHandler(f"logs/{timestamp}_{file_suffix}")],
)
logging.info(f"Logging to {Path('logs') / f'{timestamp}_{file_suffix}'}")
def get_logger(name: str) -> logging.Logger:
"""Get a logger instance for the given name."""
return logging.getLogger(name)
def normalize_text(s: str, normalize_unicode: bool) -> str:
"""
Normalize Unicode text using NFKC normalization and replace minus signs.
"""
if not normalize_unicode or not isinstance(s, str):
return s
s = unicodedata.normalize("NFKC", s)
return s.replace("\u2212", "-")
def get_representative_examples_indices(dataset: Any, per_category: int, include_remaining: bool = True) -> list[int]:
"""Get representative indices first (per_category samples per problem type), then all remaining indices."""
categories: defaultdict[str, list[int]] = defaultdict(list)
representative_indices = []
# First pass: collect representative examples per category
for idx, example in enumerate(dataset):
task = example["problem_type"]
if len(categories[task]) < per_category:
categories[task].append(idx)
representative_indices.append(idx)
# Second pass: add all remaining indices
representative_set = set(representative_indices)
remaining_indices = [i for i in range(len(dataset)) if i not in representative_set]
print(f"Number of representative indices: {len(representative_indices)}")
if include_remaining:
return representative_indices + remaining_indices
else:
return representative_indices
def get_libpath() -> Path:
"""Returns the path to the library of functions."""
return Path(__file__).parent / "lib.py"
def load_module_from_path(path: Path) -> "ModuleType":
"""Loads a python module from a given path."""
spec = util.spec_from_file_location("module.name", path)
assert spec is not None and spec.loader is not None
module = util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def get_function_schema() -> str:
"""Return a JSON string with the full tool function schemas (sorted by name)."""
# TODO: verify loaded functions
libpath_module = load_module_from_path(get_libpath())
tools = libpath_module.get_tools()
# Ensure deterministic order for readability
tools = sorted(tools, key=lambda t: t["function"]["name"])
# For prompts, show only the inner function object (cleaner to read than the wrapper)
extracted_functions = [tool_info["function"] for tool_info in tools]
return json.dumps(extracted_functions, indent=2)
def push_to_hub(
dataset: DatasetDict | dict, hub_dataset_name: str, private: bool = False, config_path: str | None = None
) -> None:
"""Push the dataset to Hugging Face Hub, optionally including entropy settings."""
if isinstance(dataset, dict):
dataset = DatasetDict(dataset)
try:
dataset.push_to_hub(hub_dataset_name, private=private)
logger.info(f"Successfully pushed dataset to: https://huggingface.co/datasets/{hub_dataset_name}")
# Upload entropy settings as an additional file if it exists
if config_path and Path(config_path).exists():
api = HfApi()
api.upload_file(
path_or_fileobj=config_path,
path_in_repo="entropy_settings.json",
repo_id=hub_dataset_name,
repo_type="dataset",
)
logger.info(
f"✅ Successfully uploaded entropy settings to: https://huggingface.co/datasets/{hub_dataset_name}"
)
elif config_path:
logger.warning(f"Warning: Entropy settings file not found at {config_path}")
except Exception:
logger.exception("Failed to push dataset to Hugging Face Hub.")
raise
|