hotel / train.py
Daddario's picture
Create train.py
2680fc6 verified
from transformers import BertTokenizer, BertForTokenClassification, Trainer, TrainingArguments
from datasets import load_dataset
# Carica il dataset dalla repository (sostituisci con il nome corretto dei file)
dataset = load_dataset("daddario/hotel", data_files="entity_dataset.json")
# Carica il tokenizer pre-addestrato
tokenizer = BertTokenizer.from_pretrained('dbmdz/bert-base-italian-uncased')
# Funzione per tokenizzare il testo e allineare le etichette
def tokenize_and_align_labels(examples):
tokenized_inputs = tokenizer(examples['query'], truncation=True, padding=True, is_split_into_words=False)
labels = []
for i, word in enumerate(examples['query'].split()):
label = 'O' # Default 'O' (no entity)
for entity, value in examples['entities'].items():
if word in value:
label = entity
break
labels.append(label)
tokenized_inputs["labels"] = labels
return tokenized_inputs
# Preprocessamento dei dati
dataset = dataset.map(tokenize_and_align_labels, batched=True)
# Carica il modello per il riconoscimento delle entità
model = BertForTokenClassification.from_pretrained('dbmdz/bert-base-italian-uncased', num_labels=len(dataset['train'].features['labels'].feature))
# Configurazione dell'addestramento
training_args = TrainingArguments(
output_dir='./results',
evaluation_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=16,
num_train_epochs=3,
logging_dir='./logs',
logging_steps=10,
weight_decay=0.01,
save_steps=10_000,
load_best_model_at_end=True,
)
# Crea il Trainer per addestrare il modello
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset['train'],
eval_dataset=dataset['test'],
)
# Addestramento del modello
trainer.train()
# Salva il modello e il tokenizer nella repository
model.save_pretrained("daddario/hotel")
tokenizer.save_pretrained("daddario/hotel")