Spaces:
Sleeping
Sleeping
LorenzoBioinfo
commited on
Commit
·
46a01d1
1
Parent(s):
4388c3f
Trainmodel
Browse files- src/train_model.py +3 -3
src/train_model.py
CHANGED
|
@@ -24,7 +24,7 @@ def compute_metrics(eval_pred):
|
|
| 24 |
return {"accuracy": acc["accuracy"], "f1": f1["f1"]}
|
| 25 |
|
| 26 |
|
| 27 |
-
def train_model(additional_data=None,sample_train_size=1000, sample_eval_size=300):
|
| 28 |
print("Caricamento dataset Tweet eval preprocessato")
|
| 29 |
dataset = load_from_disk(DATA_PATH)
|
| 30 |
if additional_data is not None:
|
|
@@ -64,8 +64,8 @@ def train_model(additional_data=None,sample_train_size=1000, sample_eval_size=30
|
|
| 64 |
|
| 65 |
trainer.train()
|
| 66 |
|
| 67 |
-
os.makedirs(
|
| 68 |
-
trainer.save_model(
|
| 69 |
print(f"Modello salvato in: {OUTPUT_DIR}")
|
| 70 |
|
| 71 |
if __name__ == "__main__":
|
|
|
|
| 24 |
return {"accuracy": acc["accuracy"], "f1": f1["f1"]}
|
| 25 |
|
| 26 |
|
| 27 |
+
def train_model(additional_data=None,sample_train_size=1000, sample_eval_size=300,output_dir=OUTPUT_DIR):
|
| 28 |
print("Caricamento dataset Tweet eval preprocessato")
|
| 29 |
dataset = load_from_disk(DATA_PATH)
|
| 30 |
if additional_data is not None:
|
|
|
|
| 64 |
|
| 65 |
trainer.train()
|
| 66 |
|
| 67 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 68 |
+
trainer.save_model(output_dir)
|
| 69 |
print(f"Modello salvato in: {OUTPUT_DIR}")
|
| 70 |
|
| 71 |
if __name__ == "__main__":
|