Commit
·
06e46d2
1
Parent(s):
026cb58
Update Sentiment_analysis_with_bert.py
Browse files- Sentiment_analysis_with_bert.py +40 -42
Sentiment_analysis_with_bert.py
CHANGED
|
@@ -303,48 +303,46 @@ def eval_model(model, data_loader, loss_fn, device, n_examples):
|
|
| 303 |
|
| 304 |
return correct_predictions.double() / n_examples, np.mean(losses)
|
| 305 |
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
# torch.save(model.state_dict(), 'best_model_state.bin')
|
| 347 |
-
# best_accuracy = val_acc
|
| 348 |
|
| 349 |
print(history['train_acc'])
|
| 350 |
|
|
|
|
| 303 |
|
| 304 |
return correct_predictions.double() / n_examples, np.mean(losses)
|
| 305 |
|
| 306 |
+
%%time
|
| 307 |
+
history = defaultdict(list)
|
| 308 |
+
best_accuracy = 0
|
| 309 |
+
|
| 310 |
+
for epoch in range(EPOCHS):
|
| 311 |
+
|
| 312 |
+
print(f'Epoch {epoch + 1}/{EPOCHS}')
|
| 313 |
+
print('-' * 10)
|
| 314 |
+
|
| 315 |
+
train_acc, train_loss = train_epoch(
|
| 316 |
+
model,
|
| 317 |
+
train_data_loader,
|
| 318 |
+
loss_fn,
|
| 319 |
+
optimizer,
|
| 320 |
+
device,
|
| 321 |
+
scheduler,
|
| 322 |
+
len(df_train)
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
print(f'Train loss {train_loss} accuracy {train_acc}')
|
| 326 |
+
|
| 327 |
+
val_acc, val_loss = eval_model(
|
| 328 |
+
model,
|
| 329 |
+
val_data_loader,
|
| 330 |
+
loss_fn,
|
| 331 |
+
device,
|
| 332 |
+
len(df_val)
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
print(f'Val loss {val_loss} accuracy {val_acc}')
|
| 336 |
+
print()
|
| 337 |
+
|
| 338 |
+
history['train_acc'].append(train_acc)
|
| 339 |
+
history['train_loss'].append(train_loss)
|
| 340 |
+
history['val_acc'].append(val_acc)
|
| 341 |
+
history['val_loss'].append(val_loss)
|
| 342 |
+
|
| 343 |
+
if val_acc > best_accuracy:
|
| 344 |
+
torch.save(model.state_dict(), 'best_model_state.bin')
|
| 345 |
+
best_accuracy = val_acc
|
|
|
|
|
|
|
| 346 |
|
| 347 |
print(history['train_acc'])
|
| 348 |
|