Spaces:
Sleeping
Sleeping
| import io | |
| import json | |
| import safetensors | |
| import torch | |
| from safetensors.torch import serialize | |
| from .torch_tools import get_target_dtype_ref | |
| def read_safetensors_metadata(lora_upload: io.BytesIO) -> dict: | |
| # This is a simple file structure, the first 8 bytes are the metadata length. | |
| # Read (length) bytes starting from [8] to get the metadata (a json string). | |
| lora_upload.seek(0) | |
| metadata_length = int.from_bytes(lora_upload.read(8), byteorder='little') | |
| lora_upload.seek(8) | |
| metadata_raw = lora_upload.read(metadata_length) | |
| metadata_raw = metadata_raw.decode("utf-8") | |
| metadata_raw = metadata_raw.strip() | |
| metadata_dict = json.loads(metadata_raw) | |
| # Rewind the buffer to the start, we were just peeking at the metadata. | |
| lora_upload.seek(0) | |
| return metadata_dict.get('__metadata__', {}) | |
| def rescale_lora_alpha(lora_upload: io.BytesIO, output_dtype, target_weight: float = 1.0) -> dict: | |
| output_dtype = get_target_dtype_ref(output_dtype) | |
| loaded_tensors = safetensors.torch.load(lora_upload.getvalue()) | |
| initial_tensors = {} | |
| for tensor_pair in loaded_tensors.items(): | |
| key, tensor = tensor_pair | |
| initial_tensors[key] = tensor.to(dtype=torch.float32) | |
| new_tensors = {} | |
| for key, val in initial_tensors.items(): | |
| if key.endswith(".alpha"): | |
| val *= target_weight | |
| new_tensors[key] = val.to(dtype=output_dtype) | |
| return new_tensors | |
| def remove_clip_weights(lora_upload: io.BytesIO, output_dtype) -> dict: | |
| output_dtype = get_target_dtype_ref(output_dtype) | |
| loaded_tensors = safetensors.torch.load(lora_upload.getvalue()) | |
| initial_tensors = {} | |
| for tensor_pair in loaded_tensors.items(): | |
| key, tensor = tensor_pair | |
| initial_tensors[key] = tensor.to(dtype=torch.float32) | |
| filtered_tensors = {} | |
| for key, val in initial_tensors.items(): | |
| if key.startswith("lora_te1") or key.startswith("lora_te2"): | |
| continue | |
| filtered_tensors[key] = val.to(dtype=output_dtype) | |
| return filtered_tensors | |
| if __name__ == '__main__': | |
| print('__main__ not allowed in modules') | |