Spaces:
Running
Running
File size: 1,642 Bytes
325d063 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 | import gzip
import io
import os
import torch
from typing import Optional
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
def _parse_key(key_str: str) -> bytes:
key_str = key_str.strip()
try:
key = bytes.fromhex(key_str)
if len(key) == 32:
return key
except ValueError:
pass
key = key_str.encode("utf-8")
if len(key) == 32:
return key
raise ValueError("Key must be either a 64-character hex string or a 32-character raw string.")
def _get_key(key: Optional[str] = None, env_var: str = "MODEL_KEY") -> bytes:
if key is not None:
return _parse_key(key)
env_value = os.environ.get(env_var)
if not env_value:
raise RuntimeError("Missing key. Provide key=... or set environment variable {}.".format(env_var))
return _parse_key(env_value)
def decrypt_and_decompress_to_bytes(path: str, key: Optional[str] = None, env_var: str = "MODEL_KEY") -> bytes:
key_bytes = _get_key(key=key, env_var=env_var)
aesgcm = AESGCM(key_bytes)
with open(path, "rb") as f:
data = f.read()
if len(data) < 13:
raise ValueError("Encrypted file is too short or invalid.")
nonce = data[:12]
ciphertext = data[12:]
compressed = aesgcm.decrypt(nonce, ciphertext, None)
plaintext = gzip.decompress(compressed)
return plaintext
def secure_torch_load(path: str, *args, key: Optional[str] = None, env_var: str = "MODEL_KEY", **kwargs):
plaintext = decrypt_and_decompress_to_bytes(path, key=key, env_var=env_var)
buffer = io.BytesIO(plaintext)
return torch.load(buffer, *args, **kwargs)
|