HariNuvo commited on
Commit
6320d54
·
1 Parent(s): 3a9699c

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ *.ipynb_checkpoints
2
+ __pycache__
3
+ *.json
Dockerfile ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10
2
+ ENV PATH="/root/miniconda3/bin:${PATH}"
3
+ ARG PATH="/root/miniconda3/bin:${PATH}"
4
+ RUN apt-get update
5
+
6
+ RUN apt-get install -y wget tesseract-ocr libtesseract-dev && rm -rf /var/lib/apt/lists/*
7
+
8
+ RUN wget \
9
+ https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
10
+ && mkdir /root/.conda \
11
+ && bash Miniconda3-latest-Linux-x86_64.sh -b \
12
+ && rm -f Miniconda3-latest-Linux-x86_64.sh
13
+ RUN conda --version
14
+
15
+ RUN conda install -c conda-forge poppler -y
16
+
17
+ WORKDIR /code
18
+
19
+ COPY ./requirements.txt /code/requirements.txt
20
+
21
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
22
+
23
+ COPY . .
24
+ RUN useradd -m -u 1000 user
25
+ USER user
26
+ ENV HOME=/home/user \
27
+ PATH=/home/user/.local/bin:$PATH
28
+
29
+ WORKDIR $HOME/app
30
+
31
+ COPY --chown=user . $HOME/app
32
+
33
+ RUN echo $pwd
34
+
35
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from inference import Inference
3
+
4
+
5
+ def predict_url_class(url):
6
+ """Predicts the class of the given pdf url. Creates the output necessary for gradio Label."""
7
+ inference = Inference(pdf_url=url)
8
+ try:
9
+ outputs = inference.predict()
10
+ except Exception as e:
11
+ gr.Warning(e)
12
+ output_for_gradio = {
13
+ "Lighting": outputs[0],
14
+ "Non-Lighting": outputs[1],
15
+ }
16
+ return output_for_gradio
17
+
18
+
19
+ def main():
20
+ # Define Gradio interface
21
+ description = "<p>The model in trained on a number of PDFs related to lighting and non-lighting products. The model takes an URL as input and predicts whether the product in the PDF corresponds to a Ligthing product or not. The model may take upto 30 second to make a prediction. This is because we need to first extract textual, tabular and image information from various pages of the PDF and this may a long time. Make sure that the URL provided is unblocked and can be downloaded without any extra steps.</p>"
22
+ inputs = gr.Text(lines=1, placeholder="Enter the url of the PDF", label="URL")
23
+ outputs = gr.Label(
24
+ num_top_classes=2,
25
+ label="Prediction",
26
+ every=2,
27
+ )
28
+ gradio_app = gr.Interface(
29
+ fn=predict_url_class,
30
+ inputs=inputs,
31
+ outputs=outputs,
32
+ title="PDF",
33
+ description=description,
34
+ theme="snehilsanyal/scikit-learn",
35
+ examples=[
36
+ [
37
+ "https://www.topbrasslighting.com/wp-content/uploads/TopBrass-138.01-tearsheet-Jun12018.pdf"
38
+ ],
39
+ ["https://lyntec.com/wp-content/uploads/2018/12/LynTec-XPC-Brochure.pdf"],
40
+ ],
41
+ allow_flagging="never",
42
+ )
43
+ gradio_app.queue().launch(server_name="0.0.0.0", server_port=7860)
44
+
45
+
46
+ if __name__ == "__main__":
47
+ # Run Gradio app
48
+ main()
create_dataset.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utilities import get_simple_logger, PDFExtractor, ModuleException
2
+ import pandas as pd
3
+ import os, glob
4
+ import json
5
+ from tqdm.auto import tqdm
6
+ import re
7
+
8
+ file_dir = os.path.dirname(os.path.realpath(__file__))
9
+ csv_file_dir = os.path.join(file_dir, "materials")
10
+ data_dir = os.path.join(file_dir, "data")
11
+ train_json_dir = os.path.join(data_dir, "train_jsons")
12
+ test_json_dir = os.path.join(data_dir, "test_jsons")
13
+
14
+ logger = get_simple_logger("create_dataset")
15
+
16
+
17
+ def save_json(data, split="train"):
18
+ id_ = data["id"]
19
+ if split == "train":
20
+ to_save = os.path.join(train_json_dir, f"{id_}.json")
21
+ else:
22
+ to_save = os.path.join(test_json_dir, f"{id_}.json")
23
+ logger.debug(f"Saving the json to {to_save}")
24
+ with open(to_save, "w") as f:
25
+ json.dump(data, f)
26
+
27
+
28
+ def clean_text(text):
29
+ # Clean the text by
30
+ # remove extra whitespace
31
+ regex = re.compile(r"\s{2,}")
32
+ text = regex.sub(" ", text)
33
+ # removing more than one new line with a single new line
34
+ regex = re.compile(r"\n{2,}")
35
+ text = regex.sub("\n", text)
36
+ # if each line has less than 3 characters, remove it
37
+ lines = text.split("\n")
38
+ lines = [line for line in lines if len(line) > 3]
39
+ text = "\n".join(lines)
40
+ # cap max to 10000
41
+ text = text[:10000]
42
+ return text
43
+
44
+
45
+ def create_id(id_, split):
46
+ id_ = int(id_)
47
+ if split == "train":
48
+ return f"P-{id_}"
49
+ else:
50
+ return f"TP{id_}"
51
+
52
+
53
+ def create_json(split="train"):
54
+ """Creates the dataset from the csv file and saves it to the data_dir
55
+
56
+ Parameters
57
+ ----------
58
+ split : str, optional
59
+ The split to create the dataset for, by default "train"
60
+ """
61
+ logger.info(f"Creating the dataset for {split}")
62
+ df_path = os.path.join(csv_file_dir, f"parspec_{split}_data.csv")
63
+ df = pd.read_csv(df_path)
64
+ df.dropna(inplace=True)
65
+ json_dir = train_json_dir if split == "train" else test_json_dir
66
+ os.makedirs(json_dir, exist_ok=True)
67
+
68
+ # already extracted files
69
+ extracted_files = os.listdir(json_dir)
70
+ extracted_files = list(map(lambda x: x.split(".")[0], os.listdir(train_json_dir)))
71
+ logger.info(f"{len(extracted_files)} files are already extracted.")
72
+
73
+ for i, row in tqdm(
74
+ df.iterrows(),
75
+ desc="extracting information...",
76
+ total=len(df) - len(extracted_files),
77
+ ):
78
+ # if i == 3:
79
+ # break
80
+ id_ = row["ID"]
81
+ if "-" in id_:
82
+ # for train
83
+ id_ = id_.split("-")[1]
84
+ else:
85
+ # for test
86
+ id_ = id_[2:]
87
+ id_ = id_.zfill(4)
88
+ if id_ in extracted_files:
89
+ logger.debug(f"File {id_} already extracted")
90
+ continue
91
+ logger.info(f"Extracting the file for ID {id_}")
92
+ url = row["URL"]
93
+ label = 1 if row["Is lighting product?"] in [1, "Yes"] else 0
94
+ try:
95
+ pdf_extractor = PDFExtractor(
96
+ file_path=url,
97
+ is_url=True,
98
+ min_characters=5,
99
+ maximum_pages=3,
100
+ )
101
+ final = pdf_extractor.extract_pages()
102
+ data = {
103
+ "status": "ok",
104
+ "id": id_,
105
+ "label": label,
106
+ "page_contents": pdf_extractor.page_contents,
107
+ "final_content": clean_text(final),
108
+ "url": url,
109
+ }
110
+ # save the json
111
+ except ModuleException:
112
+ logger.error(f"Url is not valid for ID {id_}. Using Null values.")
113
+ data = {
114
+ "status": "error",
115
+ "id": id_,
116
+ "label": label,
117
+ "page_contents": None,
118
+ "final_content": None,
119
+ "url": url,
120
+ }
121
+ save_json(data, split)
122
+
123
+
124
+ def create_dataframe(split):
125
+ df_path = os.path.join(csv_file_dir, f"parspec_{split}_data.csv")
126
+ df = pd.read_csv(df_path)
127
+ json_dir = train_json_dir if split == "train" else test_json_dir
128
+ json_files = glob.glob(f"{json_dir}/*.json")
129
+ statuss = []
130
+ ids = []
131
+ labels = []
132
+ contents = []
133
+ urls = []
134
+ for file in tqdm(json_files, "creating dataframe..."):
135
+ with open(file, "r") as f:
136
+ data = json.load(f)
137
+ if data["status"] == "error":
138
+ continue
139
+ statuss.append(data["status"])
140
+ ids.append(create_id(data["id"], split=split))
141
+ labels.append(data["label"])
142
+ contents.append(clean_text(data["final_content"]))
143
+ urls.append(data["url"])
144
+
145
+ final_df = pd.DataFrame(
146
+ {
147
+ "status": statuss,
148
+ "id": ids,
149
+ "label": labels,
150
+ "content": contents,
151
+ "url": urls,
152
+ }
153
+ )
154
+ final = pd.merge(final_df, df, left_on="id", right_on="ID")[
155
+ ["id", "content", "Is lighting product?", "url"]
156
+ ]
157
+ final.rename(columns={"Is lighting product?": "label"}, inplace=True)
158
+ final["label"] = final["label"].map(
159
+ {
160
+ "Yes": 1,
161
+ "No": 0,
162
+ }
163
+ )
164
+ final.to_csv(
165
+ os.path.join(data_dir, f"{split}.csv"), index=False, escapechar="\\"
166
+ ) # setting escapechar is required
167
+ return final
168
+
169
+
170
+ if __name__ == "__main__":
171
+ create_dataframe(split="test")
encoder/config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "sentence-transformers/paraphrase-MiniLM-L3-v2",
3
+ "architectures": [
4
+ "BertModel"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "classifier_dropout": null,
8
+ "gradient_checkpointing": false,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 384,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 1536,
14
+ "layer_norm_eps": 1e-12,
15
+ "max_position_embeddings": 512,
16
+ "model_type": "bert",
17
+ "num_attention_heads": 12,
18
+ "num_hidden_layers": 3,
19
+ "pad_token_id": 0,
20
+ "position_embedding_type": "absolute",
21
+ "torch_dtype": "float32",
22
+ "transformers_version": "4.35.2",
23
+ "type_vocab_size": 2,
24
+ "use_cache": true,
25
+ "vocab_size": 30522
26
+ }
encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f59d73a201e0f6092d7e88ac8589f886a946725cff96d0250231d7e272e63071
3
+ size 69565312
encoder/special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
encoder/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
encoder/tokenizer_config.json ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": true,
45
+ "cls_token": "[CLS]",
46
+ "do_basic_tokenize": true,
47
+ "do_lower_case": true,
48
+ "mask_token": "[MASK]",
49
+ "max_length": 128,
50
+ "model_max_length": 512,
51
+ "never_split": null,
52
+ "pad_to_multiple_of": null,
53
+ "pad_token": "[PAD]",
54
+ "pad_token_type_id": 0,
55
+ "padding_side": "right",
56
+ "sep_token": "[SEP]",
57
+ "stride": 0,
58
+ "strip_accents": null,
59
+ "tokenize_chinese_chars": true,
60
+ "tokenizer_class": "BertTokenizer",
61
+ "truncation_side": "right",
62
+ "truncation_strategy": "longest_first",
63
+ "unk_token": "[UNK]"
64
+ }
encoder/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
final_model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c9ebd5d8ba12c83ccd63476660c075cace91eca79b56f650e3db48431788c8f5
3
+ size 5788529
inference.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utilities import (
2
+ PDFExtractor,
3
+ get_simple_logger,
4
+ ModuleException,
5
+ )
6
+ from create_dataset import clean_text
7
+ from model import PDFDataSet
8
+ import pickle
9
+ import argparse
10
+
11
+
12
+ class Inference:
13
+ """A class that does inference given a url of the pdf. Uses the saved RandomForest model for inference."""
14
+
15
+ def __init__(self, pdf_url, model_path="final_model.pkl") -> None:
16
+ self.pdf_url = pdf_url
17
+ self.model_path = model_path
18
+ self.logger = get_simple_logger("inference", level="debug")
19
+
20
+ def _extract_text(self):
21
+ self.logger.debug("Extracting text from the PDF.")
22
+ pdf_extractor = PDFExtractor(
23
+ self.pdf_url,
24
+ min_characters=5,
25
+ maximum_pages=3,
26
+ )
27
+ try:
28
+ final = pdf_extractor.extract_pages()
29
+ data = {
30
+ "status": "ok",
31
+ "id": "001",
32
+ "page_contents": pdf_extractor.page_contents,
33
+ "final_content": clean_text(final),
34
+ "url": self.pdf_url,
35
+ }
36
+ except ModuleException as e:
37
+ self.logger.warning("A module exception is raised.")
38
+ self.logger.error(e)
39
+ raise e
40
+ except Exception as e:
41
+ self.logger.error(e)
42
+ raise e
43
+ return data, final
44
+
45
+ def _load_model(self):
46
+ with open(self.model_path, "rb") as f:
47
+ model = pickle.load(f)
48
+ self.model = model
49
+ return self.model
50
+
51
+ def _create_embedding(self, text):
52
+ dataset = PDFDataSet()
53
+ embedding = dataset.sentences_to_embedding([text])
54
+ return embedding.reshape(1, -1)
55
+
56
+ def _pretty_print_probability(self, p):
57
+ non_lighting_probability = p[0] * 100
58
+ lighting_probablity = p[1] * 100
59
+ if lighting_probablity > non_lighting_probability:
60
+ print(
61
+ f"This is a Lighting product with a probability of {lighting_probablity:.2f}%"
62
+ )
63
+ else:
64
+ print(
65
+ f"This is a Non-Lighting product with a probability of {non_lighting_probability:.2f}%"
66
+ )
67
+
68
+ def predict(self):
69
+ _, sentence = self._extract_text()
70
+ embedding = self._create_embedding(sentence)
71
+ model = self._load_model()
72
+ prediction = model.predict_proba(embedding)
73
+ prediction = prediction[0]
74
+ self._pretty_print_probability(prediction)
75
+ return prediction
76
+
77
+
78
+ def main(args):
79
+ inference = Inference(args.url)
80
+ prediction = inference.predict()
81
+ print(prediction)
82
+
83
+
84
+ if __name__ == "__main__":
85
+ parser = argparse.ArgumentParser()
86
+ parser.add_argument("--url", type=str, required=True)
87
+ args = parser.parse_args()
88
+ main(args)
model.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModel
2
+ from datasets import load_dataset, Dataset, concatenate_datasets
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch.utils.data import DataLoader
6
+ from sklearn.model_selection import train_test_split
7
+ from sklearn.metrics import (
8
+ classification_report,
9
+ confusion_matrix,
10
+ accuracy_score,
11
+ precision_score,
12
+ )
13
+ from sklearn.ensemble import RandomForestClassifier
14
+ from xgboost import XGBClassifier
15
+ import torch.nn as nn
16
+ import torchmetrics
17
+ from torch.optim.lr_scheduler import CosineAnnealingLR
18
+
19
+ import numpy as np
20
+ import pandas as pd
21
+ import os
22
+ import pickle
23
+ import argparse
24
+ from torch_train import TorchTrain
25
+ from utilities import get_simple_logger
26
+
27
+ FILE_DIR = os.path.dirname(os.path.realpath(__file__))
28
+ DATA_DIR = os.path.join(FILE_DIR, "data")
29
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
30
+ random_state = 42
31
+ # set random state
32
+ np.random.seed(random_state)
33
+ torch.manual_seed(random_state)
34
+
35
+
36
+ class PDFDataLoader:
37
+ def __init__(self, df):
38
+ self.df = df
39
+
40
+ def __getitem__(self, idx):
41
+ row = self.df[idx]
42
+ embeddings = row["embeddings"]
43
+ label = row["label"]
44
+ # convert to torch int
45
+ label = np.array(label)
46
+ # add extra dimension to label
47
+ label = np.expand_dims(label, axis=0)
48
+ embeddings = torch.from_numpy(np.array(embeddings)).float()
49
+ return embeddings.to(device), torch.from_numpy(label).to(device).float()
50
+
51
+ def __len__(self):
52
+ return len(self.df)
53
+
54
+
55
+ class PDFDataSet:
56
+ def __init__(
57
+ self,
58
+ data_dir=DATA_DIR,
59
+ fraction_test_data_in_train=0.2,
60
+ model_ckpt="encoder",
61
+ ) -> None:
62
+ self.data_dir = data_dir
63
+ self.fraction_test_data_in_train = fraction_test_data_in_train
64
+ self.model_ckpt = model_ckpt
65
+ tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
66
+ encoding_model = AutoModel.from_pretrained(model_ckpt)
67
+ encoding_model = encoding_model.to(device)
68
+ encoding_model = encoding_model.eval()
69
+ self.encoding_model = encoding_model
70
+ self.tokenizer = tokenizer
71
+ self.logger = get_simple_logger("pdf_dataset")
72
+
73
+ def create_datasets(self):
74
+ train_data_path = os.path.join(FILE_DIR, self.data_dir, "train.csv")
75
+ test_data_path = os.path.join(FILE_DIR, self.data_dir, "test.csv")
76
+ df = pd.read_csv(train_data_path)
77
+ test_df = pd.read_csv(test_data_path)
78
+ train_df, validation_df = train_test_split(df, test_size=0.3, random_state=42)
79
+ if self.fraction_test_data_in_train:
80
+ self.logger.info(
81
+ f"Adding {self.fraction_test_data_in_train} fraction of test dataset to the training set."
82
+ )
83
+ test_df, test_df_for_training = train_test_split(
84
+ test_df, test_size=self.fraction_test_data_in_train, random_state=42
85
+ )
86
+ train_df = pd.concat([train_df, test_df_for_training])
87
+
88
+ train_dataset = Dataset.from_pandas(train_df)
89
+ validation_dataset = Dataset.from_pandas(validation_df)
90
+ test_dataset = Dataset.from_pandas(test_df)
91
+ return train_dataset, validation_dataset, test_dataset
92
+
93
+ def mean_pooling(self, model_output, attention_mask):
94
+ token_embeddings = model_output[
95
+ 0
96
+ ] # First element of model_output contains all token embeddings
97
+ input_mask_expanded = (
98
+ attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
99
+ )
100
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
101
+ input_mask_expanded.sum(1), min=1e-9
102
+ )
103
+
104
+ def sentences_to_embedding(self, sentences):
105
+ # Tokenize sentences
106
+ encoded_input = self.tokenizer(
107
+ sentences, padding=True, truncation=True, return_tensors="pt"
108
+ )
109
+ sentence_embeddings = self.mean_pooling(
110
+ self.encoding_model(**encoded_input), encoded_input["attention_mask"]
111
+ )
112
+ sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
113
+ # remove last dimension
114
+ sentence_embeddings = sentence_embeddings.squeeze()
115
+ return sentence_embeddings.detach()
116
+
117
+ def get_embeddings(self, row):
118
+ return {
119
+ "embeddings": self.sentences_to_embedding(
120
+ sentences=row["content"],
121
+ )
122
+ }
123
+
124
+ def create_embeddings(self):
125
+ train_dataset, validation_dataset, test_dataset = self.create_datasets()
126
+ train_dataset = train_dataset.map(self.get_embeddings)
127
+ validation_dataset = validation_dataset.map(self.get_embeddings)
128
+ test_dataset = test_dataset.map(self.get_embeddings)
129
+ return train_dataset, validation_dataset, test_dataset
130
+
131
+
132
+ class PDFModel(nn.Module):
133
+ def __init__(self, input_size, hidden_sizes, output_size):
134
+ super(PDFModel, self).__init__()
135
+ self.seq_model = nn.Sequential()
136
+ for i, hidden_size in enumerate(hidden_sizes):
137
+ self.seq_model.add_module(f"linear_{i}", nn.Linear(input_size, hidden_size))
138
+ self.seq_model.add_module(f"relu_{i}", nn.ReLU())
139
+ input_size = hidden_size
140
+ self.last_layer = nn.Linear(input_size, output_size)
141
+ self.sigmoid = nn.Sigmoid()
142
+
143
+ def forward(self, x):
144
+ seq_out = self.seq_model(x)
145
+ out = self.last_layer(seq_out)
146
+ return self.sigmoid(out)
147
+
148
+
149
+ def evaluate_model(y_true, y_pred, model_name, split="train"):
150
+ accuracy = accuracy_score(y_true, y_pred)
151
+ precision = precision_score(y_true, y_pred)
152
+ classification_report_ = classification_report(y_true, y_pred)
153
+ print("------" * 10)
154
+ print(f"Evaluating for the model: {model_name} for {split} dataset...")
155
+ print(f"Accuracy: {accuracy}")
156
+ print(f"Precision: {precision}")
157
+ print(classification_report_)
158
+ print("------" * 10)
159
+
160
+
161
+ def train_dl_model(
162
+ train_data,
163
+ validation_data,
164
+ epochs=30,
165
+ input_shape=384,
166
+ hidden_sizes=[32, 16],
167
+ ):
168
+ model = PDFModel(input_size=input_shape, hidden_sizes=hidden_sizes, output_size=1)
169
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
170
+ loss_fn = nn.BCELoss()
171
+ accuracy = torchmetrics.Accuracy(
172
+ task="binary", num_classes=2, threshold=0.5, average="macro"
173
+ )
174
+ precision = torchmetrics.Precision(task="binary", average="macro")
175
+ metrics = {
176
+ "accuracy": accuracy,
177
+ "precision": precision,
178
+ }
179
+ scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=0.0001)
180
+ tt = TorchTrain(model, optimizer, loss_fn, metrics=metrics, scheduler=scheduler)
181
+ history = tt.fit(train_data, validation_data, verbose=True, epochs=epochs)
182
+ return history, model
183
+
184
+
185
+ def evaluate_models(fraction_test_data_in_train=0.1):
186
+ print("Creating Embeddings...")
187
+ ds = PDFDataSet(fraction_test_data_in_train=fraction_test_data_in_train)
188
+ train_dataset, validation_dataset, test_dataset = ds.create_embeddings()
189
+ print("Done\n")
190
+
191
+ print("Training DL Model")
192
+ # Create dataset for DL models:
193
+ BATCH_SIZE = 8
194
+ train_dataloader = PDFDataLoader(train_dataset)
195
+ validation_dataloader = PDFDataLoader(validation_dataset)
196
+ test_dataloader = PDFDataLoader(test_dataset)
197
+
198
+ train_data = DataLoader(train_dataloader, batch_size=BATCH_SIZE, shuffle=True)
199
+ validation_data = DataLoader(
200
+ validation_dataloader,
201
+ batch_size=BATCH_SIZE,
202
+ shuffle=True,
203
+ )
204
+ test_data = DataLoader(test_dataloader, batch_size=BATCH_SIZE, shuffle=True)
205
+ for X, y in train_data:
206
+ input_shape = int(X.shape[1])
207
+ output_shape = int(y.shape[1])
208
+ break
209
+ epochs = 30
210
+ hidden_sizes = [32, 16]
211
+ history, model = train_dl_model(
212
+ train_data=train_data,
213
+ validation_data=validation_data,
214
+ epochs=epochs,
215
+ hidden_sizes=hidden_sizes,
216
+ )
217
+ print("Done\n")
218
+ print("Evaluating DL Model")
219
+ y_test_pred = model(torch.from_numpy(np.array(test_dataset["embeddings"])).float())
220
+ y_test_pred = y_test_pred.detach().numpy()
221
+ y_test_pred = np.where(y_test_pred > 0.5, 1, 0)
222
+ evaluate_model(
223
+ y_true=test_dataset["label"],
224
+ y_pred=y_test_pred,
225
+ model_name="DL Model",
226
+ split="test",
227
+ )
228
+ print("Done\n")
229
+
230
+ # ML Models
231
+ print("Training and evaluating ML Models.")
232
+ X_train = train_dataset["embeddings"]
233
+ y_train = train_dataset["label"]
234
+ X_validation = validation_dataset["embeddings"]
235
+ y_validation = validation_dataset["label"]
236
+ X_test = test_dataset["embeddings"]
237
+ y_test = test_dataset["label"]
238
+ rfc_best_params = {
239
+ "max_depth": 23,
240
+ "max_features": "log2",
241
+ "n_estimators": 469,
242
+ }
243
+
244
+ xgb_best_params = {
245
+ "max_depth": 25,
246
+ "n_estimators": 372,
247
+ "learning_rate": 0.2522824287799319,
248
+ }
249
+ print("Fitting RandomForest")
250
+ rfc = RandomForestClassifier(**rfc_best_params)
251
+ rfc.fit(X_train, y_train)
252
+ evaluate_model(
253
+ y_true=y_train,
254
+ y_pred=rfc.predict(X_train),
255
+ model_name="RandomForest",
256
+ split="train",
257
+ )
258
+ evaluate_model(
259
+ y_true=y_validation,
260
+ y_pred=rfc.predict(X_validation),
261
+ model_name="RandomForest",
262
+ split="validation",
263
+ )
264
+ evaluate_model(
265
+ y_true=y_test,
266
+ y_pred=rfc.predict(X_test),
267
+ model_name="RandomForest",
268
+ split="test",
269
+ )
270
+
271
+ print("Fitting XGBoost")
272
+ xgb = XGBClassifier(**xgb_best_params)
273
+ xgb.fit(X_train, y_train)
274
+ evaluate_model(
275
+ y_true=y_train,
276
+ y_pred=xgb.predict(X_train),
277
+ model_name="XGBoost",
278
+ split="train",
279
+ )
280
+ evaluate_model(
281
+ y_true=y_validation,
282
+ y_pred=xgb.predict(X_validation),
283
+ model_name="XGBoost",
284
+ split="validation",
285
+ )
286
+ evaluate_model(
287
+ y_true=y_test,
288
+ y_pred=xgb.predict(X_test),
289
+ model_name="XGBoost",
290
+ split="test",
291
+ )
292
+ print("All Done")
293
+
294
+
295
+ def train_and_save_final_model(model_save_path="final_model.pkl"):
296
+ """This method creats and save the final model. The final model has the following characterstics:
297
+
298
+ - It is a RandomForestClassifier trained on all the training data and 10% of the test data. 10% of the test data. The 10% of test data is necessary as the distribution of the test data is very different from the training data.
299
+ - Since 10% of test data is used while training, this data is not used while claculating the final accuracy of the model, which is 100%.
300
+
301
+ Parameters
302
+ ----------
303
+ model_save_path : str, optional
304
+ The path to save the final model, by default "final_model.pkl"
305
+ Returns
306
+ -------
307
+ None
308
+ Examples
309
+ --------
310
+ >>> train_and_save_final_model()
311
+ >>> train_and_save_final_model(model_save_path="final_model.pkl")
312
+ """
313
+ print("Creating Embeddings...")
314
+ model_save_path = os.path.join(FILE_DIR, model_save_path)
315
+ ds = PDFDataSet(fraction_test_data_in_train=0.1)
316
+ train_dataset, validation_dataset, test_dataset = ds.create_embeddings()
317
+ train_dataset = concatenate_datasets([train_dataset, validation_dataset])
318
+ X_train = train_dataset["embeddings"]
319
+ X_test = test_dataset["embeddings"]
320
+ y_train = train_dataset["label"]
321
+ y_test = test_dataset["label"]
322
+
323
+ print("Training and evaluating the model...")
324
+ rfc_best_params = {
325
+ "max_depth": 23,
326
+ "max_features": "log2",
327
+ "n_estimators": 469,
328
+ }
329
+ rfc_model = RandomForestClassifier(**rfc_best_params)
330
+ rfc_model.fit(X_train, y_train)
331
+ evaluate_model(
332
+ y_true=y_train,
333
+ y_pred=rfc_model.predict(X_train),
334
+ model_name="Final Model",
335
+ split="train",
336
+ )
337
+ evaluate_model(
338
+ y_true=y_test,
339
+ y_pred=rfc_model.predict(X_test),
340
+ model_name="Final Model",
341
+ split="test",
342
+ )
343
+
344
+ print("Saving the model...")
345
+ with open(model_save_path, "wb") as f:
346
+ pickle.dump(rfc_model, f)
347
+ print(f"Model saved to: {model_save_path}")
348
+
349
+
350
+ def main(args):
351
+ task = args.task
352
+ if task == "train":
353
+ model_save_path = args.model_save_path
354
+ train_and_save_final_model(model_save_path=model_save_path)
355
+ elif task == "evaluate":
356
+ fraction_test_data_in_train = args.fraction
357
+ evaluate_models(fraction_test_data_in_train)
358
+
359
+
360
+ if __name__ == "__main__":
361
+ parser = argparse.ArgumentParser(description="Train and evaluate models")
362
+ parser.add_argument(
363
+ "--task",
364
+ type=str,
365
+ choices=["train", "evaluate"],
366
+ required=True,
367
+ help="Whether to train and save the best model or evaluate all the models.",
368
+ )
369
+ parser.add_argument(
370
+ "--fraction",
371
+ type=float,
372
+ default=0.1,
373
+ help="Fraction of test data in train dataset",
374
+ )
375
+ parser.add_argument(
376
+ "--model_save_path",
377
+ type=str,
378
+ default="final_model.pkl",
379
+ help="Path to save the final model",
380
+ )
381
+ args = parser.parse_args()
382
+ main(args)
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy==1.26.2
2
+ pandas==2.1.3
3
+ pypdf==3.17.1
4
+ pytesseract==0.3.10
5
+ pdf2image==1.16.3
6
+ pdfminer.six==20221105
7
+ pdfplumber==0.10.3
8
+ Pillow==10.1.0
9
+ PyPDF2==3.0.1
10
+ pypdfium2==4.24.0
11
+ pytesseract==0.3.10
12
+ requests==2.31.0
13
+ torch==2.1.1
14
+ transformers==4.35.2
15
+ datasets==2.15.0
16
+ optuna==3.4.0
17
+ requests==2.31.0
18
+ scikit-learn==1.3.2
19
+ scipy==1.11.4
20
+ torchmetrics==1.2.1
21
+ tqdm==4.66.1
22
+ tokenizers==0.15.0
23
+ xgboost==2.0.2
24
+ gradio
torch_train.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class TorchTrain:
6
+ """A class for training a model in PyTorch.
7
+
8
+ Parameters
9
+ -----------
10
+ model (torch.nn.Module): The PyTorch model to train.
11
+ optimizer (torch.optim.Optimizer): The optimizer to use for training.
12
+ loss_function (callable): The loss function to use for training.
13
+ metrics (dict or callable, optional): The metrics to evaluate during training.
14
+ If a dictionary, the keys are the metric names and the values are functions that
15
+ take in `yhat` and `y` and return a metric value. If a callable, it should take
16
+ in `yhat` and `y` and return a metric value. Defaults to None.
17
+
18
+ Attributes
19
+ -----------
20
+ DEVICE (torch.device): The device to use for training (cuda if available, cpu otherwise).
21
+ model (torch.nn.Module): The PyTorch model being trained.
22
+ optimizer (torch.optim.Optimizer): The optimizer being used for training.
23
+ loss_function (callable): The loss function being used for training.
24
+ metrics (dict or callable): The metrics being evaluated during training.
25
+ metrics_evaluated (dict): The metrics evaluated during training.
26
+ train_loss (float): The average training loss.
27
+ test_loss (float): The average test loss.
28
+ train_iteration (int): The number of training iterations.
29
+ test_iteration (int): The number of test iterations.
30
+ train_metrics (dict): The metrics evaluated on the training data.
31
+ test_metrics (dict): The metrics evaluated on the test data.
32
+ """
33
+
34
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+
36
+ def __init__(
37
+ self,
38
+ model,
39
+ optimizer,
40
+ loss_function,
41
+ metrics=None,
42
+ scheduler=None,
43
+ task_type="classification",
44
+ ) -> None:
45
+ """Initialize the TorchTrain object.
46
+
47
+ Parameters
48
+ -----------
49
+ model : torch.nn.Module
50
+ The PyTorch model to train.
51
+ optimizer : torch.optim.Optimizer
52
+ The optimizer to use for training.
53
+ loss_function : callable
54
+ The loss function to use for training.
55
+ metrics : dict or callable, optional
56
+ The metrics to evaluate during training. If a dictionary, the keys are the metric names
57
+ and the values are functions that take in `yhat` and `y` and return a metric value.
58
+ If a callable, it should take in `yhat` and `y` and return a metric value. Defaults to None.
59
+ scheduler : torch.optim.lr_scheduler, optional
60
+ The learning rate scheduler to use for training. Defaults to None.
61
+ """
62
+ self.model = model
63
+ self.model.to(self.DEVICE)
64
+ self.optimizer = optimizer
65
+ self.loss_function = loss_function
66
+ self.metrics = self.__preprocess_metrics(metrics)
67
+ self.scheduler = scheduler
68
+ self.metrics_evaluated = {}
69
+ self.train_loss = 0
70
+ self.test_loss = 0
71
+ self.train_iteration = 0
72
+ self.test_iteration = 0
73
+ self.train_metrics = {}
74
+ self.test_metrics = {}
75
+ self.history = {}
76
+ self.train_loss_all = []
77
+ self.test_loss_all = []
78
+ self.train_metrics_all = []
79
+ self.test_metrics_all = []
80
+ self.__train_scaled = False
81
+ self.__test_scaled = False
82
+ self.task_type = task_type
83
+
84
+ def __preprocess_metrics(self, metrics):
85
+ """Preprocesses the given metrics"""
86
+ if metrics is None:
87
+ return {}
88
+ if isinstance(metrics, dict):
89
+ return {key.title(): value for key, value in metrics.items()}
90
+ else:
91
+ raise TypeError(
92
+ "Metrics should be a dictionary of metrics or a function which takes yhat, y"
93
+ )
94
+
95
+ def __scale_matrices(self, loss, metrics, type="train"):
96
+ """Scales the loss and metrics
97
+
98
+ Parameters
99
+ -----------
100
+ loss : float
101
+ The loss to scale
102
+ metrics : dict
103
+ The metrics to scale
104
+ type : str, optional
105
+ The type of scaling to do, either "train" or "test", by default "train"
106
+
107
+ Returns
108
+ --------
109
+ loss : float
110
+ The scaled loss
111
+ metrics : dict
112
+ The scaled metrics
113
+ """
114
+ if type == "train" and not self.__train_scaled:
115
+ scale = self.train_iteration
116
+ self.__train_scaled = True
117
+ elif type == "test" and not self.__test_scaled:
118
+ scale = self.test_iteration
119
+ self.__test_scaled = True
120
+ else:
121
+ return loss, metrics
122
+ loss /= scale
123
+ for key in metrics:
124
+ metrics[key] /= scale
125
+ return loss, metrics
126
+
127
+ def __reset_counters(self):
128
+ """Resets all the counters and loss objects for a new epoch"""
129
+ self.train_loss, self.train_metrics = self.__scale_matrices(
130
+ self.train_loss, self.train_metrics, type="train"
131
+ )
132
+
133
+ self.test_loss, self.test_metrics = self.__scale_matrices(
134
+ self.test_loss, self.test_metrics, type="test"
135
+ )
136
+
137
+ self.train_loss_all.append(self.train_loss)
138
+ self.train_loss = 0
139
+
140
+ self.test_loss_all.append(self.test_loss)
141
+ self.test_loss = 0
142
+
143
+ self.train_iteration = 0
144
+ self.test_iteration = 0
145
+
146
+ self.train_metrics_all.append(self.train_metrics)
147
+ self.train_metrics = {}
148
+
149
+ self.test_metrics_all.append(self.test_metrics)
150
+ self.test_metrics = {}
151
+ self.__train_scaled = False
152
+ self.__test_scaled = False
153
+
154
+ @property
155
+ def loss(self):
156
+ """Returns the training loss"""
157
+ return self.train_loss_all[-1]
158
+
159
+ def __create_history(self):
160
+ """Creates the history dictionary"""
161
+ history = {
162
+ "train_loss": self.train_loss_all,
163
+ "val_loss": self.test_loss_all,
164
+ }
165
+ for key, value in self.metrics.items():
166
+ history[f"train_{key.lower()}"] = []
167
+ history[f"val_{key.lower()}"] = []
168
+
169
+ for item in self.train_metrics_all:
170
+ for key, value in item.items():
171
+ history[f"train_{key.lower()}"].append(value)
172
+
173
+ for item in self.test_metrics_all:
174
+ for key, value in item.items():
175
+ history[f"val_{key.lower()}"].append(value)
176
+ return history
177
+
178
+ def __parse_val(self, val):
179
+ """Parses the given value to a float"""
180
+ if isinstance(val, torch.Tensor):
181
+ val = val.item()
182
+ elif isinstance(val, np.ndarray):
183
+ val = float(val)
184
+ elif isinstance(val, (int, float)):
185
+ pass
186
+ else:
187
+ raise TypeError(
188
+ f"The given Metric function should return a tensor, numpy array, int, or float.\n\
189
+ Got {type(val)}"
190
+ )
191
+ return val
192
+
193
+ def _train_step(self, x, y):
194
+ """Perform a single training step.
195
+
196
+ Parameters
197
+ ----------
198
+ x : torch.Tensor
199
+ The input tensor.
200
+ y : torch.Tensor
201
+ The target tensor.
202
+
203
+ Returns
204
+ -------
205
+ tuple
206
+ A tuple containing the loss and the predicted output tensor.
207
+ """
208
+ self.model.train()
209
+ yhat = self.model(x)
210
+ l = self.loss_function(yhat, y)
211
+ self.optimizer.zero_grad()
212
+ l.backward()
213
+ self.optimizer.step()
214
+ self.train_iteration += 1
215
+ return l.item(), yhat
216
+
217
+ def _test_step(self, x, y):
218
+ """Perform a single testing step.
219
+
220
+ Parameters
221
+ ----------
222
+ x : torch.Tensor
223
+ The input tensor.
224
+ y : torch.Tensor
225
+ The target tensor.
226
+
227
+ Returns
228
+ -------
229
+ tuple
230
+ A tuple containing the loss and the predicted output tensor.
231
+ """
232
+ self.model.eval()
233
+ with torch.inference_mode():
234
+ yhat = self.model(x)
235
+ l = self.loss_function(yhat, y)
236
+ self.test_iteration += 1
237
+ return l.item(), yhat
238
+
239
+ def predict(self, x):
240
+ """Make predictions on a batch of data.
241
+
242
+ Parameters
243
+ ----------
244
+ x : torch.Tensor
245
+ The input tensor.
246
+
247
+ Returns
248
+ -------
249
+ torch.Tensor
250
+ The predicted output tensor.
251
+ """
252
+ self.model.eval()
253
+ yhat = self.model(x)
254
+ if self.task_type == "classification":
255
+ if len(yhat.shape) == 1:
256
+ # round
257
+ yhat = torch.round(yhat)
258
+ yhat = yhat.unsqueeze(1)
259
+ else:
260
+ yhat = torch.argmax(yhat, dim=1)
261
+
262
+ return yhat
263
+
264
+ def __calculate_metrics(self, yhat, y):
265
+ """Calculate the metrics for a batch of data.
266
+
267
+ Parameters
268
+ ----------
269
+ yhat : torch.Tensor
270
+ The predicted output tensor.
271
+ y : torch.Tensor
272
+ The target tensor.
273
+
274
+ Returns
275
+ -------
276
+ dict
277
+ A dictionary containing the values of the metrics.
278
+ """
279
+ metrics = {}
280
+ for key, metric in self.metrics.items():
281
+ val = metric(yhat, y)
282
+ if isinstance(val, torch.Tensor):
283
+ val = val.item()
284
+ elif isinstance(val, np.ndarray):
285
+ val = float(val)
286
+ elif isinstance(val, (int, float)):
287
+ pass
288
+ else:
289
+ raise TypeError(
290
+ f"Metric {key} should return a tensor, numpy array, int, or float"
291
+ )
292
+ metrics[key] = val
293
+ self.metrics_evaluated = metrics
294
+ return metrics
295
+
296
+ def __progress_bar(self, cur_iter, all_iter):
297
+ """Creates a progress bar showing the progress of the current batch.
298
+
299
+ Parameters
300
+ ----------
301
+ cur_iter : int
302
+ The current batch number.
303
+ all_iter : int
304
+ The total number of batches.
305
+
306
+ Returns
307
+ -------
308
+ str
309
+ The progress bar, in the form of "10/100[====----]".
310
+ """
311
+ len_progress_bar = 20
312
+ progress = int((cur_iter + 1) / all_iter * len_progress_bar)
313
+ progress_bar = "=" * progress + "-" * (len_progress_bar - progress)
314
+ return f"[{progress_bar}]"
315
+
316
+ def progress(self, cur_iter, all_iter, loss, metrics, on="train"):
317
+ """Prints a progress bar showing the progress of the current batch.
318
+
319
+ Parameters
320
+ ----------
321
+ cur_iter : int
322
+ The current batch number.
323
+ all_iter : int
324
+ The total number of batches.
325
+ loss : float
326
+ The current loss. Should be averaged over all batches.
327
+ metrics : dict
328
+ The metrics evaluated on the current batch.
329
+ on : str, optional
330
+ Whether the progress bar is for the training or testing data. Defaults to "train".
331
+
332
+ Returns
333
+ -------
334
+ str
335
+ The progress bar, in the form of "10/100[====----]".
336
+
337
+ Notes
338
+ -----
339
+ The progress bar shows the progress of the current batch as a bar of equal signs ("=") and
340
+ hyphens ("-"). The length of the bar is fixed at 20 characters. The current batch number
341
+ and total number of batches are displayed at the beginning of the progress bar. The current
342
+ loss and any metrics evaluated on the current batch are displayed at the end of the progress
343
+ bar.
344
+ """
345
+ # len_progress_bar = 20
346
+ # progress = int((cur_iter + 1) / all_iter * len_progress_bar)
347
+ # progress_bar = "=" * progress + "-" * (len_progress_bar - progress)
348
+ progress_bar = self.__progress_bar(cur_iter=cur_iter, all_iter=all_iter)
349
+
350
+ if on.lower() == "train":
351
+ iteration = self.train_iteration
352
+ prefix = f"Epoch {(self.current_epoch+1):2d}/{self.epochs:2d} Batch "
353
+ else:
354
+ iteration = self.test_iteration
355
+ prefix = "Epoch "
356
+
357
+ text = f"{prefix}{cur_iter:>4d}/{all_iter:>4d}{progress_bar} {on.title()} loss: {loss/iteration:.4f}"
358
+ for metric_name, metric_value in metrics.items():
359
+ text += f" | {on.title()} {metric_name}: {metric_value/iteration:.4f}"
360
+
361
+ return text
362
+
363
+ def update_metrics(self, cur_metrics, new_metrics):
364
+ """Update the metrics with the values for a new batch of data.
365
+
366
+ Parameters
367
+ ----------
368
+ cur_metrics : dict
369
+ The current values of the metrics.
370
+ new_metrics : dict
371
+ The values of the metrics for a new batch of data.
372
+
373
+ Returns
374
+ -------
375
+ dict
376
+ A dictionary containing the updated values of the metrics.
377
+ """
378
+ for key, value in new_metrics.items():
379
+ if key not in cur_metrics:
380
+ cur_metrics[key] = value
381
+ else:
382
+ cur_metrics[key] += value
383
+ return cur_metrics
384
+
385
+ def fit(
386
+ self,
387
+ train_loader,
388
+ validation_data_loader=None,
389
+ epochs=1,
390
+ verbose=True,
391
+ train_steps_per_epoch=None,
392
+ validation_steps_per_epoch=None,
393
+ ):
394
+ """Fit the PyTorch model.
395
+
396
+ Parameters
397
+ ----------
398
+ train_loader : torch.utils.data.DataLoader
399
+ The data loader for the training data.
400
+ validation_data_loader : torch.utils.data.DataLoader, optional
401
+ The data loader for the test data. Defaults to None.
402
+ epochs : int, optional
403
+ The number of epochs to train for. Defaults to 1.
404
+ verbose : bool, optional
405
+ Whether to print the training progress during training. Defaults to True.
406
+ train_steps_per_epoch : int, optional
407
+ The number of batches to train on per epoch. Defaults to None.
408
+ validation_steps_per_epoch : int, optional
409
+ The number of batches to test on per epoch. Defaults to None.
410
+
411
+ Returns
412
+ -------
413
+ None
414
+
415
+ Examples
416
+ --------
417
+ >>> model = MyModel()
418
+ >>> optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
419
+ >>> loss_function = nn.CrossEntropyLoss()
420
+ >>> scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)
421
+ >>> train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
422
+ >>> validation_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)
423
+ >>> trainer = TorchTrain(model, optimizer, loss_function, scheduler=scheduler)
424
+ >>> trainer.fit(train_loader, validation_data_loader=validation_data_loader, epochs=10, verbose=True)
425
+ """
426
+ self.epochs = epochs
427
+ if train_steps_per_epoch is None:
428
+ train_steps_per_epoch = len(train_loader)
429
+ if validation_data_loader is not None:
430
+ if validation_steps_per_epoch is None:
431
+ validation_steps_per_epoch = len(validation_data_loader)
432
+
433
+ for epoch in range(epochs):
434
+ self.current_epoch = epoch
435
+ for i, (x, y) in enumerate(train_loader):
436
+ x = x.to(self.DEVICE)
437
+ if isinstance(y, list) or isinstance(y, tuple):
438
+ y = [y_.to(self.DEVICE) for y_ in y]
439
+ else:
440
+ y = y.to(self.DEVICE)
441
+
442
+ train_loss, yhat = self._train_step(x, y)
443
+ self.train_loss += train_loss
444
+ metrics = self.__calculate_metrics(yhat, y)
445
+ self.train_metrics = self.update_metrics(self.train_metrics, metrics)
446
+
447
+ b_progress = self.progress(
448
+ i + 1,
449
+ train_steps_per_epoch,
450
+ self.train_loss,
451
+ self.train_metrics,
452
+ on="train",
453
+ )
454
+ if i == train_steps_per_epoch - 1:
455
+ print(b_progress)
456
+ break
457
+ else:
458
+ if verbose:
459
+ print(b_progress, end="\r")
460
+ if validation_data_loader is not None:
461
+ for i, (x, y) in enumerate(validation_data_loader):
462
+ x = x.to(self.DEVICE)
463
+ if isinstance(y, list) or isinstance(y, tuple):
464
+ y = [y_.to(self.DEVICE) for y_ in y]
465
+ else:
466
+ y = y.to(self.DEVICE)
467
+ test_loss, yhat = self._test_step(x, y)
468
+ self.test_loss += test_loss
469
+ metrics = self.__calculate_metrics(yhat, y)
470
+ self.test_metrics = self.update_metrics(self.test_metrics, metrics)
471
+ if i == validation_steps_per_epoch - 1:
472
+ break
473
+ test_progress = self.progress(
474
+ epoch + 1,
475
+ epochs,
476
+ self.test_loss,
477
+ self.test_metrics,
478
+ on="test",
479
+ )
480
+ print(test_progress)
481
+ self.__reset_counters()
482
+ if self.scheduler is not None:
483
+ self.scheduler.step()
484
+ if verbose and self.scheduler is not None:
485
+ print(f"New Learning rate: {self.scheduler.get_last_lr()[0]:.6f}")
486
+
487
+ return self.__create_history()
488
+
489
+ def save(self, path):
490
+ """Save the model to a file.
491
+
492
+ Parameters
493
+ ----------
494
+ path : str
495
+ The path to the file to save the model to.
496
+ """
497
+ torch.save(self.model.state_dict(), path)
498
+
499
+ def load(self, path):
500
+ """Load the model from a file.
501
+
502
+ Parameters
503
+ ----------
504
+ path : str
505
+ The path to the file to load the model from.
506
+ """
507
+ self.model.load_state_dict(torch.load(path))
508
+
509
+ def evaluate(self, data_loader, metric):
510
+ """Evaluate the model on a data loader and the given metric.
511
+
512
+ Parameters
513
+ ----------
514
+ data_loader : torch.utils.data.DataLoader
515
+ The data loader to evaluate the model on.
516
+ metric : function
517
+ The metric to evaluate the model with.
518
+
519
+ Returns
520
+ -------
521
+ float
522
+ The score of the model on the given metric.
523
+ """
524
+ running_score = 0
525
+ data_length = len(data_loader)
526
+ for i, (x, y) in enumerate(data_loader):
527
+ progress_bar = self.__progress_bar(i, data_length)
528
+ x = x.to(self.DEVICE)
529
+ if isinstance(y, list) or isinstance(y, tuple):
530
+ y = [y_.to(self.DEVICE) for y_ in y]
531
+ else:
532
+ y = y.to(self.DEVICE)
533
+
534
+ yhat = self.model(x)
535
+ yhat = torch.round(yhat)
536
+ score = metric(y, yhat)
537
+ score = self.__parse_val(score)
538
+ running_score += score
539
+
540
+ progress_bar = f"{i+1}/{data_length}" + progress_bar
541
+ progress_bar += f" Score: {(running_score/(i+1)):4f}"
542
+ print(progress_bar, end="\r")
543
+ return running_score / (len(data_loader))
utilities.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PyPDF2
2
+ from pdfminer.high_level import extract_pages
3
+ from pdfminer.layout import LTTextContainer, LTFigure
4
+ import pdfplumber
5
+ from pdf2image import convert_from_bytes
6
+
7
+ import pytesseract
8
+ import logging
9
+ import requests
10
+ from io import BytesIO
11
+
12
+ import signal
13
+ import time
14
+ import functools
15
+
16
+
17
+ import time
18
+ import signal
19
+
20
+
21
+ class ModuleException(Exception):
22
+ def __init__(self, *args: object) -> None:
23
+ super().__init__(*args)
24
+
25
+
26
+ class TimeoutError(ModuleException):
27
+ def __init__(self, *args: object) -> None:
28
+ super().__init__(*args)
29
+
30
+
31
+ class BadUrlException(ModuleException):
32
+ """Raised when the url is not valid"""
33
+
34
+ def __init__(self, message):
35
+ self.message = message
36
+ super().__init__(self.message)
37
+
38
+
39
+ def timeout(seconds_before_timeout):
40
+ def wrapper_wrapper(func):
41
+ def handler(signum, frame):
42
+ raise TimeoutError()
43
+
44
+ @functools.wraps(func)
45
+ def wrapper_function(*args, **kwargs):
46
+ old = signal.signal(signal.SIGALRM, handler)
47
+ old_time_left = signal.alarm(seconds_before_timeout)
48
+ if (
49
+ 0 < old_time_left < second_before_timeout
50
+ ): # never lengthen existing timer
51
+ signal.alarm(old_time_left)
52
+ start_time = time.time()
53
+ try:
54
+ result = func(*args, **kwargs)
55
+ finally:
56
+ if old_time_left > 0: # deduct f's run time from the saved timer
57
+ old_time_left -= time.time() - start_time
58
+ signal.signal(signal.SIGALRM, old)
59
+ signal.alarm(old_time_left)
60
+ return result
61
+
62
+ return wrapper_function
63
+
64
+ return wrapper_wrapper
65
+
66
+
67
+ def get_simple_logger(name, level="info"):
68
+ """Creates a simple loger that outputs to stdout"""
69
+ level_to_int_map = {
70
+ "debug": logging.DEBUG,
71
+ "info": logging.INFO,
72
+ "warning": logging.WARNING,
73
+ "error": logging.ERROR,
74
+ "critical": logging.CRITICAL,
75
+ }
76
+ if isinstance(level, str):
77
+ level = level_to_int_map[level.lower()]
78
+ logger = logging.getLogger(name)
79
+ logger.setLevel(level)
80
+ formatter = logging.Formatter(
81
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
82
+ datefmt="%Y-%m-%d %H:%M:%S",
83
+ )
84
+ if logger.hasHandlers():
85
+ logger.handlers.clear()
86
+ handler = logging.StreamHandler()
87
+ handler.setLevel(level)
88
+ handler.setFormatter(formatter)
89
+ logger.addHandler(handler)
90
+ return logger
91
+
92
+
93
+ class PDFExtractor:
94
+ """A class to extract pdf information"""
95
+
96
+ def __init__(
97
+ self,
98
+ file_path,
99
+ min_characters=5,
100
+ maximum_pages=3,
101
+ is_url=True,
102
+ ) -> None:
103
+ """A class that can be used to extract pdf information from a file path or url
104
+
105
+ Parameters
106
+ ----------
107
+ file_path : str
108
+ The path to the file or the url
109
+ min_characters : int, optional
110
+ The minimum number of characters that a text needs to have to be considered a relevant text, by default 5
111
+ maximum_pages : int, optional
112
+ The maximum number of pages that the pdf can have, by default 5
113
+ is_url : bool, optional
114
+ Whether the file_path is a url or not, by default True
115
+
116
+ Raises
117
+ ------
118
+ ValueError
119
+ If the url raises a status code other than 200
120
+
121
+ Returns
122
+ -------
123
+ None
124
+
125
+ Examples
126
+ --------
127
+ >>> extractor = PDFExtractor(file_path="file_path", min_characters=5, maximum_pages=5, is_url=False)
128
+ >>> extractor.extract_pages()
129
+ """
130
+ self.file_path = file_path
131
+ self.min_characters = min_characters
132
+ self.maximum_pages = maximum_pages
133
+ self.logger = get_simple_logger(name="pdf_extractor", level="info")
134
+ try:
135
+ self.byte_file = self._create_byte_object(self.file_path, is_url)
136
+ except Exception as e:
137
+ self.logger.error(e)
138
+ raise BadUrlException(str(e))
139
+ # PyPDF2 object
140
+ self.pypdf_object = PyPDF2.PdfReader(self.byte_file)
141
+ self.pdf_plumber_object = pdfplumber.open(self.byte_file)
142
+ # Pages from pdfminer
143
+ self.pdfminer_pages = extract_pages(self.byte_file)
144
+ self.page_contents = []
145
+ self.final_content = ""
146
+ self.one_page_contents = {
147
+ "contents": [],
148
+ "images": [],
149
+ "tables": [],
150
+ }
151
+
152
+ @timeout(120)
153
+ def _create_byte_object(self, path, is_url):
154
+ """Creates the byte file based on if it is url or not"""
155
+ if is_url:
156
+ res = requests.get(path)
157
+ if res.status_code == 200:
158
+ byte_file = BytesIO(res.content)
159
+ else:
160
+ raise BadUrlException(f"The url raised status code: {res.status_code}")
161
+ else:
162
+ byte_file = open(path, "rb")
163
+
164
+ return byte_file
165
+
166
+ def __check_for_relevant_text(self, text):
167
+ """Checks if the text is a relevant text or not
168
+
169
+ Parameters
170
+ ----------
171
+ text : str
172
+ The text to be checked
173
+
174
+ Returns
175
+ -------
176
+ bool
177
+ True if the text is a relevant text, False otherwise
178
+ """
179
+ # Remove the line breaker from the text
180
+ text_ = text.replace("\n", "")
181
+ if len(text_) > self.min_characters:
182
+ self.logger.debug(f"Text: {text_} is a relevant text.")
183
+ return True
184
+ self.logger.debug(f"Text: {text_} is not a relevant text.")
185
+ return False
186
+
187
+ def _handle_text(self, component):
188
+ """Handles the extraction of textual information
189
+
190
+ Parameters
191
+ ----------
192
+ component : pdfplumber.page.Page
193
+ The page object to extract the text from
194
+
195
+ Returns
196
+ -------
197
+ bool
198
+ True if the text is a relevant text, False otherwise
199
+
200
+ """
201
+ text = component.get_text()
202
+ if self.__check_for_relevant_text(text):
203
+ self.one_page_contents["contents"].append(text)
204
+ return True
205
+ return False
206
+
207
+ def __table_converter(self, table):
208
+ """Converts table from the output given by pdf_plumnber to make it more readble
209
+
210
+ Parameters
211
+ ----------
212
+ table : list
213
+ The table to be converted
214
+
215
+ Returns
216
+ -------
217
+ str
218
+ The converted table as a string
219
+
220
+ Examples
221
+ --------
222
+ >>> table = [["Name", "Age"], ["John", "30"], ["Jane", "25"]]
223
+ >>> table_string = __table_converter(table)
224
+ >>> print(table_string)
225
+ |Name|Age|
226
+ |John|30|
227
+ |Jane|25|
228
+
229
+ >>> table = [["Name", "Age"], ["John", "30"], ["Jane", "25"], ["\n", "None"]]
230
+ >>> table_string = __table_converter(table)
231
+ >>> print(table_string)
232
+ |Name|Age|
233
+ |John|30|
234
+ |Jane|25|
235
+ |None|None|
236
+
237
+ >>> table = [["Name", "Age"], ["John", "30"], ["Jane", "25"], ["\n", "None"], ["\n", "None"]]
238
+ >>> table_string = __table_converter(table)
239
+ """
240
+ table_string = ""
241
+ # Iterate through each row of the table
242
+ for row_num in range(len(table)):
243
+ row = table[row_num]
244
+ # Remove the line breaker from the wrapted texts
245
+ cleaned_row = [
246
+ item.replace("\n", " ")
247
+ if item is not None and "\n" in item
248
+ else "None"
249
+ if item is None
250
+ else item
251
+ for item in row
252
+ ]
253
+ # Convert the table into a string
254
+ table_string += "|" + "|".join(cleaned_row) + "|" + "\n"
255
+ # Removing the last line break
256
+ table_string = table_string[:-1]
257
+ return table_string
258
+
259
+ def _handle_table(self, page_number):
260
+ """Handles the extraction of tabular information
261
+
262
+ Parameters
263
+ ----------
264
+ page_number : int
265
+ The page number to extract the table from
266
+
267
+ Returns
268
+ -------
269
+ None
270
+ """
271
+ page = self.pdf_plumber_object.pages[page_number]
272
+ try:
273
+ tables = page.extract_tables()
274
+ except:
275
+ tables = [
276
+ ["NA", "NA"],
277
+ ]
278
+ tables_final = [self.__table_converter(table) for table in tables]
279
+ self.one_page_contents["tables"] = tables_final
280
+
281
+ # Create a function to crop the image elements from PDFs
282
+
283
+ def __crop_image(self, element, page_number):
284
+ """Crops the pdf and creates a new pdf with only the cropped area. This will later be converted into image and then OCRed using pytesseract
285
+
286
+ Parameters
287
+ ----------
288
+ element : pdfminer.layout.LTTextContainer
289
+ The element to crop from the pdf
290
+ page_number : int
291
+ The page number to crop the element from
292
+
293
+ Returns
294
+ -------
295
+ bytes
296
+ The cropped pdf as a byte object
297
+ """
298
+ pypdf_page = self.pypdf_object.pages[page_number]
299
+ # Get the coordinates to crop the image from PDF
300
+ [image_left, image_top, image_right, image_bottom] = [
301
+ element.x0,
302
+ element.y0,
303
+ element.x1,
304
+ element.y1,
305
+ ]
306
+ # Crop the page using coordinates (left, bottom, right, top)
307
+ pypdf_page.mediabox.lower_left = (image_left, image_bottom)
308
+ pypdf_page.mediabox.upper_right = (image_right, image_top)
309
+ # Save the cropped page to a new PDF
310
+ cropped_pdf_writer = PyPDF2.PdfWriter()
311
+ cropped_pdf_writer.add_page(pypdf_page)
312
+ # convert to byte
313
+ cropped_pdf_stream = BytesIO()
314
+ cropped_pdf_writer.write(cropped_pdf_stream)
315
+ byte_object = cropped_pdf_stream.getvalue()
316
+ return byte_object
317
+
318
+ # Create a function to convert the PDF to images
319
+ def _convert_to_images(self, pdf_byte):
320
+ """Converts the pdf byte object to images
321
+
322
+ Parameters
323
+ ----------
324
+ pdf_byte : bytes
325
+ The pdf byte object to be converted
326
+
327
+ Returns
328
+ -------
329
+ PIL.Image
330
+ The converted image
331
+ """
332
+ images = convert_from_bytes(pdf_byte)
333
+ image = images[0]
334
+ return image
335
+
336
+ @timeout(20)
337
+ def _image_to_text(self, image):
338
+ """Extracts text from image using pytesseract
339
+
340
+ Parameters
341
+ ----------
342
+ image : PIL.Image
343
+ The image to extract text from
344
+
345
+ Returns
346
+ -------
347
+ str
348
+ The extracted text from the image
349
+
350
+ Examples
351
+ --------
352
+ >>> image = PIL.Image.open("image.jpg")
353
+ >>> text = _image_to_text(image)
354
+ >>> print(text)
355
+ DUMMY TEXT
356
+ """
357
+ text = pytesseract.image_to_string(image)
358
+ # text = "DUMMY TEXT"
359
+ self.logger.debug(f"Extracted {text} from the image.")
360
+ return text
361
+
362
+ def _handle_image(self, element, page_number):
363
+ """Handles the extraction of image information
364
+
365
+ Parameters
366
+ ----------
367
+ element : pdfminer.layout.LTFigure
368
+ The element to extract the image from
369
+ page_number : int
370
+ The page number to extract the image from
371
+
372
+ Returns
373
+ -------
374
+ bool
375
+ True if the image is a relevant image, False otherwise
376
+
377
+ Notes
378
+ -----
379
+ Extract the text from the image using pytesseract. Check if the text is a relevant text or not. If the text is a relevant text, add it to the one_page_contents dictionary with the key "images"
380
+ If the text is not a relevant text, do nothing
381
+ Return True if the image is a relevant image, False otherwise
382
+ If the image is a relevant image, add it to the one_page_contents dictionary with the key "images"
383
+ If the image is not a relevant image, do nothing
384
+ Return True if the image is a relevant image, False otherwise
385
+ """
386
+ cropped_pdf = self.__crop_image(element, page_number)
387
+ image = self._convert_to_images(pdf_byte=cropped_pdf)
388
+ extracted_text = self._image_to_text(image)
389
+ # try:
390
+ # extracted_text = self._image_to_text(image)
391
+ # except TimeoutError:
392
+ # self.logger.warning("Timeout encountered. Skipping it.")
393
+ # return False
394
+ if self.__check_for_relevant_text(extracted_text):
395
+ self.one_page_contents["images"].append(extracted_text)
396
+ return True
397
+ return False
398
+
399
+ def extract_one_page(self, page_number, pdfminer_page):
400
+ """Extracts information from one page of the pdf
401
+
402
+ Parameters
403
+ ----------
404
+ page_number : int
405
+ The page number to extract the information from
406
+ pdfminer_page : pdfminer.layout.LTPage
407
+ The pdfminer page object to extract the information from
408
+
409
+ Returns
410
+ -------
411
+ None
412
+ """
413
+ self.one_page_contents = {
414
+ "contents": [],
415
+ "images": [],
416
+ "tables": [],
417
+ }
418
+ self._handle_table(page_number=page_number)
419
+ max_image_per_page = 2
420
+ image_number = 0
421
+ for element_number, element in enumerate(pdfminer_page._objs):
422
+ type_ = type(element)
423
+ self.logger.debug(
424
+ f"Handling Page: {page_number}, Element: {element_number} Type: {type_}"
425
+ )
426
+
427
+ if isinstance(element, LTTextContainer):
428
+ self._handle_text(element)
429
+
430
+ if isinstance(element, LTFigure) and image_number < max_image_per_page:
431
+ added = self._handle_image(
432
+ element=element,
433
+ page_number=page_number,
434
+ )
435
+ if added:
436
+ image_number += 1
437
+ self.page_contents.append(self.one_page_contents)
438
+
439
+ def create_final_text_content(self):
440
+ """Creates the final text using the information extracted so far. The final text has all the textual information, image information and tabular information. This text directly can be used for machine learning."""
441
+ final_content = ""
442
+ for page_number, content in enumerate(self.page_contents):
443
+ final_content += f"PAGE {page_number}\n"
444
+ text_contents = content["contents"]
445
+ final_content += "\n".join(text_contents)
446
+ for i, image in enumerate(content["images"]):
447
+ final_content += f"IMAGE {i}\n{image.strip()}\nIMAGE {i} ENDS\n"
448
+
449
+ for i, table in enumerate(content["tables"]):
450
+ final_content += f"TABLE {i}\n{table.strip()}\nTABLE {i} ENDS\n"
451
+ final_content += f"PAGE {page_number} ENDS\n"
452
+ self.final_content = final_content
453
+ return final_content
454
+
455
+ def extract_pages(self):
456
+ """Extracts information from all the pages of the pdf. This is the final method to be used
457
+
458
+ Parameters
459
+ ----------
460
+ None
461
+
462
+ Returns
463
+ -------
464
+ str
465
+ The final text content of the pdf. This text can be directly used for machine learning.
466
+ """
467
+ pages = extract_pages(self.byte_file)
468
+ for page_number, page in enumerate(pages):
469
+ self.logger.info(f"Working on the page: {page_number}")
470
+ if page_number >= self.maximum_pages:
471
+ self.logger.info(f"Maximum page limit reached. Breaking...")
472
+ break
473
+ self.extract_one_page(page_number, page)
474
+
475
+ final_content = self.create_final_text_content()
476
+ return final_content
477
+
478
+ # def __del__(self):
479
+ # # Make sure that the file is closed once the object is deleted
480
+ # self.logger.debug("Closing open file and removing the temporary directory.")
481
+ # self.byte_file.close()