File size: 3,788 Bytes
5aa312d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88

from utilities.utilities_common import *
from config.core import *
from transformers import Seq2SeqTrainer,Seq2SeqTrainingArguments
from transformers import default_data_collator, VisionEncoderDecoderModel
import os
import sys
from pathlib import Path
# os.path.join(os.path.dirname(__file__))
file = Path(__file__).resolve()
parent, root = file.parent, file.parents[1]
sys.path.append(str(root))

PACKAGE_ROOT = Path(__file__).resolve().parent
ROOT = PACKAGE_ROOT.parent
CONFIG_FILE_PATH = PACKAGE_ROOT / "config.yml"

DATASET_DIR = PACKAGE_ROOT / "dataset"
TRAINED_MODEL_DIR = PACKAGE_ROOT / "trained_models"
CAPTIONS_DIR = DATASET_DIR / "captions.txt"
IMAGES_DIR = DATASET_DIR / "Images"

def run_training(str_image_dir_path, df_train, df_validation, device):

    # transform the training and validation dataframes
    train_dataset = ImgDataset(df_train, root_dir=str_image_dir_path, tokenizer=tokenizer, feature_extractor=feature_extractor, transform=img_transforms)
    validation_dataset = ImgDataset(df_validation, root_dir=str_image_dir_path, tokenizer=tokenizer, feature_extractor=feature_extractor, transform=img_transforms)

    print("Encoder : ", config.lmodel_config.ENCODER)
    print("Decoder : ", config.lmodel_config.DECODER)

    # initialize the model
    model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(config.lmodel_config.ENCODER, config.lmodel_config.DECODER)
    print("Vocab Size : ", model.config.decoder.vocab_size)
    # set model config parameters
    model.config.decoder_start_token_id = tokenizer.cls_token_id
    model.config.pad_token_id = tokenizer.pad_token_id
    # make sure vocab size is set correctly
    model.config.vocab_size = model.config.decoder.vocab_size
    # set beam search parameters
    model.config.eos_token_id = tokenizer.sep_token_id
    model.config.decoder_start_token_id = tokenizer.bos_token_id
    model.config.max_length = config.lmodel_config.MAX_LEN
    model.config.early_stopping = config.lmodel_config.EARLY_STOPPING
    model.config.no_repeat_ngram_size = config.lmodel_config.NGRAM_SIZE
    model.config.length_penalty = config.lmodel_config.LEN_PENALTY
    model.config.num_beams = config.lmodel_config.NUM_BEAMS

    # define training arguments
    training_args = Seq2SeqTrainingArguments(
        output_dir=TRAINED_MODEL_DIR / 'VIT_large_gpt2',
        per_device_train_batch_size=config.lmodel_config.TRAIN_BATCH_SIZE,
        per_device_eval_batch_size=config.lmodel_config.VAL_BATCH_SIZE,
        predict_with_generate=True,
        evaluation_strategy="epoch",
        do_train=True,
        do_eval=True,
        logging_steps=config.lmodel_config.NUM_LOGGING_STEPS,
        save_steps=2 * config.lmodel_config.NUM_LOGGING_STEPS,
        warmup_steps=config.lmodel_config.NUM_LOGGING_STEPS,
        learning_rate=5e-5,
        max_steps=1500, # delete for full training
        num_train_epochs=config.lmodel_config.EPOCHS,  # TRAIN_EPOCHS
        overwrite_output_dir=True,
        save_total_limit=1,
    )

    # instantiate trainer
    trainer = Seq2SeqTrainer(
        tokenizer=feature_extractor,
        model=model,
        args=training_args,
        compute_metrics=compute_metrics,
        train_dataset=train_dataset,
        eval_dataset=validation_dataset,
        data_collator=default_data_collator,
    )
    trainer.train()

    # # save the trained model
    trainer.save_model(TRAINED_MODEL_DIR / 'VIT_large_gpt2')
    # print(df_test.iloc[0]["image"])
    print("Image dir : ", IMAGES_DIR)

    img =  Image.open(IMAGES_DIR / "1000268201_693b08cb0e.jpg").convert("RGB")
    img.show()
    generated_caption = tokenizer.decode(model.generate(feature_extractor(img, return_tensors="pt").pixel_values.to(device))[0])
    print('\033[96m' +generated_caption+ '\033[0m')