LorenzoBioinfo commited on
Commit
46a01d1
·
1 Parent(s): 4388c3f

Trainmodel

Browse files
Files changed (1) hide show
  1. src/train_model.py +3 -3
src/train_model.py CHANGED
@@ -24,7 +24,7 @@ def compute_metrics(eval_pred):
24
  return {"accuracy": acc["accuracy"], "f1": f1["f1"]}
25
 
26
 
27
- def train_model(additional_data=None,sample_train_size=1000, sample_eval_size=300):
28
  print("Caricamento dataset Tweet eval preprocessato")
29
  dataset = load_from_disk(DATA_PATH)
30
  if additional_data is not None:
@@ -64,8 +64,8 @@ def train_model(additional_data=None,sample_train_size=1000, sample_eval_size=30
64
 
65
  trainer.train()
66
 
67
- os.makedirs(OUTPUT_DIR, exist_ok=True)
68
- trainer.save_model(OUTPUT_DIR)
69
  print(f"Modello salvato in: {OUTPUT_DIR}")
70
 
71
  if __name__ == "__main__":
 
24
  return {"accuracy": acc["accuracy"], "f1": f1["f1"]}
25
 
26
 
27
+ def train_model(additional_data=None,sample_train_size=1000, sample_eval_size=300,output_dir=OUTPUT_DIR):
28
  print("Caricamento dataset Tweet eval preprocessato")
29
  dataset = load_from_disk(DATA_PATH)
30
  if additional_data is not None:
 
64
 
65
  trainer.train()
66
 
67
+ os.makedirs(output_dir, exist_ok=True)
68
+ trainer.save_model(output_dir)
69
  print(f"Modello salvato in: {OUTPUT_DIR}")
70
 
71
  if __name__ == "__main__":