| |
| import argparse |
| from huggingface_hub import HfApi |
|
|
|
|
| def main(api, model_id): |
| info = api.list_repo_refs(model_id) |
| branches = set([b.name for b in info.branches]) - set(["main"]) |
|
|
| return list(branches) |
|
|
|
|
| if __name__ == "__main__": |
| DESCRIPTION = """ |
| Simple utility to get all branches from a repo |
| """ |
| parser = argparse.ArgumentParser(description=DESCRIPTION) |
| parser.add_argument( |
| "--model_id", |
| type=str, |
| help="The name of the model on the hub to retrieve the branches from. E.g. `gpt2` or `facebook/wav2vec2-base-960h`", |
| ) |
|
|
| args = parser.parse_args() |
| model_id = args.model_id |
| api = HfApi() |
| branches = main(api, model_id) |
|
|
| if "fp16" in branches: |
| print(model_id) |
| |
| |
| |
|
|