Spaces:
Sleeping
Sleeping
| # ------------------------------------------------------------------- | |
| # Pimcore | |
| # | |
| # This source file is available under two different licenses: | |
| # - GNU General Public License version 3 (GPLv3) | |
| # - Pimcore Commercial License (PCL) | |
| # Full copyright and license information is available in | |
| # LICENSE.md which is distributed with this source code. | |
| # | |
| # @copyright Copyright (c) Pimcore GmbH (http://www.pimcore.org) | |
| # @license http://www.pimcore.org/license GPLv3 and PCL | |
| # ------------------------------------------------------------------- | |
| import logging | |
| from ..progress_callback import ProgressCallback | |
| from ..abstract_trainer import AbstractTrainer | |
| from ..environment_variable_checker import EnvironmentVariableChecker | |
| from .image_classification_parameters import ImageClassificationParameters | |
| import zipfile | |
| import os | |
| import shutil | |
| from datasets import load_dataset | |
| from transformers import AutoImageProcessor, DefaultDataCollator, AutoModelForImageClassification, TrainingArguments, Trainer, TrainerState, TrainerControl | |
| from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor | |
| from huggingface_hub import HfFolder | |
| import evaluate | |
| import numpy as np | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.DEBUG) | |
| class ImageClassificationTrainer(AbstractTrainer): | |
| def start_training(self, parameters: ImageClassificationParameters): | |
| logger.info('Start Training...') | |
| try: | |
| task = 'Extract training data' | |
| self.get_status().update_status(0, task, parameters.get_project_name()) | |
| logger.info(task) | |
| self.__extract_training_data(parameters) | |
| if(self.get_status().is_training_aborted()): | |
| return | |
| task = 'Prepare Data set' | |
| self.get_status().update_status(10, task) | |
| logger.info(task) | |
| images = self.__prepare_data_set(parameters) | |
| if(self.get_status().is_training_aborted()): | |
| return | |
| task = 'Start training model' | |
| self.get_status().update_status(20, task) | |
| logger.info(task) | |
| self.__train_model(images, parameters) | |
| if(self.get_status().is_training_aborted()): | |
| return | |
| self.get_status().update_status(100, "Training completed") | |
| except Exception as e: | |
| logger.error(e) | |
| self.get_status().finalize_abort_training(str(e)) | |
| raise RuntimeError(f"An error occurred: {str(e)}") | |
| finally: | |
| # Cleanup after processing | |
| logger.info('Cleaning up training files after training') | |
| shutil.rmtree(parameters.get_training_files_path()) | |
| if(self.get_status().is_training_aborted()): | |
| self.get_status().finalize_abort_training("Training aborted") | |
| def __extract_training_data(self, parameters: ImageClassificationParameters): | |
| training_file = parameters.get_training_zip_file() | |
| # Check if it is a valid ZIP file | |
| if not zipfile.is_zipfile(training_file): | |
| raise RuntimeError("Uploaded file is not a valid zip file") | |
| # Extract the ZIP file | |
| with zipfile.ZipFile(training_file, 'r') as zip_ref: | |
| zip_ref.extractall(parameters.get_training_files_path()) | |
| os.remove(training_file) | |
| logger.info(os.listdir(parameters.get_training_files_path())) | |
| def __prepare_data_set(self, parameters: ImageClassificationParameters) -> dict: | |
| dataset = load_dataset("imagefolder", data_dir=parameters.get_training_files_path()) | |
| images = dataset["train"] | |
| images = images.train_test_split(test_size=0.2) | |
| logger.info(images) | |
| logger.info(images["train"][10]) | |
| # Preprocess the images | |
| image_processor = AutoImageProcessor.from_pretrained(parameters.get_source_model_name()) | |
| # Apply some image transformations to the images to make the model more robust against overfitting. | |
| normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std) | |
| size = ( | |
| image_processor.size["shortest_edge"] | |
| if "shortest_edge" in image_processor.size | |
| else (image_processor.size["height"], image_processor.size["width"]) | |
| ) | |
| _transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize]) | |
| def transforms(examples): | |
| examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]] | |
| del examples["image"] | |
| return examples | |
| images = images.with_transform(transforms) | |
| return images | |
| def __train_model(self, images: dict, parameters: ImageClassificationParameters): | |
| environment_variable_checker = EnvironmentVariableChecker() | |
| HfFolder.save_token(environment_variable_checker.get_huggingface_token()) | |
| image_processor = AutoImageProcessor.from_pretrained(parameters.get_source_model_name()) | |
| data_collator = DefaultDataCollator() | |
| progressCallback = ProgressCallback(self.get_status(), 21, 89) | |
| # Evaluate and metrics | |
| accuracy = evaluate.load("accuracy") | |
| def compute_metrics(eval_pred): | |
| predictions, labels = eval_pred | |
| predictions = np.argmax(predictions, axis=1) | |
| return accuracy.compute(predictions=predictions, references=labels) | |
| # get label maps | |
| labels = images["train"].features["label"].names | |
| label2id, id2label = dict(), dict() | |
| for i, label in enumerate(labels): | |
| label2id[label] = str(i) | |
| id2label[str(i)] = label | |
| logger.info(id2label) | |
| # train the model | |
| model = AutoModelForImageClassification.from_pretrained( | |
| parameters.get_source_model_name(), | |
| num_labels=len(labels), | |
| id2label=id2label, | |
| label2id=label2id, | |
| ) | |
| target_model_id = environment_variable_checker.get_huggingface_organization() + '/' + parameters.get_result_model_name() | |
| training_args = TrainingArguments( | |
| output_dir=parameters.get_result_model_name(), | |
| hub_model_id=target_model_id, | |
| remove_unused_columns=False, | |
| eval_strategy="epoch", | |
| save_strategy="epoch", | |
| learning_rate=parameters.get_training_parameters().learning_rate, | |
| per_device_train_batch_size=16, | |
| gradient_accumulation_steps=4, | |
| per_device_eval_batch_size=16, | |
| num_train_epochs=parameters.get_training_parameters().epochs, | |
| warmup_ratio=0.1, | |
| logging_steps=10, | |
| load_best_model_at_end=True, | |
| metric_for_best_model="accuracy", | |
| push_to_hub=False, | |
| hub_private_repo=True, | |
| ) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| data_collator=data_collator, | |
| train_dataset=images["train"], | |
| eval_dataset=images["test"], | |
| tokenizer=image_processor, | |
| compute_metrics=compute_metrics, | |
| callbacks=[progressCallback] | |
| ) | |
| if(self.get_status().is_training_aborted()): | |
| return | |
| trainer.train() | |
| if(self.get_status().is_training_aborted()): | |
| return | |
| logger.info(f"Model trained, start uploading") | |
| self.get_status().update_status(90, f"Uploading model to Hugging Face") | |
| trainer.push_to_hub() |