light-doc-ocr / core /train.py
nbonetto's picture
fix: trained model to 0.49 cer and fixed streamlit app issues
ee142e9
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from core.model import model, processor
from core.data import load
from core.utils import compute_metrics
from config import OUTPUT_DIR, BATCH_SIZE, EPOCHS
train_ds, eval_ds = load()
training_args = Seq2SeqTrainingArguments(
output_dir = OUTPUT_DIR,
per_device_train_batch_size = BATCH_SIZE,
per_device_eval_batch_size = BATCH_SIZE,
predict_with_generate = True,
eval_strategy = 'epoch',
logging_steps = 50,
num_train_epochs = EPOCHS,
save_total_limit = 1,
remove_unused_columns = False,
learning_rate = 5e-5,
fp16 = False
)
trainer = Seq2SeqTrainer(
model = model,
args = training_args,
train_dataset = train_ds,
eval_dataset = eval_ds,
processing_class = processor.image_processor,
compute_metrics = compute_metrics
)
def train_save():
trainer.train()
trainer.save_model('./model')
processor.save_pretrained('./model')
if __name__ == '__main__':
train_save()