| | import logging |
| | import os |
| | import re |
| | import subprocess |
| | from pathlib import Path |
| | from typing import Any, Dict, Optional, Tuple, Union |
| |
|
| | import numpy as np |
| | from torch.utils.collect_env import get_pretty_env_info |
| | from transformers import __version__ as trans_version |
| |
|
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def remove_none_pattern(input_string: str) -> Tuple[str, bool]: |
| | """Remove the ',none' substring from the input_string if it exists at the end. |
| | |
| | Args: |
| | input_string (str): The input string from which to remove the ',none' substring. |
| | |
| | Returns: |
| | Tuple[str, bool]: A tuple containing the modified input_string with the ',none' substring removed |
| | and a boolean indicating whether the modification was made (True) or not (False). |
| | """ |
| | |
| | pattern = re.compile(r",none$") |
| |
|
| | |
| | result = re.sub(pattern, "", input_string) |
| |
|
| | |
| | removed = result != input_string |
| |
|
| | return result, removed |
| |
|
| |
|
| | def _handle_non_serializable(o: Any) -> Union[int, str, list]: |
| | """Handle non-serializable objects by converting them to serializable types. |
| | |
| | Args: |
| | o (Any): The object to be handled. |
| | |
| | Returns: |
| | Union[int, str, list]: The converted object. If the object is of type np.int64 or np.int32, |
| | it will be converted to int. If the object is of type set, it will be converted |
| | to a list. Otherwise, it will be converted to str. |
| | """ |
| | if isinstance(o, np.int64) or isinstance(o, np.int32): |
| | return int(o) |
| | elif isinstance(o, set): |
| | return list(o) |
| | else: |
| | return str(o) |
| |
|
| |
|
| | def get_commit_from_path(repo_path: Union[Path, str]) -> Optional[str]: |
| | try: |
| | git_folder = Path(repo_path, ".git") |
| | if git_folder.is_file(): |
| | git_folder = Path( |
| | git_folder.parent, |
| | git_folder.read_text(encoding="utf-8").split("\n")[0].split(" ")[-1], |
| | ) |
| | if Path(git_folder, "HEAD").exists(): |
| | head_name = ( |
| | Path(git_folder, "HEAD") |
| | .read_text(encoding="utf-8") |
| | .split("\n")[0] |
| | .split(" ")[-1] |
| | ) |
| | head_ref = Path(git_folder, head_name) |
| | git_hash = head_ref.read_text(encoding="utf-8").replace("\n", "") |
| | else: |
| | git_hash = None |
| | except Exception as err: |
| | logger.debug( |
| | f"Failed to retrieve a Git commit hash from path: {str(repo_path)}. Error: {err}" |
| | ) |
| | return None |
| | return git_hash |
| |
|
| |
|
| | def get_git_commit_hash(): |
| | """ |
| | Gets the git commit hash of your current repo (if it exists). |
| | Source: https://github.com/EleutherAI/gpt-neox/blob/b608043be541602170bfcfb8ec9bf85e8a0799e0/megatron/neox_arguments/neox_args.py#L42 |
| | """ |
| | try: |
| | git_hash = subprocess.check_output(["git", "describe", "--always"]).strip() |
| | git_hash = git_hash.decode() |
| | except (subprocess.CalledProcessError, FileNotFoundError): |
| | |
| | git_hash = get_commit_from_path(os.getcwd()) |
| | return git_hash |
| |
|
| |
|
| | def add_env_info(storage: Dict[str, Any]): |
| | try: |
| | pretty_env_info = get_pretty_env_info() |
| | except Exception as err: |
| | pretty_env_info = str(err) |
| | transformers_version = trans_version |
| | upper_dir_commit = get_commit_from_path( |
| | Path(os.getcwd(), "..") |
| | ) |
| | added_info = { |
| | "pretty_env_info": pretty_env_info, |
| | "transformers_version": transformers_version, |
| | "upper_git_hash": upper_dir_commit, |
| | } |
| | storage.update(added_info) |
| |
|
| |
|
| | def add_tokenizer_info(storage: Dict[str, Any], lm): |
| | if getattr(lm, "tokenizer", False): |
| | try: |
| | tokenizer_info = { |
| | "tokenizer_pad_token": [ |
| | lm.tokenizer.pad_token, |
| | str(lm.tokenizer.pad_token_id), |
| | ], |
| | "tokenizer_eos_token": [ |
| | lm.tokenizer.eos_token, |
| | str(lm.tokenizer.eos_token_id), |
| | ], |
| | "tokenizer_bos_token": [ |
| | lm.tokenizer.bos_token, |
| | str(lm.tokenizer.bos_token_id), |
| | ], |
| | "eot_token_id": getattr(lm, "eot_token_id", None), |
| | "max_length": getattr(lm, "max_length", None), |
| | } |
| | storage.update(tokenizer_info) |
| | except Exception as err: |
| | logger.debug( |
| | f"Logging detailed tokenizer info failed with {err}, skipping..." |
| | ) |
| | |
| | else: |
| | logger.debug( |
| | "LM does not have a 'tokenizer' attribute, not logging tokenizer metadata to results." |
| | ) |
| |
|