| | """ Finetuning example. |
| | """ |
| | from __future__ import print_function |
| | import sys |
| | import numpy as np |
| | from os.path import abspath, dirname |
| | sys.path.insert(0, dirname(dirname(abspath(__file__)))) |
| |
|
| | import json |
| | import math |
| | from torchmoji.model_def import torchmoji_transfer |
| | from torchmoji.global_variables import PRETRAINED_PATH, VOCAB_PATH |
| | from torchmoji.finetuning import ( |
| | load_benchmark, |
| | finetune) |
| | from torchmoji.class_avg_finetuning import class_avg_finetune |
| |
|
| | def roundup(x): |
| | return int(math.ceil(x / 10.0)) * 10 |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | DATASETS = [ |
| | |
| | |
| | |
| | |
| | ('SS-Youtube', '../data/SS-Youtube/raw.pickle', 2, False), |
| | |
| | |
| | |
| | ] |
| |
|
| | RESULTS_DIR = 'results' |
| |
|
| | |
| | FINETUNE_METHOD = 'last' |
| | VERBOSE = 1 |
| |
|
| | nb_tokens = 50000 |
| | nb_epochs = 1000 |
| | epoch_size = 1000 |
| |
|
| | with open(VOCAB_PATH, 'r') as f: |
| | vocab = json.load(f) |
| |
|
| | for rerun_iter in range(5): |
| | for p in DATASETS: |
| |
|
| | |
| | assert len(vocab) == nb_tokens |
| |
|
| | dset = p[0] |
| | path = p[1] |
| | nb_classes = p[2] |
| | use_f1_score = p[3] |
| |
|
| | if FINETUNE_METHOD == 'last': |
| | extend_with = 0 |
| | elif FINETUNE_METHOD in ['new', 'full', 'chain-thaw']: |
| | extend_with = 10000 |
| | else: |
| | raise ValueError('Finetuning method not recognised!') |
| |
|
| | |
| | data = load_benchmark(path, vocab, extend_with=extend_with) |
| |
|
| | (X_train, y_train) = (data['texts'][0], data['labels'][0]) |
| | (X_val, y_val) = (data['texts'][1], data['labels'][1]) |
| | (X_test, y_test) = (data['texts'][2], data['labels'][2]) |
| |
|
| | weight_path = PRETRAINED_PATH if FINETUNE_METHOD != 'new' else None |
| | nb_model_classes = 2 if use_f1_score else nb_classes |
| | model = torchmoji_transfer( |
| | nb_model_classes, |
| | weight_path, |
| | extend_embedding=data['added']) |
| | print(model) |
| |
|
| | |
| | print('Training: {}'.format(path)) |
| | if use_f1_score: |
| | model, result = class_avg_finetune(model, data['texts'], |
| | data['labels'], |
| | nb_classes, data['batch_size'], |
| | FINETUNE_METHOD, |
| | verbose=VERBOSE) |
| | else: |
| | model, result = finetune(model, data['texts'], data['labels'], |
| | nb_classes, data['batch_size'], |
| | FINETUNE_METHOD, metric='acc', |
| | verbose=VERBOSE) |
| |
|
| | |
| | if use_f1_score: |
| | print('Overall F1 score (dset = {}): {}'.format(dset, result)) |
| | with open('{}/{}_{}_{}_results.txt'. |
| | format(RESULTS_DIR, dset, FINETUNE_METHOD, rerun_iter), |
| | "w") as f: |
| | f.write("F1: {}\n".format(result)) |
| | else: |
| | print('Test accuracy (dset = {}): {}'.format(dset, result)) |
| | with open('{}/{}_{}_{}_results.txt'. |
| | format(RESULTS_DIR, dset, FINETUNE_METHOD, rerun_iter), |
| | "w") as f: |
| | f.write("Acc: {}\n".format(result)) |
| |
|