File size: 4,409 Bytes
b20f998
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import json
import os
import ast
from pathlib import Path
import datasets
from PIL import Image
import pandas as pd

logger = datasets.logging.get_logger(__name__)
_CITATION = """\
@article{,
  title={},
  author={},
  journal={},
  year={},
  volume={}
}
"""
_DESCRIPTION = """\
This is a sample dataset for training layoutlmv3 model on custom annotated data.
"""

def load_image(image_path):
    image = Image.open(image_path).convert("RGB")
    w, h = image.size
    return image, (w,h)

def normalize_bbox(bbox, size):
    return [
        int(1000 * bbox[0] / size[0]),
        int(1000 * bbox[1] / size[1]),
        int(1000 * bbox[2] / size[0]),
        int(1000 * bbox[3] / size[1]),
    ]


_URLS = []

'''Edit your working directory folder path here if required. 
If this file is in the same folder as the "layoutlmv3" folder keep it as it is.
'''
data_path = r'./' 

class DatasetConfig(datasets.BuilderConfig):
    """BuilderConfig for InvoiceExtraction Dataset"""
    def __init__(self, **kwargs):
        """BuilderConfig for InvoiceExtraction Dataset.
        Args:
          **kwargs: keyword arguments forwarded to super.
        """
        super(DatasetConfig, self).__init__(**kwargs)


class InvoiceExtraction(datasets.GeneratorBasedBuilder):
    BUILDER_CONFIGS = [
        DatasetConfig(name="InvoiceExtraction", version=datasets.Version("1.0.0"), description="InvoiceExtraction dataset"),
    ]

    def _info(self):
        return datasets.DatasetInfo(
            description=_DESCRIPTION,
            features=datasets.Features(
                {
                    "id": datasets.Value("string"),
                    "tokens": datasets.Sequence(datasets.Value("string")),
                    "bboxes": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))),
                    "ner_tags": datasets.Sequence(
                        datasets.features.ClassLabel(
                            names = ['num_facture','date_facture','fournisseur','client','mat_client','mat_fournisseur','tva','pourcentage_tva','remise','pourcentage_remise','timbre','fodec','ttc','devise','net_ht'] #Enter the list of labels that you have here.
                            )
                    ),
                    "image_path": datasets.Value("string"),
                    "image": datasets.features.Image()
                }
            ),
            supervised_keys=None,
            citation=_CITATION,
            homepage="",
        )




    def _split_generators(self, dl_manager):
        """Returns SplitGenerators."""
        """Uses local files located with data_dir"""
        dest = os.path.join(data_path, 'layoutlmv3')

        return [
            datasets.SplitGenerator(
                name=datasets.Split.TRAIN, gen_kwargs={"filepath": os.path.join(dest, "train.txt"), "dest": dest}
            ),            
            datasets.SplitGenerator(
                name=datasets.Split.TEST, gen_kwargs={"filepath": os.path.join(dest, "test.txt"), "dest": dest}
            ),
        ]

    def _generate_examples(self, filepath, dest):

        df = pd.read_csv(os.path.join(dest, 'class_list.txt'), delimiter=',', header=None)
        id2labels = dict(zip(df[0].tolist(), df[1].tolist()))


        logger.info("⏳ Generating examples from = %s", filepath)

        item_list = []
        with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
            for line in f:
                item_list.append(line.rstrip('\n\r'))
        print(item_list)
        for guid, fname in enumerate(item_list):
            print(fname)
            data = ast.literal_eval(fname)
            image_path = os.path.join(dest, data['file_name'])
            image, size = load_image(image_path)
            boxes = data['bboxes']

            text = data['tokens']
            label = data['ner_tags']
            
            #print(boxes)
            #for i in boxes:
            #  print(i)
            boxes = [normalize_bbox(box, size) for box in boxes]
            flag=0
            #print(image_path)
            for i in boxes:
              #print(i)
              for j in i:
                if j>1000:
                  flag+=1
                  #print(j)
                  pass
            if flag>0: print(image_path)
 
            yield guid, {"id": str(guid), "tokens": text, "bboxes": boxes, "ner_tags": label, "image_path": image_path, "image": image}