Spaces:
Sleeping
Sleeping
| import io | |
| import torch | |
| def get_target_dtype_ref(target_dtype: str) -> torch.dtype: | |
| if isinstance(target_dtype, torch.dtype): | |
| return target_dtype | |
| if target_dtype == "float16": | |
| return torch.float16 | |
| elif target_dtype == "float32": | |
| return torch.float32 | |
| elif target_dtype == "bfloat16": | |
| return torch.bfloat16 | |
| else: | |
| raise ValueError(f"Invalid target_dtype: {target_dtype}") | |
| def convert_ckpt_to_safetensors(ckpt_upload: io.BytesIO, target_dtype) -> dict: | |
| if isinstance(ckpt_upload, bytes): | |
| ckpt_upload = io.BytesIO(ckpt_upload) | |
| target_dtype = get_target_dtype_ref(target_dtype) | |
| # Load the checkpoint | |
| loaded_dict = torch.load(ckpt_upload, map_location="cpu") | |
| tensor_dict = {} | |
| is_embedding = 'string_to_param' in loaded_dict | |
| if is_embedding: | |
| emb_tensor = loaded_dict.get('string_to_param', {}).get('*', None) | |
| if emb_tensor is not None: | |
| emb_tensor = emb_tensor.to(dtype=target_dtype) | |
| tensor_dict = { | |
| 'emb_params': emb_tensor | |
| } | |
| else: | |
| # Convert weights in a checkpoint to a dictionary of tensors | |
| for key, val in loaded_dict.items(): | |
| if isinstance(val, torch.Tensor): | |
| tensor_dict[key] = val.to(dtype=target_dtype) | |
| return tensor_dict | |
| if __name__ == '__main__': | |
| print('__main__ not allowed in modules') | |