Daddario commited on
Commit
2680fc6
·
verified ·
1 Parent(s): b69d3c2

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +60 -0
train.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertTokenizer, BertForTokenClassification, Trainer, TrainingArguments
2
+ from datasets import load_dataset
3
+
4
+ # Carica il dataset dalla repository (sostituisci con il nome corretto dei file)
5
+ dataset = load_dataset("daddario/hotel", data_files="entity_dataset.json")
6
+
7
+ # Carica il tokenizer pre-addestrato
8
+ tokenizer = BertTokenizer.from_pretrained('dbmdz/bert-base-italian-uncased')
9
+
10
+ # Funzione per tokenizzare il testo e allineare le etichette
11
+ def tokenize_and_align_labels(examples):
12
+ tokenized_inputs = tokenizer(examples['query'], truncation=True, padding=True, is_split_into_words=False)
13
+ labels = []
14
+
15
+ for i, word in enumerate(examples['query'].split()):
16
+ label = 'O' # Default 'O' (no entity)
17
+
18
+ for entity, value in examples['entities'].items():
19
+ if word in value:
20
+ label = entity
21
+ break
22
+ labels.append(label)
23
+
24
+ tokenized_inputs["labels"] = labels
25
+ return tokenized_inputs
26
+
27
+ # Preprocessamento dei dati
28
+ dataset = dataset.map(tokenize_and_align_labels, batched=True)
29
+
30
+ # Carica il modello per il riconoscimento delle entità
31
+ model = BertForTokenClassification.from_pretrained('dbmdz/bert-base-italian-uncased', num_labels=len(dataset['train'].features['labels'].feature))
32
+
33
+ # Configurazione dell'addestramento
34
+ training_args = TrainingArguments(
35
+ output_dir='./results',
36
+ evaluation_strategy="epoch",
37
+ learning_rate=2e-5,
38
+ per_device_train_batch_size=16,
39
+ num_train_epochs=3,
40
+ logging_dir='./logs',
41
+ logging_steps=10,
42
+ weight_decay=0.01,
43
+ save_steps=10_000,
44
+ load_best_model_at_end=True,
45
+ )
46
+
47
+ # Crea il Trainer per addestrare il modello
48
+ trainer = Trainer(
49
+ model=model,
50
+ args=training_args,
51
+ train_dataset=dataset['train'],
52
+ eval_dataset=dataset['test'],
53
+ )
54
+
55
+ # Addestramento del modello
56
+ trainer.train()
57
+
58
+ # Salva il modello e il tokenizer nella repository
59
+ model.save_pretrained("daddario/hotel")
60
+ tokenizer.save_pretrained("daddario/hotel")