Spaces:
Sleeping
Sleeping
| import json | |
| from pathlib import Path | |
| from typing import Optional, Union | |
| from functools import wraps | |
| from huggingface_hub import ( | |
| PyTorchModelHubMixin, | |
| ModelCard, | |
| ModelCardData, | |
| hf_hub_download, | |
| ) | |
| MODEL_CARD = """ | |
| --- | |
| {{ card_data }} | |
| --- | |
| # {{ model_name }} Model Card | |
| Table of Contents: | |
| - [Load trained model](#load-trained-model) | |
| - [Model init parameters](#model-init-parameters) | |
| - [Model metrics](#model-metrics) | |
| - [Dataset](#dataset) | |
| ## Load trained model | |
| ```python | |
| import feature_extractor_models as smp | |
| model = smp.{{ model_name }}.from_pretrained("{{ save_directory | default("<save-directory-or-repo>", true)}}") | |
| ``` | |
| ## Model init parameters | |
| ```python | |
| model_init_params = {{ model_parameters }} | |
| ``` | |
| ## Model metrics | |
| {{ metrics | default("[More Information Needed]", true) }} | |
| ## Dataset | |
| Dataset name: {{ dataset | default("[More Information Needed]", true) }} | |
| ## More Information | |
| - Library: {{ repo_url | default("[More Information Needed]", true) }} | |
| - Docs: {{ docs_url | default("[More Information Needed]", true) }} | |
| This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) | |
| """ | |
| def _format_parameters(parameters: dict): | |
| params = {k: v for k, v in parameters.items() if not k.startswith("_")} | |
| params = [ | |
| f'"{k}": {v}' if not isinstance(v, str) else f'"{k}": "{v}"' | |
| for k, v in params.items() | |
| ] | |
| params = ",\n".join([f" {param}" for param in params]) | |
| params = "{\n" + f"{params}" + "\n}" | |
| return params | |
| class SMPHubMixin(PyTorchModelHubMixin): | |
| def generate_model_card(self, *args, **kwargs) -> ModelCard: | |
| model_parameters_json = _format_parameters(self._hub_mixin_config) | |
| directory = self._save_directory if hasattr(self, "_save_directory") else None | |
| repo_id = self._repo_id if hasattr(self, "_repo_id") else None | |
| repo_or_directory = repo_id if repo_id is not None else directory | |
| metrics = self._metrics if hasattr(self, "_metrics") else None | |
| dataset = self._dataset if hasattr(self, "_dataset") else None | |
| if metrics is not None: | |
| metrics = json.dumps(metrics, indent=4) | |
| metrics = f"```json\n{metrics}\n```" | |
| model_card_data = ModelCardData( | |
| languages=["python"], | |
| library_name="segmentation-models-pytorch", | |
| license="mit", | |
| tags=["semantic-segmentation", "pytorch", "segmentation-models-pytorch"], | |
| pipeline_tag="image-segmentation", | |
| ) | |
| model_card = ModelCard.from_template( | |
| card_data=model_card_data, | |
| template_str=MODEL_CARD, | |
| repo_url="https://github.com/qubvel/segmentation_models.pytorch", | |
| docs_url="https://smp.readthedocs.io/en/latest/", | |
| model_parameters=model_parameters_json, | |
| save_directory=repo_or_directory, | |
| model_name=self.__class__.__name__, | |
| metrics=metrics, | |
| dataset=dataset, | |
| ) | |
| return model_card | |
| def _set_attrs_from_kwargs(self, attrs, kwargs): | |
| for attr in attrs: | |
| if attr in kwargs: | |
| setattr(self, f"_{attr}", kwargs.pop(attr)) | |
| def _del_attrs(self, attrs): | |
| for attr in attrs: | |
| if hasattr(self, f"_{attr}"): | |
| delattr(self, f"_{attr}") | |
| def save_pretrained( | |
| self, save_directory: Union[str, Path], *args, **kwargs | |
| ) -> Optional[str]: | |
| # set additional attributes to be used in generate_model_card | |
| self._save_directory = save_directory | |
| self._set_attrs_from_kwargs(["metrics", "dataset"], kwargs) | |
| # set additional attribute to be used in from_pretrained | |
| self._hub_mixin_config["_model_class"] = self.__class__.__name__ | |
| try: | |
| # call the original save_pretrained | |
| result = super().save_pretrained(save_directory, *args, **kwargs) | |
| finally: | |
| # delete the additional attributes | |
| self._del_attrs(["save_directory", "metrics", "dataset"]) | |
| self._hub_mixin_config.pop("_model_class") | |
| return result | |
| def push_to_hub(self, repo_id: str, *args, **kwargs): | |
| self._repo_id = repo_id | |
| self._set_attrs_from_kwargs(["metrics", "dataset"], kwargs) | |
| result = super().push_to_hub(repo_id, *args, **kwargs) | |
| self._del_attrs(["repo_id", "metrics", "dataset"]) | |
| return result | |
| def config(self): | |
| return self._hub_mixin_config | |
| def from_pretrained(pretrained_model_name_or_path: str, *args, **kwargs): | |
| config_path = hf_hub_download( | |
| pretrained_model_name_or_path, | |
| filename="config.json", | |
| revision=kwargs.get("revision", None), | |
| ) | |
| with open(config_path, "r") as f: | |
| config = json.load(f) | |
| model_class_name = config.pop("_model_class") | |
| import feature_extractor_models as smp | |
| model_class = getattr(smp, model_class_name) | |
| return model_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) | |