Spaces:
Running on Zero
Running on Zero
| import json | |
| import logging | |
| import sys | |
| import unicodedata | |
| from collections import defaultdict | |
| from datetime import datetime | |
| from importlib import util | |
| from pathlib import Path | |
| from typing import TYPE_CHECKING, Any | |
| from datasets.dataset_dict import DatasetDict | |
| from huggingface_hub import HfApi | |
| logger = logging.getLogger(__name__) | |
| LLAMA_CPP_DIR = Path(__file__).parent / "distillation" / "llama-cpp" / "models" | |
| if TYPE_CHECKING: | |
| from types import ModuleType | |
| def get_config_dir() -> str: | |
| """Get the path of the config directory""" | |
| script_dir = Path(__file__).parent.parent | |
| return str(script_dir / "config") | |
| def get_log_file_path() -> str: | |
| """ | |
| Finds and returns the file path of the first FileHandler found in the logger's handlers. | |
| Raises ValueError if no FileHandler is found. | |
| """ | |
| logger = logging.getLogger() # Get root logger | |
| for handler in logger.handlers: | |
| if isinstance(handler, logging.FileHandler): | |
| return handler.baseFilename | |
| raise ValueError("No FileHandler found in the logger's handlers") | |
| def setup_logging( | |
| level: int = logging.INFO, include_timestamp: bool = False, file_suffix: str = "linalg_zero.log" | |
| ) -> None: # pragma: no cover | |
| """ | |
| Set up simple logging configuration. Will log to console and file. | |
| Args: | |
| level: Logging level (default: INFO) | |
| include_timestamp: Whether to include timestamp in logs | |
| """ | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| Path("logs").mkdir(exist_ok=True) | |
| format_string = "%(asctime)s - %(levelname)s: %(message)s" if include_timestamp else "%(levelname)s: %(message)s" | |
| logging.basicConfig( | |
| level=level, | |
| format=format_string, | |
| force=True, | |
| handlers=[logging.StreamHandler(sys.stdout), logging.FileHandler(f"logs/{timestamp}_{file_suffix}")], | |
| ) | |
| logging.info(f"Logging to {Path('logs') / f'{timestamp}_{file_suffix}'}") | |
| def get_logger(name: str) -> logging.Logger: | |
| """Get a logger instance for the given name.""" | |
| return logging.getLogger(name) | |
| def normalize_text(s: str, normalize_unicode: bool) -> str: | |
| """ | |
| Normalize Unicode text using NFKC normalization and replace minus signs. | |
| """ | |
| if not normalize_unicode or not isinstance(s, str): | |
| return s | |
| s = unicodedata.normalize("NFKC", s) | |
| return s.replace("\u2212", "-") | |
| def get_representative_examples_indices(dataset: Any, per_category: int, include_remaining: bool = True) -> list[int]: | |
| """Get representative indices first (per_category samples per problem type), then all remaining indices.""" | |
| categories: defaultdict[str, list[int]] = defaultdict(list) | |
| representative_indices = [] | |
| # First pass: collect representative examples per category | |
| for idx, example in enumerate(dataset): | |
| task = example["problem_type"] | |
| if len(categories[task]) < per_category: | |
| categories[task].append(idx) | |
| representative_indices.append(idx) | |
| # Second pass: add all remaining indices | |
| representative_set = set(representative_indices) | |
| remaining_indices = [i for i in range(len(dataset)) if i not in representative_set] | |
| print(f"Number of representative indices: {len(representative_indices)}") | |
| if include_remaining: | |
| return representative_indices + remaining_indices | |
| else: | |
| return representative_indices | |
| def get_libpath() -> Path: | |
| """Returns the path to the library of functions.""" | |
| return Path(__file__).parent / "lib.py" | |
| def load_module_from_path(path: Path) -> "ModuleType": | |
| """Loads a python module from a given path.""" | |
| spec = util.spec_from_file_location("module.name", path) | |
| assert spec is not None and spec.loader is not None | |
| module = util.module_from_spec(spec) | |
| spec.loader.exec_module(module) | |
| return module | |
| def get_function_schema() -> str: | |
| """Return a JSON string with the full tool function schemas (sorted by name).""" | |
| # TODO: verify loaded functions | |
| libpath_module = load_module_from_path(get_libpath()) | |
| tools = libpath_module.get_tools() | |
| # Ensure deterministic order for readability | |
| tools = sorted(tools, key=lambda t: t["function"]["name"]) | |
| # For prompts, show only the inner function object (cleaner to read than the wrapper) | |
| extracted_functions = [tool_info["function"] for tool_info in tools] | |
| return json.dumps(extracted_functions, indent=2) | |
| def push_to_hub( | |
| dataset: DatasetDict | dict, hub_dataset_name: str, private: bool = False, config_path: str | None = None | |
| ) -> None: | |
| """Push the dataset to Hugging Face Hub, optionally including entropy settings.""" | |
| if isinstance(dataset, dict): | |
| dataset = DatasetDict(dataset) | |
| try: | |
| dataset.push_to_hub(hub_dataset_name, private=private) | |
| logger.info(f"Successfully pushed dataset to: https://huggingface.co/datasets/{hub_dataset_name}") | |
| # Upload entropy settings as an additional file if it exists | |
| if config_path and Path(config_path).exists(): | |
| api = HfApi() | |
| api.upload_file( | |
| path_or_fileobj=config_path, | |
| path_in_repo="entropy_settings.json", | |
| repo_id=hub_dataset_name, | |
| repo_type="dataset", | |
| ) | |
| logger.info( | |
| f"✅ Successfully uploaded entropy settings to: https://huggingface.co/datasets/{hub_dataset_name}" | |
| ) | |
| elif config_path: | |
| logger.warning(f"Warning: Entropy settings file not found at {config_path}") | |
| except Exception: | |
| logger.exception("Failed to push dataset to Hugging Face Hub.") | |
| raise | |