|
|
import os |
|
|
import argparse |
|
|
from typing import Optional |
|
|
|
|
|
import tensorflow as tf |
|
|
|
|
|
try: |
|
|
from huggingface_hub import snapshot_download |
|
|
except Exception: |
|
|
snapshot_download = None |
|
|
|
|
|
|
|
|
def load_model_from_hub( |
|
|
repo_id: str, |
|
|
token: Optional[str] = None, |
|
|
revision: Optional[str] = None, |
|
|
local_dir: Optional[str] = None, |
|
|
) -> tf.keras.Model: |
|
|
"""Download artifacts from Hugging Face Hub and load the Keras model. |
|
|
|
|
|
Args: |
|
|
repo_id: Repository like `username/lambda-keras-model`. |
|
|
token: Optional HF token; otherwise use cached. |
|
|
revision: Optional git revision, tag, or commit. |
|
|
local_dir: Optional directory to place downloaded snapshot. |
|
|
|
|
|
Returns: |
|
|
Loaded tf.keras.Model |
|
|
""" |
|
|
if snapshot_download is None: |
|
|
raise RuntimeError( |
|
|
"huggingface-hub is not installed. Add it to dependencies and reinstall." |
|
|
) |
|
|
|
|
|
cache_dir = snapshot_download( |
|
|
repo_id=repo_id, |
|
|
token=token, |
|
|
revision=revision, |
|
|
local_dir=local_dir, |
|
|
local_dir_use_symlinks=False, |
|
|
) |
|
|
|
|
|
model_path = os.path.join(cache_dir, "lambda_model.keras") |
|
|
if not os.path.exists(model_path): |
|
|
raise FileNotFoundError(f"Model file not found in repo: {model_path}") |
|
|
|
|
|
model = tf.keras.models.load_model(model_path) |
|
|
return model |
|
|
|
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
|
parser = argparse.ArgumentParser(description="Load the Lambda tf.keras model from Hugging Face Hub") |
|
|
parser.add_argument("--repo-id", type=str, required=True, help="Repo id, e.g. username/lambda-keras-model") |
|
|
parser.add_argument("--hf-token", type=str, default=None, help="Hugging Face token (optional)") |
|
|
parser.add_argument("--revision", type=str, default=None, help="Git revision, tag, or commit (optional)") |
|
|
parser.add_argument("--local-dir", type=str, default=None, help="Optional local directory for download") |
|
|
parser.add_argument("--run", action="store_true", help="Run a quick forward pass as a smoke test") |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
args = parse_args() |
|
|
model = load_model_from_hub( |
|
|
repo_id=args.repo_id, |
|
|
token=args.hf_token, |
|
|
revision=args.revision, |
|
|
local_dir=args.local_dir, |
|
|
) |
|
|
|
|
|
model.summary() |
|
|
|
|
|
if args.run: |
|
|
|
|
|
input_shape = tuple(dim if dim is not None else 4 for dim in model.input_shape[1:]) |
|
|
example = tf.ones((1,) + input_shape) |
|
|
prediction = model(example) |
|
|
print("Example input:", example.numpy()) |
|
|
print("Model output:", prediction.numpy()) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|
|
|
|