Respair's picture
Upload folder using huggingface_hub
b386992 verified
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Union
import torch
_str_to_dtype: Dict[str, torch.dtype] = dict(
float32=torch.float32,
float=torch.float32,
float64=torch.float64,
double=torch.float64,
float16=torch.float16,
half=torch.float16,
bfloat16=torch.bfloat16,
bf16=torch.bfloat16,
uint8=torch.uint8,
byte=torch.uint8,
int8=torch.int8,
char=torch.int8,
int16=torch.int16,
short=torch.int16,
int32=torch.int32,
int=torch.int32,
int64=torch.int64,
long=torch.int64,
bool=torch.bool,
)
def str_to_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype:
"""Convert a data type name to a PyTorch data type"""
if isinstance(dtype, torch.dtype):
return dtype
name = str(dtype).strip().lower()
if name.startswith("torch."):
name = name.replace("torch.", "", 1)
if name.startswith("fp"):
name = name.replace("fp", "float", 1)
if name not in _str_to_dtype:
raise ValueError(f"Unrecognized dtype ({name})")
return _str_to_dtype[name]