Spaces:
Runtime error
Runtime error
Commit
·
012daa9
1
Parent(s):
0a02c34
Update functions/modelling_function.py
Browse files
functions/modelling_function.py
CHANGED
|
@@ -78,7 +78,7 @@ def category_reassign(row, reference_df, checked_category, threshold=70):
|
|
| 78 |
else:
|
| 79 |
return row['category_name']
|
| 80 |
|
| 81 |
-
def train_model(df, stratify=True, model_type='bert', use_existing_model=False, model_name=None):
|
| 82 |
"""
|
| 83 |
This function trains the model using the configuration in config.yaml
|
| 84 |
|
|
@@ -98,7 +98,7 @@ def train_model(df, stratify=True, model_type='bert', use_existing_model=False,
|
|
| 98 |
warnings.filterwarnings('ignore')
|
| 99 |
|
| 100 |
test_size = yaml.load(open('config.yaml'), Loader=yaml.FullLoader)['parameters']['training_args']['test_size']
|
| 101 |
-
train_df, test_df = train_test_split(df, test_size=test_size, stratify=df[
|
| 102 |
|
| 103 |
# Optional model configuration
|
| 104 |
model_config = yaml.load(open('config.yaml'), Loader=yaml.FullLoader)['parameters']['model_args']
|
|
@@ -112,7 +112,7 @@ def train_model(df, stratify=True, model_type='bert', use_existing_model=False,
|
|
| 112 |
|
| 113 |
# Create a ClassificationModel
|
| 114 |
model_detail = yaml.load(open('config.yaml'), Loader=yaml.FullLoader)['parameters']['model_types']
|
| 115 |
-
class_names = yaml.load(open('config.yaml'), Loader=yaml.FullLoader)['parameters']['class_names']
|
| 116 |
|
| 117 |
if use_existing_model:
|
| 118 |
model = ClassificationModel(model_type, model_name, num_labels=len(class_names), args=model_args, use_cuda=False)
|
|
@@ -125,7 +125,7 @@ def train_model(df, stratify=True, model_type='bert', use_existing_model=False,
|
|
| 125 |
# Evaluate the model
|
| 126 |
result, model_outputs, wrong_predictions = model.eval_model(test_df)
|
| 127 |
preds = np.argmax(model_outputs, axis=1)
|
| 128 |
-
class_report =classification_report(test_df[
|
| 129 |
|
| 130 |
return model, preds, class_report, train_df, test_df, class_names
|
| 131 |
|
|
|
|
| 78 |
else:
|
| 79 |
return row['category_name']
|
| 80 |
|
| 81 |
+
def train_model(df, train_type, label_column, stratify=True, model_type='bert', use_existing_model=False, model_name=None):
|
| 82 |
"""
|
| 83 |
This function trains the model using the configuration in config.yaml
|
| 84 |
|
|
|
|
| 98 |
warnings.filterwarnings('ignore')
|
| 99 |
|
| 100 |
test_size = yaml.load(open('config.yaml'), Loader=yaml.FullLoader)['parameters']['training_args']['test_size']
|
| 101 |
+
train_df, test_df = train_test_split(df, test_size=test_size, stratify=df[label_column])
|
| 102 |
|
| 103 |
# Optional model configuration
|
| 104 |
model_config = yaml.load(open('config.yaml'), Loader=yaml.FullLoader)['parameters']['model_args']
|
|
|
|
| 112 |
|
| 113 |
# Create a ClassificationModel
|
| 114 |
model_detail = yaml.load(open('config.yaml'), Loader=yaml.FullLoader)['parameters']['model_types']
|
| 115 |
+
class_names = yaml.load(open('config.yaml'), Loader=yaml.FullLoader)['parameters']['class_names'][train_type]
|
| 116 |
|
| 117 |
if use_existing_model:
|
| 118 |
model = ClassificationModel(model_type, model_name, num_labels=len(class_names), args=model_args, use_cuda=False)
|
|
|
|
| 125 |
# Evaluate the model
|
| 126 |
result, model_outputs, wrong_predictions = model.eval_model(test_df)
|
| 127 |
preds = np.argmax(model_outputs, axis=1)
|
| 128 |
+
class_report =classification_report(test_df[label_column], preds, target_names=class_names)
|
| 129 |
|
| 130 |
return model, preds, class_report, train_df, test_df, class_names
|
| 131 |
|