Commit
·
819fc2b
1
Parent(s):
c35bb2c
Update files/functions.py
Browse files- files/functions.py +74 -35
files/functions.py
CHANGED
|
@@ -70,36 +70,36 @@ label2color = {
|
|
| 70 |
|
| 71 |
# bounding boxes start and end of a sequence
|
| 72 |
cls_box = [0, 0, 0, 0]
|
|
|
|
|
|
|
| 73 |
sep_box_lilt = cls_box
|
|
|
|
|
|
|
| 74 |
sep_box_layoutxlm = [1000, 1000, 1000, 1000]
|
|
|
|
| 75 |
|
| 76 |
# models
|
| 77 |
model_id_lilt = "pierreguillou/lilt-xlm-roberta-base-finetuned-with-DocLayNet-base-at-paragraphlevel-ml512"
|
|
|
|
| 78 |
model_id_layoutxlm = "pierreguillou/layout-xlm-base-finetuned-with-DocLayNet-base-at-paragraphlevel-ml512"
|
|
|
|
| 79 |
|
| 80 |
# tokenizer for LayoutXLM
|
| 81 |
tokenizer_id_layoutxlm = "xlm-roberta-base"
|
| 82 |
|
| 83 |
# (tokenization) The maximum length of a feature (sequence)
|
| 84 |
-
if str(384) in model_id_lilt:
|
| 85 |
-
|
| 86 |
-
elif str(512) in model_id_lilt:
|
| 87 |
-
|
| 88 |
-
else:
|
| 89 |
-
print("Error with max_length_lilt of chunks!")
|
| 90 |
-
|
| 91 |
-
if str(384) in model_id_layoutxlm:
|
| 92 |
-
max_length_layoutxlm = 384
|
| 93 |
-
elif str(512) in model_id_layoutxlm:
|
| 94 |
-
max_length_layoutxlm = 512
|
| 95 |
else:
|
| 96 |
-
print("Error with
|
| 97 |
|
| 98 |
# (tokenization) overlap
|
| 99 |
doc_stride = 128 # The authorized overlap between two part of the context when splitting it is needed.
|
| 100 |
|
| 101 |
# max PDF page images that will be displayed
|
| 102 |
-
max_imgboxes =
|
| 103 |
|
| 104 |
# get files
|
| 105 |
examples_dir = 'files/'
|
|
@@ -159,6 +159,9 @@ tokenizer_lilt = AutoTokenizer.from_pretrained(model_id_lilt)
|
|
| 159 |
model_lilt = AutoModelForTokenClassification.from_pretrained(model_id_lilt);
|
| 160 |
model_lilt.to(device);
|
| 161 |
|
|
|
|
|
|
|
|
|
|
| 162 |
## model LayoutXLM
|
| 163 |
from transformers import LayoutLMv2ForTokenClassification # LayoutXLMTokenizerFast,
|
| 164 |
model_layoutxlm = LayoutLMv2ForTokenClassification.from_pretrained(model_id_layoutxlm);
|
|
@@ -172,14 +175,8 @@ feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
|
|
| 172 |
from transformers import AutoTokenizer
|
| 173 |
tokenizer_layoutxlm = AutoTokenizer.from_pretrained(tokenizer_id_layoutxlm)
|
| 174 |
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
label2id_lilt = model_lilt.config.label2id
|
| 178 |
-
num_labels_lilt = len(id2label_lilt)
|
| 179 |
-
|
| 180 |
-
id2label_layoutxlm = model_layoutxlm.config.id2label
|
| 181 |
-
label2id_layoutxlm = model_layoutxlm.config.label2id
|
| 182 |
-
num_labels_layoutxlm = len(id2label_layoutxlm)
|
| 183 |
|
| 184 |
|
| 185 |
# General
|
|
@@ -519,14 +516,10 @@ def extraction_data_from_image(images):
|
|
| 519 |
from datasets import Dataset
|
| 520 |
dataset = Dataset.from_dict({"images_ids": images_ids_list, "images": images_list, "images_pixels": images_pixels_list, "page_no": page_no_list, "num_pages": num_pages_list, "texts_line": texts_lines_list, "texts_par": texts_pars_list, "texts_lines_par": texts_lines_par_list, "bboxes_par": par_boxes_list, "bboxes_lines_par":lines_par_boxes_list})
|
| 521 |
|
| 522 |
-
|
| 523 |
# print(f"The text data was successfully extracted by the OCR!")
|
| 524 |
|
| 525 |
return dataset, texts_lines, texts_pars, texts_lines_par, row_indexes, par_boxes, line_boxes, lines_par_boxes
|
| 526 |
|
| 527 |
-
|
| 528 |
-
# Inference
|
| 529 |
-
|
| 530 |
def prepare_inference_features_paragraph(example, tokenizer, max_length, cls_box, sep_box):
|
| 531 |
|
| 532 |
images_ids_list, chunks_ids_list, input_ids_list, attention_mask_list, bb_list, images_pixels_list = list(), list(), list(), list(), list(), list()
|
|
@@ -711,8 +704,8 @@ def predictions_token_level(images, custom_encoded_dataset, model_id, model):
|
|
| 711 |
|
| 712 |
from functools import reduce
|
| 713 |
|
| 714 |
-
# Get predictions (
|
| 715 |
-
def
|
| 716 |
|
| 717 |
ten_probs_dict, ten_input_ids_dict, ten_bboxes_dict = dict(), dict(), dict()
|
| 718 |
bboxes_list_dict, input_ids_dict_dict, probs_dict_dict, df = dict(), dict(), dict(), dict()
|
|
@@ -788,24 +781,69 @@ def predictions_paragraph_level(max_length, tokenizer, id2label, dataset, output
|
|
| 788 |
prob_label = reduce(lambda x, y: x*y, probs_list)
|
| 789 |
prob_label = prob_label**(1./(len(probs_list))) # normalization
|
| 790 |
probs_label.append(prob_label)
|
| 791 |
-
max_value = max(probs_label)
|
| 792 |
-
max_index = probs_label.index(max_value)
|
| 793 |
-
probs_bbox[str(bbox)] = max_index
|
|
|
|
| 794 |
|
| 795 |
bboxes_list_dict[image_id] = bboxes_list
|
| 796 |
input_ids_dict_dict[image_id] = input_ids_dict
|
| 797 |
probs_dict_dict[image_id] = probs_bbox
|
| 798 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 799 |
df[image_id] = pd.DataFrame()
|
| 800 |
-
df[image_id]["bboxes"] =
|
| 801 |
-
df[image_id]["texts"] = [
|
| 802 |
-
df[image_id]["labels"] = [id2label[probs_bbox[str(bbox)]] for bbox in
|
| 803 |
|
| 804 |
-
return
|
| 805 |
|
| 806 |
else:
|
| 807 |
print("An error occurred while getting predictions!")
|
| 808 |
|
|
|
|
| 809 |
# Get labeled images with lines bounding boxes
|
| 810 |
def get_labeled_images(id2label, dataset, images_ids_list, bboxes_list_dict, probs_dict_dict):
|
| 811 |
|
|
@@ -925,4 +963,5 @@ def display_chunk_lines_inference(dataset, encoded_dataset, index_chunk=None):
|
|
| 925 |
print("\n>> Dataframe of annotated lines\n")
|
| 926 |
cols = ["texts", "bboxes"]
|
| 927 |
df = df[cols]
|
| 928 |
-
display(df)
|
|
|
|
|
|
| 70 |
|
| 71 |
# bounding boxes start and end of a sequence
|
| 72 |
cls_box = [0, 0, 0, 0]
|
| 73 |
+
cls_box1, cls_box2 = cls_box, cls_box
|
| 74 |
+
|
| 75 |
sep_box_lilt = cls_box
|
| 76 |
+
sep_box1 = sep_box_lilt
|
| 77 |
+
|
| 78 |
sep_box_layoutxlm = [1000, 1000, 1000, 1000]
|
| 79 |
+
sep_box2 = sep_box_layoutxlm
|
| 80 |
|
| 81 |
# models
|
| 82 |
model_id_lilt = "pierreguillou/lilt-xlm-roberta-base-finetuned-with-DocLayNet-base-at-paragraphlevel-ml512"
|
| 83 |
+
model_id1 = model_id_lilt
|
| 84 |
model_id_layoutxlm = "pierreguillou/layout-xlm-base-finetuned-with-DocLayNet-base-at-paragraphlevel-ml512"
|
| 85 |
+
model_id2 = model_id_layoutxlm
|
| 86 |
|
| 87 |
# tokenizer for LayoutXLM
|
| 88 |
tokenizer_id_layoutxlm = "xlm-roberta-base"
|
| 89 |
|
| 90 |
# (tokenization) The maximum length of a feature (sequence)
|
| 91 |
+
if (str(384) in model_id_lilt) and (str(384) in model_id_layoutxlm):
|
| 92 |
+
max_length = 384
|
| 93 |
+
elif (str(512) in model_id_lilt) and (str(512) in model_id_layoutxlm):
|
| 94 |
+
max_length = 512
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
else:
|
| 96 |
+
print("Error with max_length of chunks!")
|
| 97 |
|
| 98 |
# (tokenization) overlap
|
| 99 |
doc_stride = 128 # The authorized overlap between two part of the context when splitting it is needed.
|
| 100 |
|
| 101 |
# max PDF page images that will be displayed
|
| 102 |
+
max_imgboxes = 2
|
| 103 |
|
| 104 |
# get files
|
| 105 |
examples_dir = 'files/'
|
|
|
|
| 159 |
model_lilt = AutoModelForTokenClassification.from_pretrained(model_id_lilt);
|
| 160 |
model_lilt.to(device);
|
| 161 |
|
| 162 |
+
tokenizer1 = tokenizer_lilt
|
| 163 |
+
model1 = model_lilt
|
| 164 |
+
|
| 165 |
## model LayoutXLM
|
| 166 |
from transformers import LayoutLMv2ForTokenClassification # LayoutXLMTokenizerFast,
|
| 167 |
model_layoutxlm = LayoutLMv2ForTokenClassification.from_pretrained(model_id_layoutxlm);
|
|
|
|
| 175 |
from transformers import AutoTokenizer
|
| 176 |
tokenizer_layoutxlm = AutoTokenizer.from_pretrained(tokenizer_id_layoutxlm)
|
| 177 |
|
| 178 |
+
tokenizer2 = tokenizer_layoutxlm
|
| 179 |
+
model2 = model_layoutxlm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
|
| 182 |
# General
|
|
|
|
| 516 |
from datasets import Dataset
|
| 517 |
dataset = Dataset.from_dict({"images_ids": images_ids_list, "images": images_list, "images_pixels": images_pixels_list, "page_no": page_no_list, "num_pages": num_pages_list, "texts_line": texts_lines_list, "texts_par": texts_pars_list, "texts_lines_par": texts_lines_par_list, "bboxes_par": par_boxes_list, "bboxes_lines_par":lines_par_boxes_list})
|
| 518 |
|
|
|
|
| 519 |
# print(f"The text data was successfully extracted by the OCR!")
|
| 520 |
|
| 521 |
return dataset, texts_lines, texts_pars, texts_lines_par, row_indexes, par_boxes, line_boxes, lines_par_boxes
|
| 522 |
|
|
|
|
|
|
|
|
|
|
| 523 |
def prepare_inference_features_paragraph(example, tokenizer, max_length, cls_box, sep_box):
|
| 524 |
|
| 525 |
images_ids_list, chunks_ids_list, input_ids_list, attention_mask_list, bb_list, images_pixels_list = list(), list(), list(), list(), list(), list()
|
|
|
|
| 704 |
|
| 705 |
from functools import reduce
|
| 706 |
|
| 707 |
+
# Get predictions (paragraph level)
|
| 708 |
+
def predictions_probs_paragraph_level(max_length, tokenizer, id2label, dataset, outputs, images_ids_list, chunk_ids, input_ids, bboxes, cls_box, sep_box):
|
| 709 |
|
| 710 |
ten_probs_dict, ten_input_ids_dict, ten_bboxes_dict = dict(), dict(), dict()
|
| 711 |
bboxes_list_dict, input_ids_dict_dict, probs_dict_dict, df = dict(), dict(), dict(), dict()
|
|
|
|
| 781 |
prob_label = reduce(lambda x, y: x*y, probs_list)
|
| 782 |
prob_label = prob_label**(1./(len(probs_list))) # normalization
|
| 783 |
probs_label.append(prob_label)
|
| 784 |
+
# max_value = max(probs_label)
|
| 785 |
+
# max_index = probs_label.index(max_value)
|
| 786 |
+
# probs_bbox[str(bbox)] = max_index
|
| 787 |
+
probs_bbox[str(bbox)] = probs_label
|
| 788 |
|
| 789 |
bboxes_list_dict[image_id] = bboxes_list
|
| 790 |
input_ids_dict_dict[image_id] = input_ids_dict
|
| 791 |
probs_dict_dict[image_id] = probs_bbox
|
| 792 |
|
| 793 |
+
# df[image_id] = pd.DataFrame()
|
| 794 |
+
# df[image_id]["bboxes"] = bboxes_list
|
| 795 |
+
# df[image_id]["texts"] = [tokenizer.decode(input_ids_dict[str(bbox)]) for bbox in bboxes_list]
|
| 796 |
+
# df[image_id]["labels"] = [id2label[probs_bbox[str(bbox)]] for bbox in bboxes_list]
|
| 797 |
+
|
| 798 |
+
return probs_bbox, bboxes_list_dict, input_ids_dict_dict, probs_dict_dict #, df
|
| 799 |
+
|
| 800 |
+
else:
|
| 801 |
+
print("An error occurred while getting predictions!")
|
| 802 |
+
|
| 803 |
+
from functools import reduce
|
| 804 |
+
|
| 805 |
+
# Get predictions (paragraph level)
|
| 806 |
+
def predictions_paragraph_level(max_length, tokenizer1, id2label, dataset, outputs1, images_ids_list1, chunk_ids1, input_ids1, bboxes1, cls_box1, sep_box1, tokenizer2, outputs2, images_ids_list2, chunk_ids2, input_ids2, bboxes2, cls_box2, sep_box2):
|
| 807 |
+
|
| 808 |
+
bboxes_list_dict, input_ids_dict_dict, probs_dict_dict, df = dict(), dict(), dict(), dict()
|
| 809 |
+
|
| 810 |
+
probs_bbox1, bboxes_list_dict1, input_ids_dict_dict1, probs_dict_dict1 = predictions_probs_paragraph_level(max_length, tokenizer1, id2label, dataset, outputs1, images_ids_list1, chunk_ids1, input_ids1, bboxes1, cls_box1, sep_box1)
|
| 811 |
+
probs_bbox2, bboxes_list_dict2, input_ids_dict_dict2, probs_dict_dict2 = predictions_probs_paragraph_level(max_length, tokenizer2, id2label, dataset, outputs2, images_ids_list2, chunk_ids2, input_ids2, bboxes2, cls_box2, sep_box2)
|
| 812 |
+
|
| 813 |
+
if len(images_ids_list1) > 0:
|
| 814 |
+
|
| 815 |
+
for i, image_id in enumerate(images_ids_list1):
|
| 816 |
+
|
| 817 |
+
bboxes_list1 = bboxes_list_dict1[image_id]
|
| 818 |
+
input_ids_dict1 = input_ids_dict_dict1[image_id]
|
| 819 |
+
probs_bbox1 = probs_dict_dict1[image_id]
|
| 820 |
+
|
| 821 |
+
bboxes_list2 = bboxes_list_dict2[image_id]
|
| 822 |
+
input_ids_dict2 = input_ids_dict_dict2[image_id]
|
| 823 |
+
probs_bbox2 = probs_dict_dict2[image_id]
|
| 824 |
+
|
| 825 |
+
probs_bbox = dict()
|
| 826 |
+
for bbox in bboxes_list1:
|
| 827 |
+
prob_bbox = [(p1+p2)/2 for p1,p2 in zip(probs_bbox1[str(bbox)], probs_bbox2[str(bbox)])]
|
| 828 |
+
max_value = max(prob_bbox)
|
| 829 |
+
max_index = prob_bbox.index(max_value)
|
| 830 |
+
probs_bbox[str(bbox)] = max_index
|
| 831 |
+
|
| 832 |
+
bboxes_list_dict[image_id] = bboxes_list1
|
| 833 |
+
input_ids_dict_dict[image_id] = input_ids_dict1
|
| 834 |
+
probs_dict_dict[image_id] = probs_bbox
|
| 835 |
+
|
| 836 |
df[image_id] = pd.DataFrame()
|
| 837 |
+
df[image_id]["bboxes"] = bboxes_list1
|
| 838 |
+
df[image_id]["texts"] = [tokenizer1.decode(input_ids_dict1[str(bbox)]) for bbox in bboxes_list1]
|
| 839 |
+
df[image_id]["labels"] = [id2label[probs_bbox[str(bbox)]] for bbox in bboxes_list1]
|
| 840 |
|
| 841 |
+
return bboxes_list_dict, input_ids_dict_dict, probs_dict_dict, df
|
| 842 |
|
| 843 |
else:
|
| 844 |
print("An error occurred while getting predictions!")
|
| 845 |
|
| 846 |
+
|
| 847 |
# Get labeled images with lines bounding boxes
|
| 848 |
def get_labeled_images(id2label, dataset, images_ids_list, bboxes_list_dict, probs_dict_dict):
|
| 849 |
|
|
|
|
| 963 |
print("\n>> Dataframe of annotated lines\n")
|
| 964 |
cols = ["texts", "bboxes"]
|
| 965 |
df = df[cols]
|
| 966 |
+
display(df)
|
| 967 |
+
|