| import urllib.request |
| import tarfile |
| from tqdm import tqdm |
| import os |
| import yaml |
| from ruamel.yaml import YAML |
|
|
| def read_plainconfig(configname): |
| if not os.path.exists(configname): |
| raise FileNotFoundError( |
| f"Config {configname} is not found. Please make sure that the file exists." |
| ) |
| with open(configname) as file: |
| return YAML().load(file) |
|
|
| def DownloadModel(modelname, target_dir): |
| """ |
| Downloads a DeepLabCut Model Zoo Project |
| """ |
| |
| def show_progress(count, block_size, total_size): |
| pbar.update(block_size) |
|
|
| def tarfilenamecutting(tarf): |
| """' auxfun to extract folder path |
| ie. /xyz-trainsetxyshufflez/ |
| """ |
| for memberid, member in enumerate(tarf.getmembers()): |
| if memberid == 0: |
| parent = str(member.path) |
| l = len(parent) + 1 |
| if member.path.startswith(parent): |
| member.path = member.path[l:] |
| yield member |
|
|
| neturls = read_plainconfig("./model/pretrained_model_urls.yaml") |
| |
| if modelname in neturls.keys(): |
| url = neturls[modelname] |
| print(url) |
| response = urllib.request.urlopen(url) |
| print( |
| "Downloading the model from the DeepLabCut server @Harvard -> Go Crimson!!! {}....".format( |
| url |
| ) |
| ) |
| total_size = int(response.getheader("Content-Length")) |
| pbar = tqdm(unit="B", total=total_size, position=0) |
| filename, _ = urllib.request.urlretrieve(url, reporthook=show_progress) |
| with tarfile.open(filename, mode="r:gz") as tar: |
| tar.extractall(target_dir, members=tarfilenamecutting(tar)) |
| else: |
| models = [ |
| fn |
| for fn in neturls.keys() |
| if "resnet_" not in fn and "mobilenet_" not in fn |
| ] |
| print("Model does not exist: ", modelname) |
| print("Pick one of the following: ", models) |
| return target_dir |
|
|