Spaces:
Sleeping
Sleeping
some improvements
Browse files
src/image_classification/image_classification_trainer.py
CHANGED
|
@@ -52,6 +52,9 @@ class ImageClassificationTrainer(AbstractTrainer):
|
|
| 52 |
|
| 53 |
self.__train_model(images, parameters)
|
| 54 |
|
|
|
|
|
|
|
|
|
|
| 55 |
self.get_status().update_status(100, "Training completed")
|
| 56 |
|
| 57 |
except Exception as e:
|
|
|
|
| 52 |
|
| 53 |
self.__train_model(images, parameters)
|
| 54 |
|
| 55 |
+
if(self.get_status().is_training_aborted()):
|
| 56 |
+
return
|
| 57 |
+
|
| 58 |
self.get_status().update_status(100, "Training completed")
|
| 59 |
|
| 60 |
except Exception as e:
|
src/main.py
CHANGED
|
@@ -75,7 +75,7 @@ async def get_task_status(token_data: dict = Depends(verify_token)):
|
|
| 75 |
"status": status.get_status().value
|
| 76 |
}
|
| 77 |
|
| 78 |
-
@app.
|
| 79 |
async def stop_task(token_data: dict = Depends(verify_token)):
|
| 80 |
""" Stop the currently running training (if any). """
|
| 81 |
try:
|
|
|
|
| 75 |
"status": status.get_status().value
|
| 76 |
}
|
| 77 |
|
| 78 |
+
@app.put("/stop_training")
|
| 79 |
async def stop_task(token_data: dict = Depends(verify_token)):
|
| 80 |
""" Stop the currently running training (if any). """
|
| 81 |
try:
|
src/text_classification/text_classification_trainer.py
CHANGED
|
@@ -44,6 +44,9 @@ class TextClassificationTrainer(AbstractTrainer):
|
|
| 44 |
|
| 45 |
self.__train_model(tokenized_dataset, labels, label2id, id2label, parameters)
|
| 46 |
|
|
|
|
|
|
|
|
|
|
| 47 |
self.get_status().update_status(100, "Training completed")
|
| 48 |
|
| 49 |
except Exception as e:
|
|
@@ -66,6 +69,18 @@ class TextClassificationTrainer(AbstractTrainer):
|
|
| 66 |
dataset = load_dataset('csv', data_files=parameters.get_training_csv_file_path(), delimiter=parameters.get_training_csv_limiter())
|
| 67 |
|
| 68 |
dataset = dataset["train"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
dataset = dataset.train_test_split(test_size=0.2)
|
| 70 |
|
| 71 |
logger.info(dataset)
|
|
@@ -79,15 +94,6 @@ class TextClassificationTrainer(AbstractTrainer):
|
|
| 79 |
|
| 80 |
tokenized_dataset = dataset.map(preprocess_function, batched=True)
|
| 81 |
|
| 82 |
-
# Extract the labels
|
| 83 |
-
labels = tokenized_dataset['train'].unique('target')
|
| 84 |
-
label2id, id2label = dict(), dict()
|
| 85 |
-
for i, label in enumerate(labels):
|
| 86 |
-
label2id[label] = i
|
| 87 |
-
id2label[i] = label
|
| 88 |
-
|
| 89 |
-
logger.info(id2label)
|
| 90 |
-
|
| 91 |
# Rename the Target column to labels and remove unnecessary columns
|
| 92 |
tokenized_dataset = tokenized_dataset.rename_column('target', 'labels')
|
| 93 |
|
|
|
|
| 44 |
|
| 45 |
self.__train_model(tokenized_dataset, labels, label2id, id2label, parameters)
|
| 46 |
|
| 47 |
+
if(self.get_status().is_training_aborted()):
|
| 48 |
+
return
|
| 49 |
+
|
| 50 |
self.get_status().update_status(100, "Training completed")
|
| 51 |
|
| 52 |
except Exception as e:
|
|
|
|
| 69 |
dataset = load_dataset('csv', data_files=parameters.get_training_csv_file_path(), delimiter=parameters.get_training_csv_limiter())
|
| 70 |
|
| 71 |
dataset = dataset["train"]
|
| 72 |
+
|
| 73 |
+
# Extract the labels
|
| 74 |
+
#labels = tokenized_dataset['train'].unique('target')
|
| 75 |
+
labels = dataset.unique('target')
|
| 76 |
+
label2id, id2label = dict(), dict()
|
| 77 |
+
for i, label in enumerate(labels):
|
| 78 |
+
label2id[label] = i
|
| 79 |
+
id2label[i] = label
|
| 80 |
+
|
| 81 |
+
logger.info(id2label)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
dataset = dataset.train_test_split(test_size=0.2)
|
| 85 |
|
| 86 |
logger.info(dataset)
|
|
|
|
| 94 |
|
| 95 |
tokenized_dataset = dataset.map(preprocess_function, batched=True)
|
| 96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
# Rename the Target column to labels and remove unnecessary columns
|
| 98 |
tokenized_dataset = tokenized_dataset.rename_column('target', 'labels')
|
| 99 |
|