import os import argparse from typing import Optional import tensorflow as tf try: from huggingface_hub import snapshot_download except Exception: # pragma: no cover snapshot_download = None # type: ignore def load_model_from_hub( repo_id: str, token: Optional[str] = None, revision: Optional[str] = None, local_dir: Optional[str] = None, ) -> tf.keras.Model: """Download artifacts from Hugging Face Hub and load the Keras model. Args: repo_id: Repository like `username/lambda-keras-model`. token: Optional HF token; otherwise use cached. revision: Optional git revision, tag, or commit. local_dir: Optional directory to place downloaded snapshot. Returns: Loaded tf.keras.Model """ if snapshot_download is None: raise RuntimeError( "huggingface-hub is not installed. Add it to dependencies and reinstall." ) cache_dir = snapshot_download( repo_id=repo_id, token=token, revision=revision, local_dir=local_dir, local_dir_use_symlinks=False, ) model_path = os.path.join(cache_dir, "lambda_model.keras") if not os.path.exists(model_path): raise FileNotFoundError(f"Model file not found in repo: {model_path}") model = tf.keras.models.load_model(model_path) return model def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Load the Lambda tf.keras model from Hugging Face Hub") parser.add_argument("--repo-id", type=str, required=True, help="Repo id, e.g. username/lambda-keras-model") parser.add_argument("--hf-token", type=str, default=None, help="Hugging Face token (optional)") parser.add_argument("--revision", type=str, default=None, help="Git revision, tag, or commit (optional)") parser.add_argument("--local-dir", type=str, default=None, help="Optional local directory for download") parser.add_argument("--run", action="store_true", help="Run a quick forward pass as a smoke test") return parser.parse_args() def main() -> None: args = parse_args() model = load_model_from_hub( repo_id=args.repo_id, token=args.hf_token, revision=args.revision, local_dir=args.local_dir, ) model.summary() if args.run: # Attempt a quick forward pass using shape derived from the model input input_shape = tuple(dim if dim is not None else 4 for dim in model.input_shape[1:]) example = tf.ones((1,) + input_shape) prediction = model(example) print("Example input:", example.numpy()) print("Model output:", prediction.numpy()) if __name__ == "__main__": main()