File size: 1,642 Bytes
325d063
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import gzip
import io
import os
import torch
from typing import Optional
from cryptography.hazmat.primitives.ciphers.aead import AESGCM


def _parse_key(key_str: str) -> bytes:
    key_str = key_str.strip()

    try:
        key = bytes.fromhex(key_str)
        if len(key) == 32:
            return key
    except ValueError:
        pass

    key = key_str.encode("utf-8")
    if len(key) == 32:
        return key

    raise ValueError("Key must be either a 64-character hex string or a 32-character raw string.")


def _get_key(key: Optional[str] = None, env_var: str = "MODEL_KEY") -> bytes:
    if key is not None:
        return _parse_key(key)

    env_value = os.environ.get(env_var)
    if not env_value:
        raise RuntimeError("Missing key. Provide key=... or set environment variable {}.".format(env_var))
    return _parse_key(env_value)


def decrypt_and_decompress_to_bytes(path: str, key: Optional[str] = None, env_var: str = "MODEL_KEY") -> bytes:
    key_bytes = _get_key(key=key, env_var=env_var)
    aesgcm = AESGCM(key_bytes)

    with open(path, "rb") as f:
        data = f.read()

    if len(data) < 13:
        raise ValueError("Encrypted file is too short or invalid.")

    nonce = data[:12]
    ciphertext = data[12:]
    compressed = aesgcm.decrypt(nonce, ciphertext, None)
    plaintext = gzip.decompress(compressed)
    return plaintext


def secure_torch_load(path: str, *args, key: Optional[str] = None, env_var: str = "MODEL_KEY", **kwargs):
    plaintext = decrypt_and_decompress_to_bytes(path, key=key, env_var=env_var)
    buffer = io.BytesIO(plaintext)
    return torch.load(buffer, *args, **kwargs)