Spaces:
Sleeping
Sleeping
| # ------------------------------------------------------------------- | |
| # Pimcore | |
| # | |
| # This source file is available under two different licenses: | |
| # - GNU General Public License version 3 (GPLv3) | |
| # - Pimcore Commercial License (PCL) | |
| # Full copyright and license information is available in | |
| # LICENSE.md which is distributed with this source code. | |
| # | |
| # @copyright Copyright (c) Pimcore GmbH (http://www.pimcore.org) | |
| # @license http://www.pimcore.org/license GPLv3 and PCL | |
| # ------------------------------------------------------------------- | |
| from pydantic import BaseModel | |
| from typing import Annotated | |
| from fastapi import Form | |
| class TextClassificationTrainingParameters(BaseModel): | |
| """ Provides specific training parameters for the text classification fine tuning.""" | |
| epochs: int | |
| learning_rate: float | |
| def map_text_classification_training_parameters( | |
| epocs: Annotated[int, Form(description="Epochs executed during training.")] = 3, | |
| learning_rate: Annotated[float, Form(description="Learning rate for training.")] = 5e-5 | |
| ) -> TextClassificationTrainingParameters: | |
| """ Maps the parameters to the TextClassificationTrainingParameters class. """ | |
| return TextClassificationTrainingParameters( | |
| epochs=epocs, | |
| learning_rate=learning_rate | |
| ) | |
| class TextClassificationParameters: | |
| """ Provides all parameters for the text classification fine tuning. """ | |
| __training_csv_file_path: str | |
| __training_csv_limiter: str | |
| __project_name: str | |
| __source_model_name: str | |
| __training_parameters: TextClassificationTrainingParameters | |
| def __init__(self, | |
| training_csv_file_path: str, | |
| project_name: str, | |
| source_model_name: str, | |
| training_parameters: TextClassificationTrainingParameters, | |
| training_csv_limiter: str = ';' | |
| ): | |
| self.__training_csv_file_path = training_csv_file_path | |
| self.__project_name = project_name | |
| self.__source_model_name = source_model_name | |
| self.__training_parameters = training_parameters | |
| self.__training_csv_limiter = training_csv_limiter | |
| def get_training_csv_file_path(self) -> str: | |
| return self.__training_csv_file_path | |
| def get_training_csv_limiter(self) -> str: | |
| return self.__training_csv_limiter | |
| def get_project_name(self) -> str: | |
| return self.__project_name | |
| def get_result_model_name(self) -> str: | |
| return self.__project_name | |
| def get_source_model_name(self) -> str: | |
| return self.__source_model_name | |
| def get_training_parameters(self) -> TextClassificationTrainingParameters: | |
| return self.__training_parameters | |