File size: 2,720 Bytes
15bec80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import argparse
from typing import Optional

import tensorflow as tf

try:
    from huggingface_hub import snapshot_download
except Exception:  # pragma: no cover
    snapshot_download = None  # type: ignore


def load_model_from_hub(
    repo_id: str,
    token: Optional[str] = None,
    revision: Optional[str] = None,
    local_dir: Optional[str] = None,
) -> tf.keras.Model:
    """Download artifacts from Hugging Face Hub and load the Keras model.

    Args:
        repo_id: Repository like `username/lambda-keras-model`.
        token: Optional HF token; otherwise use cached.
        revision: Optional git revision, tag, or commit.
        local_dir: Optional directory to place downloaded snapshot.

    Returns:
        Loaded tf.keras.Model
    """
    if snapshot_download is None:
        raise RuntimeError(
            "huggingface-hub is not installed. Add it to dependencies and reinstall."
        )

    cache_dir = snapshot_download(
        repo_id=repo_id,
        token=token,
        revision=revision,
        local_dir=local_dir,
        local_dir_use_symlinks=False,
    )

    model_path = os.path.join(cache_dir, "lambda_model.keras")
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model file not found in repo: {model_path}")

    model = tf.keras.models.load_model(model_path)
    return model


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Load the Lambda tf.keras model from Hugging Face Hub")
    parser.add_argument("--repo-id", type=str, required=True, help="Repo id, e.g. username/lambda-keras-model")
    parser.add_argument("--hf-token", type=str, default=None, help="Hugging Face token (optional)")
    parser.add_argument("--revision", type=str, default=None, help="Git revision, tag, or commit (optional)")
    parser.add_argument("--local-dir", type=str, default=None, help="Optional local directory for download")
    parser.add_argument("--run", action="store_true", help="Run a quick forward pass as a smoke test")
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    model = load_model_from_hub(
        repo_id=args.repo_id,
        token=args.hf_token,
        revision=args.revision,
        local_dir=args.local_dir,
    )

    model.summary()

    if args.run:
        # Attempt a quick forward pass using shape derived from the model input
        input_shape = tuple(dim if dim is not None else 4 for dim in model.input_shape[1:])
        example = tf.ones((1,) + input_shape)
        prediction = model(example)
        print("Example input:", example.numpy())
        print("Model output:", prediction.numpy())


if __name__ == "__main__":
    main()