temp / load_lambda_model.py
aakashjapi's picture
Add lambda keras model
15bec80 verified
import os
import argparse
from typing import Optional
import tensorflow as tf
try:
from huggingface_hub import snapshot_download
except Exception: # pragma: no cover
snapshot_download = None # type: ignore
def load_model_from_hub(
repo_id: str,
token: Optional[str] = None,
revision: Optional[str] = None,
local_dir: Optional[str] = None,
) -> tf.keras.Model:
"""Download artifacts from Hugging Face Hub and load the Keras model.
Args:
repo_id: Repository like `username/lambda-keras-model`.
token: Optional HF token; otherwise use cached.
revision: Optional git revision, tag, or commit.
local_dir: Optional directory to place downloaded snapshot.
Returns:
Loaded tf.keras.Model
"""
if snapshot_download is None:
raise RuntimeError(
"huggingface-hub is not installed. Add it to dependencies and reinstall."
)
cache_dir = snapshot_download(
repo_id=repo_id,
token=token,
revision=revision,
local_dir=local_dir,
local_dir_use_symlinks=False,
)
model_path = os.path.join(cache_dir, "lambda_model.keras")
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found in repo: {model_path}")
model = tf.keras.models.load_model(model_path)
return model
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Load the Lambda tf.keras model from Hugging Face Hub")
parser.add_argument("--repo-id", type=str, required=True, help="Repo id, e.g. username/lambda-keras-model")
parser.add_argument("--hf-token", type=str, default=None, help="Hugging Face token (optional)")
parser.add_argument("--revision", type=str, default=None, help="Git revision, tag, or commit (optional)")
parser.add_argument("--local-dir", type=str, default=None, help="Optional local directory for download")
parser.add_argument("--run", action="store_true", help="Run a quick forward pass as a smoke test")
return parser.parse_args()
def main() -> None:
args = parse_args()
model = load_model_from_hub(
repo_id=args.repo_id,
token=args.hf_token,
revision=args.revision,
local_dir=args.local_dir,
)
model.summary()
if args.run:
# Attempt a quick forward pass using shape derived from the model input
input_shape = tuple(dim if dim is not None else 4 for dim in model.input_shape[1:])
example = tf.ones((1,) + input_shape)
prediction = model(example)
print("Example input:", example.numpy())
print("Model output:", prediction.numpy())
if __name__ == "__main__":
main()