| from clip.clip import tokenize as _tokenize, load as _load, available_models as _available_models |
| import re |
| import string |
|
|
| dependencies = ["torch", "torchvision", "ftfy", "regex", "tqdm"] |
|
|
| |
| model_functions = { model: re.sub(f'[{string.punctuation}]', '_', model) for model in _available_models()} |
|
|
| def _create_hub_entrypoint(model): |
| def entrypoint(**kwargs): |
| return _load(model, **kwargs) |
| |
| entrypoint.__doc__ = f"""Loads the {model} CLIP model |
| |
| Parameters |
| ---------- |
| device : Union[str, torch.device] |
| The device to put the loaded model |
| |
| jit : bool |
| Whether to load the optimized JIT model or more hackable non-JIT model (default). |
| |
| download_root: str |
| path to download the model files; by default, it uses "~/.cache/clip" |
| |
| Returns |
| ------- |
| model : torch.nn.Module |
| The {model} CLIP model |
| |
| preprocess : Callable[[PIL.Image], torch.Tensor] |
| A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input |
| """ |
| return entrypoint |
|
|
| def tokenize(): |
| return _tokenize |
|
|
| _entrypoints = {model_functions[model]: _create_hub_entrypoint(model) for model in _available_models()} |
|
|
| globals().update(_entrypoints) |