Spaces:
Runtime error
Runtime error
HariNuvo
commited on
Commit
·
6320d54
1
Parent(s):
3a9699c
initial commit
Browse files- .gitignore +3 -0
- Dockerfile +35 -0
- app.py +48 -0
- create_dataset.py +171 -0
- encoder/config.json +26 -0
- encoder/model.safetensors +3 -0
- encoder/special_tokens_map.json +7 -0
- encoder/tokenizer.json +0 -0
- encoder/tokenizer_config.json +64 -0
- encoder/vocab.txt +0 -0
- final_model.pkl +3 -0
- inference.py +88 -0
- model.py +382 -0
- requirements.txt +24 -0
- torch_train.py +543 -0
- utilities.py +481 -0
.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.ipynb_checkpoints
|
| 2 |
+
__pycache__
|
| 3 |
+
*.json
|
Dockerfile
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10
|
| 2 |
+
ENV PATH="/root/miniconda3/bin:${PATH}"
|
| 3 |
+
ARG PATH="/root/miniconda3/bin:${PATH}"
|
| 4 |
+
RUN apt-get update
|
| 5 |
+
|
| 6 |
+
RUN apt-get install -y wget tesseract-ocr libtesseract-dev && rm -rf /var/lib/apt/lists/*
|
| 7 |
+
|
| 8 |
+
RUN wget \
|
| 9 |
+
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
| 10 |
+
&& mkdir /root/.conda \
|
| 11 |
+
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
| 12 |
+
&& rm -f Miniconda3-latest-Linux-x86_64.sh
|
| 13 |
+
RUN conda --version
|
| 14 |
+
|
| 15 |
+
RUN conda install -c conda-forge poppler -y
|
| 16 |
+
|
| 17 |
+
WORKDIR /code
|
| 18 |
+
|
| 19 |
+
COPY ./requirements.txt /code/requirements.txt
|
| 20 |
+
|
| 21 |
+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
| 22 |
+
|
| 23 |
+
COPY . .
|
| 24 |
+
RUN useradd -m -u 1000 user
|
| 25 |
+
USER user
|
| 26 |
+
ENV HOME=/home/user \
|
| 27 |
+
PATH=/home/user/.local/bin:$PATH
|
| 28 |
+
|
| 29 |
+
WORKDIR $HOME/app
|
| 30 |
+
|
| 31 |
+
COPY --chown=user . $HOME/app
|
| 32 |
+
|
| 33 |
+
RUN echo $pwd
|
| 34 |
+
|
| 35 |
+
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
app.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from inference import Inference
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def predict_url_class(url):
|
| 6 |
+
"""Predicts the class of the given pdf url. Creates the output necessary for gradio Label."""
|
| 7 |
+
inference = Inference(pdf_url=url)
|
| 8 |
+
try:
|
| 9 |
+
outputs = inference.predict()
|
| 10 |
+
except Exception as e:
|
| 11 |
+
gr.Warning(e)
|
| 12 |
+
output_for_gradio = {
|
| 13 |
+
"Lighting": outputs[0],
|
| 14 |
+
"Non-Lighting": outputs[1],
|
| 15 |
+
}
|
| 16 |
+
return output_for_gradio
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def main():
|
| 20 |
+
# Define Gradio interface
|
| 21 |
+
description = "<p>The model in trained on a number of PDFs related to lighting and non-lighting products. The model takes an URL as input and predicts whether the product in the PDF corresponds to a Ligthing product or not. The model may take upto 30 second to make a prediction. This is because we need to first extract textual, tabular and image information from various pages of the PDF and this may a long time. Make sure that the URL provided is unblocked and can be downloaded without any extra steps.</p>"
|
| 22 |
+
inputs = gr.Text(lines=1, placeholder="Enter the url of the PDF", label="URL")
|
| 23 |
+
outputs = gr.Label(
|
| 24 |
+
num_top_classes=2,
|
| 25 |
+
label="Prediction",
|
| 26 |
+
every=2,
|
| 27 |
+
)
|
| 28 |
+
gradio_app = gr.Interface(
|
| 29 |
+
fn=predict_url_class,
|
| 30 |
+
inputs=inputs,
|
| 31 |
+
outputs=outputs,
|
| 32 |
+
title="PDF",
|
| 33 |
+
description=description,
|
| 34 |
+
theme="snehilsanyal/scikit-learn",
|
| 35 |
+
examples=[
|
| 36 |
+
[
|
| 37 |
+
"https://www.topbrasslighting.com/wp-content/uploads/TopBrass-138.01-tearsheet-Jun12018.pdf"
|
| 38 |
+
],
|
| 39 |
+
["https://lyntec.com/wp-content/uploads/2018/12/LynTec-XPC-Brochure.pdf"],
|
| 40 |
+
],
|
| 41 |
+
allow_flagging="never",
|
| 42 |
+
)
|
| 43 |
+
gradio_app.queue().launch(server_name="0.0.0.0", server_port=7860)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
if __name__ == "__main__":
|
| 47 |
+
# Run Gradio app
|
| 48 |
+
main()
|
create_dataset.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utilities import get_simple_logger, PDFExtractor, ModuleException
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import os, glob
|
| 4 |
+
import json
|
| 5 |
+
from tqdm.auto import tqdm
|
| 6 |
+
import re
|
| 7 |
+
|
| 8 |
+
file_dir = os.path.dirname(os.path.realpath(__file__))
|
| 9 |
+
csv_file_dir = os.path.join(file_dir, "materials")
|
| 10 |
+
data_dir = os.path.join(file_dir, "data")
|
| 11 |
+
train_json_dir = os.path.join(data_dir, "train_jsons")
|
| 12 |
+
test_json_dir = os.path.join(data_dir, "test_jsons")
|
| 13 |
+
|
| 14 |
+
logger = get_simple_logger("create_dataset")
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def save_json(data, split="train"):
|
| 18 |
+
id_ = data["id"]
|
| 19 |
+
if split == "train":
|
| 20 |
+
to_save = os.path.join(train_json_dir, f"{id_}.json")
|
| 21 |
+
else:
|
| 22 |
+
to_save = os.path.join(test_json_dir, f"{id_}.json")
|
| 23 |
+
logger.debug(f"Saving the json to {to_save}")
|
| 24 |
+
with open(to_save, "w") as f:
|
| 25 |
+
json.dump(data, f)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def clean_text(text):
|
| 29 |
+
# Clean the text by
|
| 30 |
+
# remove extra whitespace
|
| 31 |
+
regex = re.compile(r"\s{2,}")
|
| 32 |
+
text = regex.sub(" ", text)
|
| 33 |
+
# removing more than one new line with a single new line
|
| 34 |
+
regex = re.compile(r"\n{2,}")
|
| 35 |
+
text = regex.sub("\n", text)
|
| 36 |
+
# if each line has less than 3 characters, remove it
|
| 37 |
+
lines = text.split("\n")
|
| 38 |
+
lines = [line for line in lines if len(line) > 3]
|
| 39 |
+
text = "\n".join(lines)
|
| 40 |
+
# cap max to 10000
|
| 41 |
+
text = text[:10000]
|
| 42 |
+
return text
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def create_id(id_, split):
|
| 46 |
+
id_ = int(id_)
|
| 47 |
+
if split == "train":
|
| 48 |
+
return f"P-{id_}"
|
| 49 |
+
else:
|
| 50 |
+
return f"TP{id_}"
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def create_json(split="train"):
|
| 54 |
+
"""Creates the dataset from the csv file and saves it to the data_dir
|
| 55 |
+
|
| 56 |
+
Parameters
|
| 57 |
+
----------
|
| 58 |
+
split : str, optional
|
| 59 |
+
The split to create the dataset for, by default "train"
|
| 60 |
+
"""
|
| 61 |
+
logger.info(f"Creating the dataset for {split}")
|
| 62 |
+
df_path = os.path.join(csv_file_dir, f"parspec_{split}_data.csv")
|
| 63 |
+
df = pd.read_csv(df_path)
|
| 64 |
+
df.dropna(inplace=True)
|
| 65 |
+
json_dir = train_json_dir if split == "train" else test_json_dir
|
| 66 |
+
os.makedirs(json_dir, exist_ok=True)
|
| 67 |
+
|
| 68 |
+
# already extracted files
|
| 69 |
+
extracted_files = os.listdir(json_dir)
|
| 70 |
+
extracted_files = list(map(lambda x: x.split(".")[0], os.listdir(train_json_dir)))
|
| 71 |
+
logger.info(f"{len(extracted_files)} files are already extracted.")
|
| 72 |
+
|
| 73 |
+
for i, row in tqdm(
|
| 74 |
+
df.iterrows(),
|
| 75 |
+
desc="extracting information...",
|
| 76 |
+
total=len(df) - len(extracted_files),
|
| 77 |
+
):
|
| 78 |
+
# if i == 3:
|
| 79 |
+
# break
|
| 80 |
+
id_ = row["ID"]
|
| 81 |
+
if "-" in id_:
|
| 82 |
+
# for train
|
| 83 |
+
id_ = id_.split("-")[1]
|
| 84 |
+
else:
|
| 85 |
+
# for test
|
| 86 |
+
id_ = id_[2:]
|
| 87 |
+
id_ = id_.zfill(4)
|
| 88 |
+
if id_ in extracted_files:
|
| 89 |
+
logger.debug(f"File {id_} already extracted")
|
| 90 |
+
continue
|
| 91 |
+
logger.info(f"Extracting the file for ID {id_}")
|
| 92 |
+
url = row["URL"]
|
| 93 |
+
label = 1 if row["Is lighting product?"] in [1, "Yes"] else 0
|
| 94 |
+
try:
|
| 95 |
+
pdf_extractor = PDFExtractor(
|
| 96 |
+
file_path=url,
|
| 97 |
+
is_url=True,
|
| 98 |
+
min_characters=5,
|
| 99 |
+
maximum_pages=3,
|
| 100 |
+
)
|
| 101 |
+
final = pdf_extractor.extract_pages()
|
| 102 |
+
data = {
|
| 103 |
+
"status": "ok",
|
| 104 |
+
"id": id_,
|
| 105 |
+
"label": label,
|
| 106 |
+
"page_contents": pdf_extractor.page_contents,
|
| 107 |
+
"final_content": clean_text(final),
|
| 108 |
+
"url": url,
|
| 109 |
+
}
|
| 110 |
+
# save the json
|
| 111 |
+
except ModuleException:
|
| 112 |
+
logger.error(f"Url is not valid for ID {id_}. Using Null values.")
|
| 113 |
+
data = {
|
| 114 |
+
"status": "error",
|
| 115 |
+
"id": id_,
|
| 116 |
+
"label": label,
|
| 117 |
+
"page_contents": None,
|
| 118 |
+
"final_content": None,
|
| 119 |
+
"url": url,
|
| 120 |
+
}
|
| 121 |
+
save_json(data, split)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def create_dataframe(split):
|
| 125 |
+
df_path = os.path.join(csv_file_dir, f"parspec_{split}_data.csv")
|
| 126 |
+
df = pd.read_csv(df_path)
|
| 127 |
+
json_dir = train_json_dir if split == "train" else test_json_dir
|
| 128 |
+
json_files = glob.glob(f"{json_dir}/*.json")
|
| 129 |
+
statuss = []
|
| 130 |
+
ids = []
|
| 131 |
+
labels = []
|
| 132 |
+
contents = []
|
| 133 |
+
urls = []
|
| 134 |
+
for file in tqdm(json_files, "creating dataframe..."):
|
| 135 |
+
with open(file, "r") as f:
|
| 136 |
+
data = json.load(f)
|
| 137 |
+
if data["status"] == "error":
|
| 138 |
+
continue
|
| 139 |
+
statuss.append(data["status"])
|
| 140 |
+
ids.append(create_id(data["id"], split=split))
|
| 141 |
+
labels.append(data["label"])
|
| 142 |
+
contents.append(clean_text(data["final_content"]))
|
| 143 |
+
urls.append(data["url"])
|
| 144 |
+
|
| 145 |
+
final_df = pd.DataFrame(
|
| 146 |
+
{
|
| 147 |
+
"status": statuss,
|
| 148 |
+
"id": ids,
|
| 149 |
+
"label": labels,
|
| 150 |
+
"content": contents,
|
| 151 |
+
"url": urls,
|
| 152 |
+
}
|
| 153 |
+
)
|
| 154 |
+
final = pd.merge(final_df, df, left_on="id", right_on="ID")[
|
| 155 |
+
["id", "content", "Is lighting product?", "url"]
|
| 156 |
+
]
|
| 157 |
+
final.rename(columns={"Is lighting product?": "label"}, inplace=True)
|
| 158 |
+
final["label"] = final["label"].map(
|
| 159 |
+
{
|
| 160 |
+
"Yes": 1,
|
| 161 |
+
"No": 0,
|
| 162 |
+
}
|
| 163 |
+
)
|
| 164 |
+
final.to_csv(
|
| 165 |
+
os.path.join(data_dir, f"{split}.csv"), index=False, escapechar="\\"
|
| 166 |
+
) # setting escapechar is required
|
| 167 |
+
return final
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
if __name__ == "__main__":
|
| 171 |
+
create_dataframe(split="test")
|
encoder/config.json
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "sentence-transformers/paraphrase-MiniLM-L3-v2",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"BertModel"
|
| 5 |
+
],
|
| 6 |
+
"attention_probs_dropout_prob": 0.1,
|
| 7 |
+
"classifier_dropout": null,
|
| 8 |
+
"gradient_checkpointing": false,
|
| 9 |
+
"hidden_act": "gelu",
|
| 10 |
+
"hidden_dropout_prob": 0.1,
|
| 11 |
+
"hidden_size": 384,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"intermediate_size": 1536,
|
| 14 |
+
"layer_norm_eps": 1e-12,
|
| 15 |
+
"max_position_embeddings": 512,
|
| 16 |
+
"model_type": "bert",
|
| 17 |
+
"num_attention_heads": 12,
|
| 18 |
+
"num_hidden_layers": 3,
|
| 19 |
+
"pad_token_id": 0,
|
| 20 |
+
"position_embedding_type": "absolute",
|
| 21 |
+
"torch_dtype": "float32",
|
| 22 |
+
"transformers_version": "4.35.2",
|
| 23 |
+
"type_vocab_size": 2,
|
| 24 |
+
"use_cache": true,
|
| 25 |
+
"vocab_size": 30522
|
| 26 |
+
}
|
encoder/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f59d73a201e0f6092d7e88ac8589f886a946725cff96d0250231d7e272e63071
|
| 3 |
+
size 69565312
|
encoder/special_tokens_map.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cls_token": "[CLS]",
|
| 3 |
+
"mask_token": "[MASK]",
|
| 4 |
+
"pad_token": "[PAD]",
|
| 5 |
+
"sep_token": "[SEP]",
|
| 6 |
+
"unk_token": "[UNK]"
|
| 7 |
+
}
|
encoder/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
encoder/tokenizer_config.json
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {
|
| 3 |
+
"0": {
|
| 4 |
+
"content": "[PAD]",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false,
|
| 9 |
+
"special": true
|
| 10 |
+
},
|
| 11 |
+
"100": {
|
| 12 |
+
"content": "[UNK]",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false,
|
| 17 |
+
"special": true
|
| 18 |
+
},
|
| 19 |
+
"101": {
|
| 20 |
+
"content": "[CLS]",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": false,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false,
|
| 25 |
+
"special": true
|
| 26 |
+
},
|
| 27 |
+
"102": {
|
| 28 |
+
"content": "[SEP]",
|
| 29 |
+
"lstrip": false,
|
| 30 |
+
"normalized": false,
|
| 31 |
+
"rstrip": false,
|
| 32 |
+
"single_word": false,
|
| 33 |
+
"special": true
|
| 34 |
+
},
|
| 35 |
+
"103": {
|
| 36 |
+
"content": "[MASK]",
|
| 37 |
+
"lstrip": false,
|
| 38 |
+
"normalized": false,
|
| 39 |
+
"rstrip": false,
|
| 40 |
+
"single_word": false,
|
| 41 |
+
"special": true
|
| 42 |
+
}
|
| 43 |
+
},
|
| 44 |
+
"clean_up_tokenization_spaces": true,
|
| 45 |
+
"cls_token": "[CLS]",
|
| 46 |
+
"do_basic_tokenize": true,
|
| 47 |
+
"do_lower_case": true,
|
| 48 |
+
"mask_token": "[MASK]",
|
| 49 |
+
"max_length": 128,
|
| 50 |
+
"model_max_length": 512,
|
| 51 |
+
"never_split": null,
|
| 52 |
+
"pad_to_multiple_of": null,
|
| 53 |
+
"pad_token": "[PAD]",
|
| 54 |
+
"pad_token_type_id": 0,
|
| 55 |
+
"padding_side": "right",
|
| 56 |
+
"sep_token": "[SEP]",
|
| 57 |
+
"stride": 0,
|
| 58 |
+
"strip_accents": null,
|
| 59 |
+
"tokenize_chinese_chars": true,
|
| 60 |
+
"tokenizer_class": "BertTokenizer",
|
| 61 |
+
"truncation_side": "right",
|
| 62 |
+
"truncation_strategy": "longest_first",
|
| 63 |
+
"unk_token": "[UNK]"
|
| 64 |
+
}
|
encoder/vocab.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
final_model.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c9ebd5d8ba12c83ccd63476660c075cace91eca79b56f650e3db48431788c8f5
|
| 3 |
+
size 5788529
|
inference.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utilities import (
|
| 2 |
+
PDFExtractor,
|
| 3 |
+
get_simple_logger,
|
| 4 |
+
ModuleException,
|
| 5 |
+
)
|
| 6 |
+
from create_dataset import clean_text
|
| 7 |
+
from model import PDFDataSet
|
| 8 |
+
import pickle
|
| 9 |
+
import argparse
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Inference:
|
| 13 |
+
"""A class that does inference given a url of the pdf. Uses the saved RandomForest model for inference."""
|
| 14 |
+
|
| 15 |
+
def __init__(self, pdf_url, model_path="final_model.pkl") -> None:
|
| 16 |
+
self.pdf_url = pdf_url
|
| 17 |
+
self.model_path = model_path
|
| 18 |
+
self.logger = get_simple_logger("inference", level="debug")
|
| 19 |
+
|
| 20 |
+
def _extract_text(self):
|
| 21 |
+
self.logger.debug("Extracting text from the PDF.")
|
| 22 |
+
pdf_extractor = PDFExtractor(
|
| 23 |
+
self.pdf_url,
|
| 24 |
+
min_characters=5,
|
| 25 |
+
maximum_pages=3,
|
| 26 |
+
)
|
| 27 |
+
try:
|
| 28 |
+
final = pdf_extractor.extract_pages()
|
| 29 |
+
data = {
|
| 30 |
+
"status": "ok",
|
| 31 |
+
"id": "001",
|
| 32 |
+
"page_contents": pdf_extractor.page_contents,
|
| 33 |
+
"final_content": clean_text(final),
|
| 34 |
+
"url": self.pdf_url,
|
| 35 |
+
}
|
| 36 |
+
except ModuleException as e:
|
| 37 |
+
self.logger.warning("A module exception is raised.")
|
| 38 |
+
self.logger.error(e)
|
| 39 |
+
raise e
|
| 40 |
+
except Exception as e:
|
| 41 |
+
self.logger.error(e)
|
| 42 |
+
raise e
|
| 43 |
+
return data, final
|
| 44 |
+
|
| 45 |
+
def _load_model(self):
|
| 46 |
+
with open(self.model_path, "rb") as f:
|
| 47 |
+
model = pickle.load(f)
|
| 48 |
+
self.model = model
|
| 49 |
+
return self.model
|
| 50 |
+
|
| 51 |
+
def _create_embedding(self, text):
|
| 52 |
+
dataset = PDFDataSet()
|
| 53 |
+
embedding = dataset.sentences_to_embedding([text])
|
| 54 |
+
return embedding.reshape(1, -1)
|
| 55 |
+
|
| 56 |
+
def _pretty_print_probability(self, p):
|
| 57 |
+
non_lighting_probability = p[0] * 100
|
| 58 |
+
lighting_probablity = p[1] * 100
|
| 59 |
+
if lighting_probablity > non_lighting_probability:
|
| 60 |
+
print(
|
| 61 |
+
f"This is a Lighting product with a probability of {lighting_probablity:.2f}%"
|
| 62 |
+
)
|
| 63 |
+
else:
|
| 64 |
+
print(
|
| 65 |
+
f"This is a Non-Lighting product with a probability of {non_lighting_probability:.2f}%"
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
def predict(self):
|
| 69 |
+
_, sentence = self._extract_text()
|
| 70 |
+
embedding = self._create_embedding(sentence)
|
| 71 |
+
model = self._load_model()
|
| 72 |
+
prediction = model.predict_proba(embedding)
|
| 73 |
+
prediction = prediction[0]
|
| 74 |
+
self._pretty_print_probability(prediction)
|
| 75 |
+
return prediction
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def main(args):
|
| 79 |
+
inference = Inference(args.url)
|
| 80 |
+
prediction = inference.predict()
|
| 81 |
+
print(prediction)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
if __name__ == "__main__":
|
| 85 |
+
parser = argparse.ArgumentParser()
|
| 86 |
+
parser.add_argument("--url", type=str, required=True)
|
| 87 |
+
args = parser.parse_args()
|
| 88 |
+
main(args)
|
model.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoTokenizer, AutoModel
|
| 2 |
+
from datasets import load_dataset, Dataset, concatenate_datasets
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from sklearn.model_selection import train_test_split
|
| 7 |
+
from sklearn.metrics import (
|
| 8 |
+
classification_report,
|
| 9 |
+
confusion_matrix,
|
| 10 |
+
accuracy_score,
|
| 11 |
+
precision_score,
|
| 12 |
+
)
|
| 13 |
+
from sklearn.ensemble import RandomForestClassifier
|
| 14 |
+
from xgboost import XGBClassifier
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torchmetrics
|
| 17 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import pandas as pd
|
| 21 |
+
import os
|
| 22 |
+
import pickle
|
| 23 |
+
import argparse
|
| 24 |
+
from torch_train import TorchTrain
|
| 25 |
+
from utilities import get_simple_logger
|
| 26 |
+
|
| 27 |
+
FILE_DIR = os.path.dirname(os.path.realpath(__file__))
|
| 28 |
+
DATA_DIR = os.path.join(FILE_DIR, "data")
|
| 29 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 30 |
+
random_state = 42
|
| 31 |
+
# set random state
|
| 32 |
+
np.random.seed(random_state)
|
| 33 |
+
torch.manual_seed(random_state)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class PDFDataLoader:
|
| 37 |
+
def __init__(self, df):
|
| 38 |
+
self.df = df
|
| 39 |
+
|
| 40 |
+
def __getitem__(self, idx):
|
| 41 |
+
row = self.df[idx]
|
| 42 |
+
embeddings = row["embeddings"]
|
| 43 |
+
label = row["label"]
|
| 44 |
+
# convert to torch int
|
| 45 |
+
label = np.array(label)
|
| 46 |
+
# add extra dimension to label
|
| 47 |
+
label = np.expand_dims(label, axis=0)
|
| 48 |
+
embeddings = torch.from_numpy(np.array(embeddings)).float()
|
| 49 |
+
return embeddings.to(device), torch.from_numpy(label).to(device).float()
|
| 50 |
+
|
| 51 |
+
def __len__(self):
|
| 52 |
+
return len(self.df)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class PDFDataSet:
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
data_dir=DATA_DIR,
|
| 59 |
+
fraction_test_data_in_train=0.2,
|
| 60 |
+
model_ckpt="encoder",
|
| 61 |
+
) -> None:
|
| 62 |
+
self.data_dir = data_dir
|
| 63 |
+
self.fraction_test_data_in_train = fraction_test_data_in_train
|
| 64 |
+
self.model_ckpt = model_ckpt
|
| 65 |
+
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
|
| 66 |
+
encoding_model = AutoModel.from_pretrained(model_ckpt)
|
| 67 |
+
encoding_model = encoding_model.to(device)
|
| 68 |
+
encoding_model = encoding_model.eval()
|
| 69 |
+
self.encoding_model = encoding_model
|
| 70 |
+
self.tokenizer = tokenizer
|
| 71 |
+
self.logger = get_simple_logger("pdf_dataset")
|
| 72 |
+
|
| 73 |
+
def create_datasets(self):
|
| 74 |
+
train_data_path = os.path.join(FILE_DIR, self.data_dir, "train.csv")
|
| 75 |
+
test_data_path = os.path.join(FILE_DIR, self.data_dir, "test.csv")
|
| 76 |
+
df = pd.read_csv(train_data_path)
|
| 77 |
+
test_df = pd.read_csv(test_data_path)
|
| 78 |
+
train_df, validation_df = train_test_split(df, test_size=0.3, random_state=42)
|
| 79 |
+
if self.fraction_test_data_in_train:
|
| 80 |
+
self.logger.info(
|
| 81 |
+
f"Adding {self.fraction_test_data_in_train} fraction of test dataset to the training set."
|
| 82 |
+
)
|
| 83 |
+
test_df, test_df_for_training = train_test_split(
|
| 84 |
+
test_df, test_size=self.fraction_test_data_in_train, random_state=42
|
| 85 |
+
)
|
| 86 |
+
train_df = pd.concat([train_df, test_df_for_training])
|
| 87 |
+
|
| 88 |
+
train_dataset = Dataset.from_pandas(train_df)
|
| 89 |
+
validation_dataset = Dataset.from_pandas(validation_df)
|
| 90 |
+
test_dataset = Dataset.from_pandas(test_df)
|
| 91 |
+
return train_dataset, validation_dataset, test_dataset
|
| 92 |
+
|
| 93 |
+
def mean_pooling(self, model_output, attention_mask):
|
| 94 |
+
token_embeddings = model_output[
|
| 95 |
+
0
|
| 96 |
+
] # First element of model_output contains all token embeddings
|
| 97 |
+
input_mask_expanded = (
|
| 98 |
+
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
| 99 |
+
)
|
| 100 |
+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
|
| 101 |
+
input_mask_expanded.sum(1), min=1e-9
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
def sentences_to_embedding(self, sentences):
|
| 105 |
+
# Tokenize sentences
|
| 106 |
+
encoded_input = self.tokenizer(
|
| 107 |
+
sentences, padding=True, truncation=True, return_tensors="pt"
|
| 108 |
+
)
|
| 109 |
+
sentence_embeddings = self.mean_pooling(
|
| 110 |
+
self.encoding_model(**encoded_input), encoded_input["attention_mask"]
|
| 111 |
+
)
|
| 112 |
+
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
|
| 113 |
+
# remove last dimension
|
| 114 |
+
sentence_embeddings = sentence_embeddings.squeeze()
|
| 115 |
+
return sentence_embeddings.detach()
|
| 116 |
+
|
| 117 |
+
def get_embeddings(self, row):
|
| 118 |
+
return {
|
| 119 |
+
"embeddings": self.sentences_to_embedding(
|
| 120 |
+
sentences=row["content"],
|
| 121 |
+
)
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
def create_embeddings(self):
|
| 125 |
+
train_dataset, validation_dataset, test_dataset = self.create_datasets()
|
| 126 |
+
train_dataset = train_dataset.map(self.get_embeddings)
|
| 127 |
+
validation_dataset = validation_dataset.map(self.get_embeddings)
|
| 128 |
+
test_dataset = test_dataset.map(self.get_embeddings)
|
| 129 |
+
return train_dataset, validation_dataset, test_dataset
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class PDFModel(nn.Module):
|
| 133 |
+
def __init__(self, input_size, hidden_sizes, output_size):
|
| 134 |
+
super(PDFModel, self).__init__()
|
| 135 |
+
self.seq_model = nn.Sequential()
|
| 136 |
+
for i, hidden_size in enumerate(hidden_sizes):
|
| 137 |
+
self.seq_model.add_module(f"linear_{i}", nn.Linear(input_size, hidden_size))
|
| 138 |
+
self.seq_model.add_module(f"relu_{i}", nn.ReLU())
|
| 139 |
+
input_size = hidden_size
|
| 140 |
+
self.last_layer = nn.Linear(input_size, output_size)
|
| 141 |
+
self.sigmoid = nn.Sigmoid()
|
| 142 |
+
|
| 143 |
+
def forward(self, x):
|
| 144 |
+
seq_out = self.seq_model(x)
|
| 145 |
+
out = self.last_layer(seq_out)
|
| 146 |
+
return self.sigmoid(out)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def evaluate_model(y_true, y_pred, model_name, split="train"):
|
| 150 |
+
accuracy = accuracy_score(y_true, y_pred)
|
| 151 |
+
precision = precision_score(y_true, y_pred)
|
| 152 |
+
classification_report_ = classification_report(y_true, y_pred)
|
| 153 |
+
print("------" * 10)
|
| 154 |
+
print(f"Evaluating for the model: {model_name} for {split} dataset...")
|
| 155 |
+
print(f"Accuracy: {accuracy}")
|
| 156 |
+
print(f"Precision: {precision}")
|
| 157 |
+
print(classification_report_)
|
| 158 |
+
print("------" * 10)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def train_dl_model(
|
| 162 |
+
train_data,
|
| 163 |
+
validation_data,
|
| 164 |
+
epochs=30,
|
| 165 |
+
input_shape=384,
|
| 166 |
+
hidden_sizes=[32, 16],
|
| 167 |
+
):
|
| 168 |
+
model = PDFModel(input_size=input_shape, hidden_sizes=hidden_sizes, output_size=1)
|
| 169 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
| 170 |
+
loss_fn = nn.BCELoss()
|
| 171 |
+
accuracy = torchmetrics.Accuracy(
|
| 172 |
+
task="binary", num_classes=2, threshold=0.5, average="macro"
|
| 173 |
+
)
|
| 174 |
+
precision = torchmetrics.Precision(task="binary", average="macro")
|
| 175 |
+
metrics = {
|
| 176 |
+
"accuracy": accuracy,
|
| 177 |
+
"precision": precision,
|
| 178 |
+
}
|
| 179 |
+
scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=0.0001)
|
| 180 |
+
tt = TorchTrain(model, optimizer, loss_fn, metrics=metrics, scheduler=scheduler)
|
| 181 |
+
history = tt.fit(train_data, validation_data, verbose=True, epochs=epochs)
|
| 182 |
+
return history, model
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def evaluate_models(fraction_test_data_in_train=0.1):
|
| 186 |
+
print("Creating Embeddings...")
|
| 187 |
+
ds = PDFDataSet(fraction_test_data_in_train=fraction_test_data_in_train)
|
| 188 |
+
train_dataset, validation_dataset, test_dataset = ds.create_embeddings()
|
| 189 |
+
print("Done\n")
|
| 190 |
+
|
| 191 |
+
print("Training DL Model")
|
| 192 |
+
# Create dataset for DL models:
|
| 193 |
+
BATCH_SIZE = 8
|
| 194 |
+
train_dataloader = PDFDataLoader(train_dataset)
|
| 195 |
+
validation_dataloader = PDFDataLoader(validation_dataset)
|
| 196 |
+
test_dataloader = PDFDataLoader(test_dataset)
|
| 197 |
+
|
| 198 |
+
train_data = DataLoader(train_dataloader, batch_size=BATCH_SIZE, shuffle=True)
|
| 199 |
+
validation_data = DataLoader(
|
| 200 |
+
validation_dataloader,
|
| 201 |
+
batch_size=BATCH_SIZE,
|
| 202 |
+
shuffle=True,
|
| 203 |
+
)
|
| 204 |
+
test_data = DataLoader(test_dataloader, batch_size=BATCH_SIZE, shuffle=True)
|
| 205 |
+
for X, y in train_data:
|
| 206 |
+
input_shape = int(X.shape[1])
|
| 207 |
+
output_shape = int(y.shape[1])
|
| 208 |
+
break
|
| 209 |
+
epochs = 30
|
| 210 |
+
hidden_sizes = [32, 16]
|
| 211 |
+
history, model = train_dl_model(
|
| 212 |
+
train_data=train_data,
|
| 213 |
+
validation_data=validation_data,
|
| 214 |
+
epochs=epochs,
|
| 215 |
+
hidden_sizes=hidden_sizes,
|
| 216 |
+
)
|
| 217 |
+
print("Done\n")
|
| 218 |
+
print("Evaluating DL Model")
|
| 219 |
+
y_test_pred = model(torch.from_numpy(np.array(test_dataset["embeddings"])).float())
|
| 220 |
+
y_test_pred = y_test_pred.detach().numpy()
|
| 221 |
+
y_test_pred = np.where(y_test_pred > 0.5, 1, 0)
|
| 222 |
+
evaluate_model(
|
| 223 |
+
y_true=test_dataset["label"],
|
| 224 |
+
y_pred=y_test_pred,
|
| 225 |
+
model_name="DL Model",
|
| 226 |
+
split="test",
|
| 227 |
+
)
|
| 228 |
+
print("Done\n")
|
| 229 |
+
|
| 230 |
+
# ML Models
|
| 231 |
+
print("Training and evaluating ML Models.")
|
| 232 |
+
X_train = train_dataset["embeddings"]
|
| 233 |
+
y_train = train_dataset["label"]
|
| 234 |
+
X_validation = validation_dataset["embeddings"]
|
| 235 |
+
y_validation = validation_dataset["label"]
|
| 236 |
+
X_test = test_dataset["embeddings"]
|
| 237 |
+
y_test = test_dataset["label"]
|
| 238 |
+
rfc_best_params = {
|
| 239 |
+
"max_depth": 23,
|
| 240 |
+
"max_features": "log2",
|
| 241 |
+
"n_estimators": 469,
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
xgb_best_params = {
|
| 245 |
+
"max_depth": 25,
|
| 246 |
+
"n_estimators": 372,
|
| 247 |
+
"learning_rate": 0.2522824287799319,
|
| 248 |
+
}
|
| 249 |
+
print("Fitting RandomForest")
|
| 250 |
+
rfc = RandomForestClassifier(**rfc_best_params)
|
| 251 |
+
rfc.fit(X_train, y_train)
|
| 252 |
+
evaluate_model(
|
| 253 |
+
y_true=y_train,
|
| 254 |
+
y_pred=rfc.predict(X_train),
|
| 255 |
+
model_name="RandomForest",
|
| 256 |
+
split="train",
|
| 257 |
+
)
|
| 258 |
+
evaluate_model(
|
| 259 |
+
y_true=y_validation,
|
| 260 |
+
y_pred=rfc.predict(X_validation),
|
| 261 |
+
model_name="RandomForest",
|
| 262 |
+
split="validation",
|
| 263 |
+
)
|
| 264 |
+
evaluate_model(
|
| 265 |
+
y_true=y_test,
|
| 266 |
+
y_pred=rfc.predict(X_test),
|
| 267 |
+
model_name="RandomForest",
|
| 268 |
+
split="test",
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
print("Fitting XGBoost")
|
| 272 |
+
xgb = XGBClassifier(**xgb_best_params)
|
| 273 |
+
xgb.fit(X_train, y_train)
|
| 274 |
+
evaluate_model(
|
| 275 |
+
y_true=y_train,
|
| 276 |
+
y_pred=xgb.predict(X_train),
|
| 277 |
+
model_name="XGBoost",
|
| 278 |
+
split="train",
|
| 279 |
+
)
|
| 280 |
+
evaluate_model(
|
| 281 |
+
y_true=y_validation,
|
| 282 |
+
y_pred=xgb.predict(X_validation),
|
| 283 |
+
model_name="XGBoost",
|
| 284 |
+
split="validation",
|
| 285 |
+
)
|
| 286 |
+
evaluate_model(
|
| 287 |
+
y_true=y_test,
|
| 288 |
+
y_pred=xgb.predict(X_test),
|
| 289 |
+
model_name="XGBoost",
|
| 290 |
+
split="test",
|
| 291 |
+
)
|
| 292 |
+
print("All Done")
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def train_and_save_final_model(model_save_path="final_model.pkl"):
|
| 296 |
+
"""This method creats and save the final model. The final model has the following characterstics:
|
| 297 |
+
|
| 298 |
+
- It is a RandomForestClassifier trained on all the training data and 10% of the test data. 10% of the test data. The 10% of test data is necessary as the distribution of the test data is very different from the training data.
|
| 299 |
+
- Since 10% of test data is used while training, this data is not used while claculating the final accuracy of the model, which is 100%.
|
| 300 |
+
|
| 301 |
+
Parameters
|
| 302 |
+
----------
|
| 303 |
+
model_save_path : str, optional
|
| 304 |
+
The path to save the final model, by default "final_model.pkl"
|
| 305 |
+
Returns
|
| 306 |
+
-------
|
| 307 |
+
None
|
| 308 |
+
Examples
|
| 309 |
+
--------
|
| 310 |
+
>>> train_and_save_final_model()
|
| 311 |
+
>>> train_and_save_final_model(model_save_path="final_model.pkl")
|
| 312 |
+
"""
|
| 313 |
+
print("Creating Embeddings...")
|
| 314 |
+
model_save_path = os.path.join(FILE_DIR, model_save_path)
|
| 315 |
+
ds = PDFDataSet(fraction_test_data_in_train=0.1)
|
| 316 |
+
train_dataset, validation_dataset, test_dataset = ds.create_embeddings()
|
| 317 |
+
train_dataset = concatenate_datasets([train_dataset, validation_dataset])
|
| 318 |
+
X_train = train_dataset["embeddings"]
|
| 319 |
+
X_test = test_dataset["embeddings"]
|
| 320 |
+
y_train = train_dataset["label"]
|
| 321 |
+
y_test = test_dataset["label"]
|
| 322 |
+
|
| 323 |
+
print("Training and evaluating the model...")
|
| 324 |
+
rfc_best_params = {
|
| 325 |
+
"max_depth": 23,
|
| 326 |
+
"max_features": "log2",
|
| 327 |
+
"n_estimators": 469,
|
| 328 |
+
}
|
| 329 |
+
rfc_model = RandomForestClassifier(**rfc_best_params)
|
| 330 |
+
rfc_model.fit(X_train, y_train)
|
| 331 |
+
evaluate_model(
|
| 332 |
+
y_true=y_train,
|
| 333 |
+
y_pred=rfc_model.predict(X_train),
|
| 334 |
+
model_name="Final Model",
|
| 335 |
+
split="train",
|
| 336 |
+
)
|
| 337 |
+
evaluate_model(
|
| 338 |
+
y_true=y_test,
|
| 339 |
+
y_pred=rfc_model.predict(X_test),
|
| 340 |
+
model_name="Final Model",
|
| 341 |
+
split="test",
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
print("Saving the model...")
|
| 345 |
+
with open(model_save_path, "wb") as f:
|
| 346 |
+
pickle.dump(rfc_model, f)
|
| 347 |
+
print(f"Model saved to: {model_save_path}")
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def main(args):
|
| 351 |
+
task = args.task
|
| 352 |
+
if task == "train":
|
| 353 |
+
model_save_path = args.model_save_path
|
| 354 |
+
train_and_save_final_model(model_save_path=model_save_path)
|
| 355 |
+
elif task == "evaluate":
|
| 356 |
+
fraction_test_data_in_train = args.fraction
|
| 357 |
+
evaluate_models(fraction_test_data_in_train)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
if __name__ == "__main__":
|
| 361 |
+
parser = argparse.ArgumentParser(description="Train and evaluate models")
|
| 362 |
+
parser.add_argument(
|
| 363 |
+
"--task",
|
| 364 |
+
type=str,
|
| 365 |
+
choices=["train", "evaluate"],
|
| 366 |
+
required=True,
|
| 367 |
+
help="Whether to train and save the best model or evaluate all the models.",
|
| 368 |
+
)
|
| 369 |
+
parser.add_argument(
|
| 370 |
+
"--fraction",
|
| 371 |
+
type=float,
|
| 372 |
+
default=0.1,
|
| 373 |
+
help="Fraction of test data in train dataset",
|
| 374 |
+
)
|
| 375 |
+
parser.add_argument(
|
| 376 |
+
"--model_save_path",
|
| 377 |
+
type=str,
|
| 378 |
+
default="final_model.pkl",
|
| 379 |
+
help="Path to save the final model",
|
| 380 |
+
)
|
| 381 |
+
args = parser.parse_args()
|
| 382 |
+
main(args)
|
requirements.txt
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy==1.26.2
|
| 2 |
+
pandas==2.1.3
|
| 3 |
+
pypdf==3.17.1
|
| 4 |
+
pytesseract==0.3.10
|
| 5 |
+
pdf2image==1.16.3
|
| 6 |
+
pdfminer.six==20221105
|
| 7 |
+
pdfplumber==0.10.3
|
| 8 |
+
Pillow==10.1.0
|
| 9 |
+
PyPDF2==3.0.1
|
| 10 |
+
pypdfium2==4.24.0
|
| 11 |
+
pytesseract==0.3.10
|
| 12 |
+
requests==2.31.0
|
| 13 |
+
torch==2.1.1
|
| 14 |
+
transformers==4.35.2
|
| 15 |
+
datasets==2.15.0
|
| 16 |
+
optuna==3.4.0
|
| 17 |
+
requests==2.31.0
|
| 18 |
+
scikit-learn==1.3.2
|
| 19 |
+
scipy==1.11.4
|
| 20 |
+
torchmetrics==1.2.1
|
| 21 |
+
tqdm==4.66.1
|
| 22 |
+
tokenizers==0.15.0
|
| 23 |
+
xgboost==2.0.2
|
| 24 |
+
gradio
|
torch_train.py
ADDED
|
@@ -0,0 +1,543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class TorchTrain:
|
| 6 |
+
"""A class for training a model in PyTorch.
|
| 7 |
+
|
| 8 |
+
Parameters
|
| 9 |
+
-----------
|
| 10 |
+
model (torch.nn.Module): The PyTorch model to train.
|
| 11 |
+
optimizer (torch.optim.Optimizer): The optimizer to use for training.
|
| 12 |
+
loss_function (callable): The loss function to use for training.
|
| 13 |
+
metrics (dict or callable, optional): The metrics to evaluate during training.
|
| 14 |
+
If a dictionary, the keys are the metric names and the values are functions that
|
| 15 |
+
take in `yhat` and `y` and return a metric value. If a callable, it should take
|
| 16 |
+
in `yhat` and `y` and return a metric value. Defaults to None.
|
| 17 |
+
|
| 18 |
+
Attributes
|
| 19 |
+
-----------
|
| 20 |
+
DEVICE (torch.device): The device to use for training (cuda if available, cpu otherwise).
|
| 21 |
+
model (torch.nn.Module): The PyTorch model being trained.
|
| 22 |
+
optimizer (torch.optim.Optimizer): The optimizer being used for training.
|
| 23 |
+
loss_function (callable): The loss function being used for training.
|
| 24 |
+
metrics (dict or callable): The metrics being evaluated during training.
|
| 25 |
+
metrics_evaluated (dict): The metrics evaluated during training.
|
| 26 |
+
train_loss (float): The average training loss.
|
| 27 |
+
test_loss (float): The average test loss.
|
| 28 |
+
train_iteration (int): The number of training iterations.
|
| 29 |
+
test_iteration (int): The number of test iterations.
|
| 30 |
+
train_metrics (dict): The metrics evaluated on the training data.
|
| 31 |
+
test_metrics (dict): The metrics evaluated on the test data.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
model,
|
| 39 |
+
optimizer,
|
| 40 |
+
loss_function,
|
| 41 |
+
metrics=None,
|
| 42 |
+
scheduler=None,
|
| 43 |
+
task_type="classification",
|
| 44 |
+
) -> None:
|
| 45 |
+
"""Initialize the TorchTrain object.
|
| 46 |
+
|
| 47 |
+
Parameters
|
| 48 |
+
-----------
|
| 49 |
+
model : torch.nn.Module
|
| 50 |
+
The PyTorch model to train.
|
| 51 |
+
optimizer : torch.optim.Optimizer
|
| 52 |
+
The optimizer to use for training.
|
| 53 |
+
loss_function : callable
|
| 54 |
+
The loss function to use for training.
|
| 55 |
+
metrics : dict or callable, optional
|
| 56 |
+
The metrics to evaluate during training. If a dictionary, the keys are the metric names
|
| 57 |
+
and the values are functions that take in `yhat` and `y` and return a metric value.
|
| 58 |
+
If a callable, it should take in `yhat` and `y` and return a metric value. Defaults to None.
|
| 59 |
+
scheduler : torch.optim.lr_scheduler, optional
|
| 60 |
+
The learning rate scheduler to use for training. Defaults to None.
|
| 61 |
+
"""
|
| 62 |
+
self.model = model
|
| 63 |
+
self.model.to(self.DEVICE)
|
| 64 |
+
self.optimizer = optimizer
|
| 65 |
+
self.loss_function = loss_function
|
| 66 |
+
self.metrics = self.__preprocess_metrics(metrics)
|
| 67 |
+
self.scheduler = scheduler
|
| 68 |
+
self.metrics_evaluated = {}
|
| 69 |
+
self.train_loss = 0
|
| 70 |
+
self.test_loss = 0
|
| 71 |
+
self.train_iteration = 0
|
| 72 |
+
self.test_iteration = 0
|
| 73 |
+
self.train_metrics = {}
|
| 74 |
+
self.test_metrics = {}
|
| 75 |
+
self.history = {}
|
| 76 |
+
self.train_loss_all = []
|
| 77 |
+
self.test_loss_all = []
|
| 78 |
+
self.train_metrics_all = []
|
| 79 |
+
self.test_metrics_all = []
|
| 80 |
+
self.__train_scaled = False
|
| 81 |
+
self.__test_scaled = False
|
| 82 |
+
self.task_type = task_type
|
| 83 |
+
|
| 84 |
+
def __preprocess_metrics(self, metrics):
|
| 85 |
+
"""Preprocesses the given metrics"""
|
| 86 |
+
if metrics is None:
|
| 87 |
+
return {}
|
| 88 |
+
if isinstance(metrics, dict):
|
| 89 |
+
return {key.title(): value for key, value in metrics.items()}
|
| 90 |
+
else:
|
| 91 |
+
raise TypeError(
|
| 92 |
+
"Metrics should be a dictionary of metrics or a function which takes yhat, y"
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
def __scale_matrices(self, loss, metrics, type="train"):
|
| 96 |
+
"""Scales the loss and metrics
|
| 97 |
+
|
| 98 |
+
Parameters
|
| 99 |
+
-----------
|
| 100 |
+
loss : float
|
| 101 |
+
The loss to scale
|
| 102 |
+
metrics : dict
|
| 103 |
+
The metrics to scale
|
| 104 |
+
type : str, optional
|
| 105 |
+
The type of scaling to do, either "train" or "test", by default "train"
|
| 106 |
+
|
| 107 |
+
Returns
|
| 108 |
+
--------
|
| 109 |
+
loss : float
|
| 110 |
+
The scaled loss
|
| 111 |
+
metrics : dict
|
| 112 |
+
The scaled metrics
|
| 113 |
+
"""
|
| 114 |
+
if type == "train" and not self.__train_scaled:
|
| 115 |
+
scale = self.train_iteration
|
| 116 |
+
self.__train_scaled = True
|
| 117 |
+
elif type == "test" and not self.__test_scaled:
|
| 118 |
+
scale = self.test_iteration
|
| 119 |
+
self.__test_scaled = True
|
| 120 |
+
else:
|
| 121 |
+
return loss, metrics
|
| 122 |
+
loss /= scale
|
| 123 |
+
for key in metrics:
|
| 124 |
+
metrics[key] /= scale
|
| 125 |
+
return loss, metrics
|
| 126 |
+
|
| 127 |
+
def __reset_counters(self):
|
| 128 |
+
"""Resets all the counters and loss objects for a new epoch"""
|
| 129 |
+
self.train_loss, self.train_metrics = self.__scale_matrices(
|
| 130 |
+
self.train_loss, self.train_metrics, type="train"
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
self.test_loss, self.test_metrics = self.__scale_matrices(
|
| 134 |
+
self.test_loss, self.test_metrics, type="test"
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
self.train_loss_all.append(self.train_loss)
|
| 138 |
+
self.train_loss = 0
|
| 139 |
+
|
| 140 |
+
self.test_loss_all.append(self.test_loss)
|
| 141 |
+
self.test_loss = 0
|
| 142 |
+
|
| 143 |
+
self.train_iteration = 0
|
| 144 |
+
self.test_iteration = 0
|
| 145 |
+
|
| 146 |
+
self.train_metrics_all.append(self.train_metrics)
|
| 147 |
+
self.train_metrics = {}
|
| 148 |
+
|
| 149 |
+
self.test_metrics_all.append(self.test_metrics)
|
| 150 |
+
self.test_metrics = {}
|
| 151 |
+
self.__train_scaled = False
|
| 152 |
+
self.__test_scaled = False
|
| 153 |
+
|
| 154 |
+
@property
|
| 155 |
+
def loss(self):
|
| 156 |
+
"""Returns the training loss"""
|
| 157 |
+
return self.train_loss_all[-1]
|
| 158 |
+
|
| 159 |
+
def __create_history(self):
|
| 160 |
+
"""Creates the history dictionary"""
|
| 161 |
+
history = {
|
| 162 |
+
"train_loss": self.train_loss_all,
|
| 163 |
+
"val_loss": self.test_loss_all,
|
| 164 |
+
}
|
| 165 |
+
for key, value in self.metrics.items():
|
| 166 |
+
history[f"train_{key.lower()}"] = []
|
| 167 |
+
history[f"val_{key.lower()}"] = []
|
| 168 |
+
|
| 169 |
+
for item in self.train_metrics_all:
|
| 170 |
+
for key, value in item.items():
|
| 171 |
+
history[f"train_{key.lower()}"].append(value)
|
| 172 |
+
|
| 173 |
+
for item in self.test_metrics_all:
|
| 174 |
+
for key, value in item.items():
|
| 175 |
+
history[f"val_{key.lower()}"].append(value)
|
| 176 |
+
return history
|
| 177 |
+
|
| 178 |
+
def __parse_val(self, val):
|
| 179 |
+
"""Parses the given value to a float"""
|
| 180 |
+
if isinstance(val, torch.Tensor):
|
| 181 |
+
val = val.item()
|
| 182 |
+
elif isinstance(val, np.ndarray):
|
| 183 |
+
val = float(val)
|
| 184 |
+
elif isinstance(val, (int, float)):
|
| 185 |
+
pass
|
| 186 |
+
else:
|
| 187 |
+
raise TypeError(
|
| 188 |
+
f"The given Metric function should return a tensor, numpy array, int, or float.\n\
|
| 189 |
+
Got {type(val)}"
|
| 190 |
+
)
|
| 191 |
+
return val
|
| 192 |
+
|
| 193 |
+
def _train_step(self, x, y):
|
| 194 |
+
"""Perform a single training step.
|
| 195 |
+
|
| 196 |
+
Parameters
|
| 197 |
+
----------
|
| 198 |
+
x : torch.Tensor
|
| 199 |
+
The input tensor.
|
| 200 |
+
y : torch.Tensor
|
| 201 |
+
The target tensor.
|
| 202 |
+
|
| 203 |
+
Returns
|
| 204 |
+
-------
|
| 205 |
+
tuple
|
| 206 |
+
A tuple containing the loss and the predicted output tensor.
|
| 207 |
+
"""
|
| 208 |
+
self.model.train()
|
| 209 |
+
yhat = self.model(x)
|
| 210 |
+
l = self.loss_function(yhat, y)
|
| 211 |
+
self.optimizer.zero_grad()
|
| 212 |
+
l.backward()
|
| 213 |
+
self.optimizer.step()
|
| 214 |
+
self.train_iteration += 1
|
| 215 |
+
return l.item(), yhat
|
| 216 |
+
|
| 217 |
+
def _test_step(self, x, y):
|
| 218 |
+
"""Perform a single testing step.
|
| 219 |
+
|
| 220 |
+
Parameters
|
| 221 |
+
----------
|
| 222 |
+
x : torch.Tensor
|
| 223 |
+
The input tensor.
|
| 224 |
+
y : torch.Tensor
|
| 225 |
+
The target tensor.
|
| 226 |
+
|
| 227 |
+
Returns
|
| 228 |
+
-------
|
| 229 |
+
tuple
|
| 230 |
+
A tuple containing the loss and the predicted output tensor.
|
| 231 |
+
"""
|
| 232 |
+
self.model.eval()
|
| 233 |
+
with torch.inference_mode():
|
| 234 |
+
yhat = self.model(x)
|
| 235 |
+
l = self.loss_function(yhat, y)
|
| 236 |
+
self.test_iteration += 1
|
| 237 |
+
return l.item(), yhat
|
| 238 |
+
|
| 239 |
+
def predict(self, x):
|
| 240 |
+
"""Make predictions on a batch of data.
|
| 241 |
+
|
| 242 |
+
Parameters
|
| 243 |
+
----------
|
| 244 |
+
x : torch.Tensor
|
| 245 |
+
The input tensor.
|
| 246 |
+
|
| 247 |
+
Returns
|
| 248 |
+
-------
|
| 249 |
+
torch.Tensor
|
| 250 |
+
The predicted output tensor.
|
| 251 |
+
"""
|
| 252 |
+
self.model.eval()
|
| 253 |
+
yhat = self.model(x)
|
| 254 |
+
if self.task_type == "classification":
|
| 255 |
+
if len(yhat.shape) == 1:
|
| 256 |
+
# round
|
| 257 |
+
yhat = torch.round(yhat)
|
| 258 |
+
yhat = yhat.unsqueeze(1)
|
| 259 |
+
else:
|
| 260 |
+
yhat = torch.argmax(yhat, dim=1)
|
| 261 |
+
|
| 262 |
+
return yhat
|
| 263 |
+
|
| 264 |
+
def __calculate_metrics(self, yhat, y):
|
| 265 |
+
"""Calculate the metrics for a batch of data.
|
| 266 |
+
|
| 267 |
+
Parameters
|
| 268 |
+
----------
|
| 269 |
+
yhat : torch.Tensor
|
| 270 |
+
The predicted output tensor.
|
| 271 |
+
y : torch.Tensor
|
| 272 |
+
The target tensor.
|
| 273 |
+
|
| 274 |
+
Returns
|
| 275 |
+
-------
|
| 276 |
+
dict
|
| 277 |
+
A dictionary containing the values of the metrics.
|
| 278 |
+
"""
|
| 279 |
+
metrics = {}
|
| 280 |
+
for key, metric in self.metrics.items():
|
| 281 |
+
val = metric(yhat, y)
|
| 282 |
+
if isinstance(val, torch.Tensor):
|
| 283 |
+
val = val.item()
|
| 284 |
+
elif isinstance(val, np.ndarray):
|
| 285 |
+
val = float(val)
|
| 286 |
+
elif isinstance(val, (int, float)):
|
| 287 |
+
pass
|
| 288 |
+
else:
|
| 289 |
+
raise TypeError(
|
| 290 |
+
f"Metric {key} should return a tensor, numpy array, int, or float"
|
| 291 |
+
)
|
| 292 |
+
metrics[key] = val
|
| 293 |
+
self.metrics_evaluated = metrics
|
| 294 |
+
return metrics
|
| 295 |
+
|
| 296 |
+
def __progress_bar(self, cur_iter, all_iter):
|
| 297 |
+
"""Creates a progress bar showing the progress of the current batch.
|
| 298 |
+
|
| 299 |
+
Parameters
|
| 300 |
+
----------
|
| 301 |
+
cur_iter : int
|
| 302 |
+
The current batch number.
|
| 303 |
+
all_iter : int
|
| 304 |
+
The total number of batches.
|
| 305 |
+
|
| 306 |
+
Returns
|
| 307 |
+
-------
|
| 308 |
+
str
|
| 309 |
+
The progress bar, in the form of "10/100[====----]".
|
| 310 |
+
"""
|
| 311 |
+
len_progress_bar = 20
|
| 312 |
+
progress = int((cur_iter + 1) / all_iter * len_progress_bar)
|
| 313 |
+
progress_bar = "=" * progress + "-" * (len_progress_bar - progress)
|
| 314 |
+
return f"[{progress_bar}]"
|
| 315 |
+
|
| 316 |
+
def progress(self, cur_iter, all_iter, loss, metrics, on="train"):
|
| 317 |
+
"""Prints a progress bar showing the progress of the current batch.
|
| 318 |
+
|
| 319 |
+
Parameters
|
| 320 |
+
----------
|
| 321 |
+
cur_iter : int
|
| 322 |
+
The current batch number.
|
| 323 |
+
all_iter : int
|
| 324 |
+
The total number of batches.
|
| 325 |
+
loss : float
|
| 326 |
+
The current loss. Should be averaged over all batches.
|
| 327 |
+
metrics : dict
|
| 328 |
+
The metrics evaluated on the current batch.
|
| 329 |
+
on : str, optional
|
| 330 |
+
Whether the progress bar is for the training or testing data. Defaults to "train".
|
| 331 |
+
|
| 332 |
+
Returns
|
| 333 |
+
-------
|
| 334 |
+
str
|
| 335 |
+
The progress bar, in the form of "10/100[====----]".
|
| 336 |
+
|
| 337 |
+
Notes
|
| 338 |
+
-----
|
| 339 |
+
The progress bar shows the progress of the current batch as a bar of equal signs ("=") and
|
| 340 |
+
hyphens ("-"). The length of the bar is fixed at 20 characters. The current batch number
|
| 341 |
+
and total number of batches are displayed at the beginning of the progress bar. The current
|
| 342 |
+
loss and any metrics evaluated on the current batch are displayed at the end of the progress
|
| 343 |
+
bar.
|
| 344 |
+
"""
|
| 345 |
+
# len_progress_bar = 20
|
| 346 |
+
# progress = int((cur_iter + 1) / all_iter * len_progress_bar)
|
| 347 |
+
# progress_bar = "=" * progress + "-" * (len_progress_bar - progress)
|
| 348 |
+
progress_bar = self.__progress_bar(cur_iter=cur_iter, all_iter=all_iter)
|
| 349 |
+
|
| 350 |
+
if on.lower() == "train":
|
| 351 |
+
iteration = self.train_iteration
|
| 352 |
+
prefix = f"Epoch {(self.current_epoch+1):2d}/{self.epochs:2d} Batch "
|
| 353 |
+
else:
|
| 354 |
+
iteration = self.test_iteration
|
| 355 |
+
prefix = "Epoch "
|
| 356 |
+
|
| 357 |
+
text = f"{prefix}{cur_iter:>4d}/{all_iter:>4d}{progress_bar} {on.title()} loss: {loss/iteration:.4f}"
|
| 358 |
+
for metric_name, metric_value in metrics.items():
|
| 359 |
+
text += f" | {on.title()} {metric_name}: {metric_value/iteration:.4f}"
|
| 360 |
+
|
| 361 |
+
return text
|
| 362 |
+
|
| 363 |
+
def update_metrics(self, cur_metrics, new_metrics):
|
| 364 |
+
"""Update the metrics with the values for a new batch of data.
|
| 365 |
+
|
| 366 |
+
Parameters
|
| 367 |
+
----------
|
| 368 |
+
cur_metrics : dict
|
| 369 |
+
The current values of the metrics.
|
| 370 |
+
new_metrics : dict
|
| 371 |
+
The values of the metrics for a new batch of data.
|
| 372 |
+
|
| 373 |
+
Returns
|
| 374 |
+
-------
|
| 375 |
+
dict
|
| 376 |
+
A dictionary containing the updated values of the metrics.
|
| 377 |
+
"""
|
| 378 |
+
for key, value in new_metrics.items():
|
| 379 |
+
if key not in cur_metrics:
|
| 380 |
+
cur_metrics[key] = value
|
| 381 |
+
else:
|
| 382 |
+
cur_metrics[key] += value
|
| 383 |
+
return cur_metrics
|
| 384 |
+
|
| 385 |
+
def fit(
|
| 386 |
+
self,
|
| 387 |
+
train_loader,
|
| 388 |
+
validation_data_loader=None,
|
| 389 |
+
epochs=1,
|
| 390 |
+
verbose=True,
|
| 391 |
+
train_steps_per_epoch=None,
|
| 392 |
+
validation_steps_per_epoch=None,
|
| 393 |
+
):
|
| 394 |
+
"""Fit the PyTorch model.
|
| 395 |
+
|
| 396 |
+
Parameters
|
| 397 |
+
----------
|
| 398 |
+
train_loader : torch.utils.data.DataLoader
|
| 399 |
+
The data loader for the training data.
|
| 400 |
+
validation_data_loader : torch.utils.data.DataLoader, optional
|
| 401 |
+
The data loader for the test data. Defaults to None.
|
| 402 |
+
epochs : int, optional
|
| 403 |
+
The number of epochs to train for. Defaults to 1.
|
| 404 |
+
verbose : bool, optional
|
| 405 |
+
Whether to print the training progress during training. Defaults to True.
|
| 406 |
+
train_steps_per_epoch : int, optional
|
| 407 |
+
The number of batches to train on per epoch. Defaults to None.
|
| 408 |
+
validation_steps_per_epoch : int, optional
|
| 409 |
+
The number of batches to test on per epoch. Defaults to None.
|
| 410 |
+
|
| 411 |
+
Returns
|
| 412 |
+
-------
|
| 413 |
+
None
|
| 414 |
+
|
| 415 |
+
Examples
|
| 416 |
+
--------
|
| 417 |
+
>>> model = MyModel()
|
| 418 |
+
>>> optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
| 419 |
+
>>> loss_function = nn.CrossEntropyLoss()
|
| 420 |
+
>>> scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)
|
| 421 |
+
>>> train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
|
| 422 |
+
>>> validation_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)
|
| 423 |
+
>>> trainer = TorchTrain(model, optimizer, loss_function, scheduler=scheduler)
|
| 424 |
+
>>> trainer.fit(train_loader, validation_data_loader=validation_data_loader, epochs=10, verbose=True)
|
| 425 |
+
"""
|
| 426 |
+
self.epochs = epochs
|
| 427 |
+
if train_steps_per_epoch is None:
|
| 428 |
+
train_steps_per_epoch = len(train_loader)
|
| 429 |
+
if validation_data_loader is not None:
|
| 430 |
+
if validation_steps_per_epoch is None:
|
| 431 |
+
validation_steps_per_epoch = len(validation_data_loader)
|
| 432 |
+
|
| 433 |
+
for epoch in range(epochs):
|
| 434 |
+
self.current_epoch = epoch
|
| 435 |
+
for i, (x, y) in enumerate(train_loader):
|
| 436 |
+
x = x.to(self.DEVICE)
|
| 437 |
+
if isinstance(y, list) or isinstance(y, tuple):
|
| 438 |
+
y = [y_.to(self.DEVICE) for y_ in y]
|
| 439 |
+
else:
|
| 440 |
+
y = y.to(self.DEVICE)
|
| 441 |
+
|
| 442 |
+
train_loss, yhat = self._train_step(x, y)
|
| 443 |
+
self.train_loss += train_loss
|
| 444 |
+
metrics = self.__calculate_metrics(yhat, y)
|
| 445 |
+
self.train_metrics = self.update_metrics(self.train_metrics, metrics)
|
| 446 |
+
|
| 447 |
+
b_progress = self.progress(
|
| 448 |
+
i + 1,
|
| 449 |
+
train_steps_per_epoch,
|
| 450 |
+
self.train_loss,
|
| 451 |
+
self.train_metrics,
|
| 452 |
+
on="train",
|
| 453 |
+
)
|
| 454 |
+
if i == train_steps_per_epoch - 1:
|
| 455 |
+
print(b_progress)
|
| 456 |
+
break
|
| 457 |
+
else:
|
| 458 |
+
if verbose:
|
| 459 |
+
print(b_progress, end="\r")
|
| 460 |
+
if validation_data_loader is not None:
|
| 461 |
+
for i, (x, y) in enumerate(validation_data_loader):
|
| 462 |
+
x = x.to(self.DEVICE)
|
| 463 |
+
if isinstance(y, list) or isinstance(y, tuple):
|
| 464 |
+
y = [y_.to(self.DEVICE) for y_ in y]
|
| 465 |
+
else:
|
| 466 |
+
y = y.to(self.DEVICE)
|
| 467 |
+
test_loss, yhat = self._test_step(x, y)
|
| 468 |
+
self.test_loss += test_loss
|
| 469 |
+
metrics = self.__calculate_metrics(yhat, y)
|
| 470 |
+
self.test_metrics = self.update_metrics(self.test_metrics, metrics)
|
| 471 |
+
if i == validation_steps_per_epoch - 1:
|
| 472 |
+
break
|
| 473 |
+
test_progress = self.progress(
|
| 474 |
+
epoch + 1,
|
| 475 |
+
epochs,
|
| 476 |
+
self.test_loss,
|
| 477 |
+
self.test_metrics,
|
| 478 |
+
on="test",
|
| 479 |
+
)
|
| 480 |
+
print(test_progress)
|
| 481 |
+
self.__reset_counters()
|
| 482 |
+
if self.scheduler is not None:
|
| 483 |
+
self.scheduler.step()
|
| 484 |
+
if verbose and self.scheduler is not None:
|
| 485 |
+
print(f"New Learning rate: {self.scheduler.get_last_lr()[0]:.6f}")
|
| 486 |
+
|
| 487 |
+
return self.__create_history()
|
| 488 |
+
|
| 489 |
+
def save(self, path):
|
| 490 |
+
"""Save the model to a file.
|
| 491 |
+
|
| 492 |
+
Parameters
|
| 493 |
+
----------
|
| 494 |
+
path : str
|
| 495 |
+
The path to the file to save the model to.
|
| 496 |
+
"""
|
| 497 |
+
torch.save(self.model.state_dict(), path)
|
| 498 |
+
|
| 499 |
+
def load(self, path):
|
| 500 |
+
"""Load the model from a file.
|
| 501 |
+
|
| 502 |
+
Parameters
|
| 503 |
+
----------
|
| 504 |
+
path : str
|
| 505 |
+
The path to the file to load the model from.
|
| 506 |
+
"""
|
| 507 |
+
self.model.load_state_dict(torch.load(path))
|
| 508 |
+
|
| 509 |
+
def evaluate(self, data_loader, metric):
|
| 510 |
+
"""Evaluate the model on a data loader and the given metric.
|
| 511 |
+
|
| 512 |
+
Parameters
|
| 513 |
+
----------
|
| 514 |
+
data_loader : torch.utils.data.DataLoader
|
| 515 |
+
The data loader to evaluate the model on.
|
| 516 |
+
metric : function
|
| 517 |
+
The metric to evaluate the model with.
|
| 518 |
+
|
| 519 |
+
Returns
|
| 520 |
+
-------
|
| 521 |
+
float
|
| 522 |
+
The score of the model on the given metric.
|
| 523 |
+
"""
|
| 524 |
+
running_score = 0
|
| 525 |
+
data_length = len(data_loader)
|
| 526 |
+
for i, (x, y) in enumerate(data_loader):
|
| 527 |
+
progress_bar = self.__progress_bar(i, data_length)
|
| 528 |
+
x = x.to(self.DEVICE)
|
| 529 |
+
if isinstance(y, list) or isinstance(y, tuple):
|
| 530 |
+
y = [y_.to(self.DEVICE) for y_ in y]
|
| 531 |
+
else:
|
| 532 |
+
y = y.to(self.DEVICE)
|
| 533 |
+
|
| 534 |
+
yhat = self.model(x)
|
| 535 |
+
yhat = torch.round(yhat)
|
| 536 |
+
score = metric(y, yhat)
|
| 537 |
+
score = self.__parse_val(score)
|
| 538 |
+
running_score += score
|
| 539 |
+
|
| 540 |
+
progress_bar = f"{i+1}/{data_length}" + progress_bar
|
| 541 |
+
progress_bar += f" Score: {(running_score/(i+1)):4f}"
|
| 542 |
+
print(progress_bar, end="\r")
|
| 543 |
+
return running_score / (len(data_loader))
|
utilities.py
ADDED
|
@@ -0,0 +1,481 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import PyPDF2
|
| 2 |
+
from pdfminer.high_level import extract_pages
|
| 3 |
+
from pdfminer.layout import LTTextContainer, LTFigure
|
| 4 |
+
import pdfplumber
|
| 5 |
+
from pdf2image import convert_from_bytes
|
| 6 |
+
|
| 7 |
+
import pytesseract
|
| 8 |
+
import logging
|
| 9 |
+
import requests
|
| 10 |
+
from io import BytesIO
|
| 11 |
+
|
| 12 |
+
import signal
|
| 13 |
+
import time
|
| 14 |
+
import functools
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
import time
|
| 18 |
+
import signal
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ModuleException(Exception):
|
| 22 |
+
def __init__(self, *args: object) -> None:
|
| 23 |
+
super().__init__(*args)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class TimeoutError(ModuleException):
|
| 27 |
+
def __init__(self, *args: object) -> None:
|
| 28 |
+
super().__init__(*args)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class BadUrlException(ModuleException):
|
| 32 |
+
"""Raised when the url is not valid"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, message):
|
| 35 |
+
self.message = message
|
| 36 |
+
super().__init__(self.message)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def timeout(seconds_before_timeout):
|
| 40 |
+
def wrapper_wrapper(func):
|
| 41 |
+
def handler(signum, frame):
|
| 42 |
+
raise TimeoutError()
|
| 43 |
+
|
| 44 |
+
@functools.wraps(func)
|
| 45 |
+
def wrapper_function(*args, **kwargs):
|
| 46 |
+
old = signal.signal(signal.SIGALRM, handler)
|
| 47 |
+
old_time_left = signal.alarm(seconds_before_timeout)
|
| 48 |
+
if (
|
| 49 |
+
0 < old_time_left < second_before_timeout
|
| 50 |
+
): # never lengthen existing timer
|
| 51 |
+
signal.alarm(old_time_left)
|
| 52 |
+
start_time = time.time()
|
| 53 |
+
try:
|
| 54 |
+
result = func(*args, **kwargs)
|
| 55 |
+
finally:
|
| 56 |
+
if old_time_left > 0: # deduct f's run time from the saved timer
|
| 57 |
+
old_time_left -= time.time() - start_time
|
| 58 |
+
signal.signal(signal.SIGALRM, old)
|
| 59 |
+
signal.alarm(old_time_left)
|
| 60 |
+
return result
|
| 61 |
+
|
| 62 |
+
return wrapper_function
|
| 63 |
+
|
| 64 |
+
return wrapper_wrapper
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def get_simple_logger(name, level="info"):
|
| 68 |
+
"""Creates a simple loger that outputs to stdout"""
|
| 69 |
+
level_to_int_map = {
|
| 70 |
+
"debug": logging.DEBUG,
|
| 71 |
+
"info": logging.INFO,
|
| 72 |
+
"warning": logging.WARNING,
|
| 73 |
+
"error": logging.ERROR,
|
| 74 |
+
"critical": logging.CRITICAL,
|
| 75 |
+
}
|
| 76 |
+
if isinstance(level, str):
|
| 77 |
+
level = level_to_int_map[level.lower()]
|
| 78 |
+
logger = logging.getLogger(name)
|
| 79 |
+
logger.setLevel(level)
|
| 80 |
+
formatter = logging.Formatter(
|
| 81 |
+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
| 82 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 83 |
+
)
|
| 84 |
+
if logger.hasHandlers():
|
| 85 |
+
logger.handlers.clear()
|
| 86 |
+
handler = logging.StreamHandler()
|
| 87 |
+
handler.setLevel(level)
|
| 88 |
+
handler.setFormatter(formatter)
|
| 89 |
+
logger.addHandler(handler)
|
| 90 |
+
return logger
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class PDFExtractor:
|
| 94 |
+
"""A class to extract pdf information"""
|
| 95 |
+
|
| 96 |
+
def __init__(
|
| 97 |
+
self,
|
| 98 |
+
file_path,
|
| 99 |
+
min_characters=5,
|
| 100 |
+
maximum_pages=3,
|
| 101 |
+
is_url=True,
|
| 102 |
+
) -> None:
|
| 103 |
+
"""A class that can be used to extract pdf information from a file path or url
|
| 104 |
+
|
| 105 |
+
Parameters
|
| 106 |
+
----------
|
| 107 |
+
file_path : str
|
| 108 |
+
The path to the file or the url
|
| 109 |
+
min_characters : int, optional
|
| 110 |
+
The minimum number of characters that a text needs to have to be considered a relevant text, by default 5
|
| 111 |
+
maximum_pages : int, optional
|
| 112 |
+
The maximum number of pages that the pdf can have, by default 5
|
| 113 |
+
is_url : bool, optional
|
| 114 |
+
Whether the file_path is a url or not, by default True
|
| 115 |
+
|
| 116 |
+
Raises
|
| 117 |
+
------
|
| 118 |
+
ValueError
|
| 119 |
+
If the url raises a status code other than 200
|
| 120 |
+
|
| 121 |
+
Returns
|
| 122 |
+
-------
|
| 123 |
+
None
|
| 124 |
+
|
| 125 |
+
Examples
|
| 126 |
+
--------
|
| 127 |
+
>>> extractor = PDFExtractor(file_path="file_path", min_characters=5, maximum_pages=5, is_url=False)
|
| 128 |
+
>>> extractor.extract_pages()
|
| 129 |
+
"""
|
| 130 |
+
self.file_path = file_path
|
| 131 |
+
self.min_characters = min_characters
|
| 132 |
+
self.maximum_pages = maximum_pages
|
| 133 |
+
self.logger = get_simple_logger(name="pdf_extractor", level="info")
|
| 134 |
+
try:
|
| 135 |
+
self.byte_file = self._create_byte_object(self.file_path, is_url)
|
| 136 |
+
except Exception as e:
|
| 137 |
+
self.logger.error(e)
|
| 138 |
+
raise BadUrlException(str(e))
|
| 139 |
+
# PyPDF2 object
|
| 140 |
+
self.pypdf_object = PyPDF2.PdfReader(self.byte_file)
|
| 141 |
+
self.pdf_plumber_object = pdfplumber.open(self.byte_file)
|
| 142 |
+
# Pages from pdfminer
|
| 143 |
+
self.pdfminer_pages = extract_pages(self.byte_file)
|
| 144 |
+
self.page_contents = []
|
| 145 |
+
self.final_content = ""
|
| 146 |
+
self.one_page_contents = {
|
| 147 |
+
"contents": [],
|
| 148 |
+
"images": [],
|
| 149 |
+
"tables": [],
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
@timeout(120)
|
| 153 |
+
def _create_byte_object(self, path, is_url):
|
| 154 |
+
"""Creates the byte file based on if it is url or not"""
|
| 155 |
+
if is_url:
|
| 156 |
+
res = requests.get(path)
|
| 157 |
+
if res.status_code == 200:
|
| 158 |
+
byte_file = BytesIO(res.content)
|
| 159 |
+
else:
|
| 160 |
+
raise BadUrlException(f"The url raised status code: {res.status_code}")
|
| 161 |
+
else:
|
| 162 |
+
byte_file = open(path, "rb")
|
| 163 |
+
|
| 164 |
+
return byte_file
|
| 165 |
+
|
| 166 |
+
def __check_for_relevant_text(self, text):
|
| 167 |
+
"""Checks if the text is a relevant text or not
|
| 168 |
+
|
| 169 |
+
Parameters
|
| 170 |
+
----------
|
| 171 |
+
text : str
|
| 172 |
+
The text to be checked
|
| 173 |
+
|
| 174 |
+
Returns
|
| 175 |
+
-------
|
| 176 |
+
bool
|
| 177 |
+
True if the text is a relevant text, False otherwise
|
| 178 |
+
"""
|
| 179 |
+
# Remove the line breaker from the text
|
| 180 |
+
text_ = text.replace("\n", "")
|
| 181 |
+
if len(text_) > self.min_characters:
|
| 182 |
+
self.logger.debug(f"Text: {text_} is a relevant text.")
|
| 183 |
+
return True
|
| 184 |
+
self.logger.debug(f"Text: {text_} is not a relevant text.")
|
| 185 |
+
return False
|
| 186 |
+
|
| 187 |
+
def _handle_text(self, component):
|
| 188 |
+
"""Handles the extraction of textual information
|
| 189 |
+
|
| 190 |
+
Parameters
|
| 191 |
+
----------
|
| 192 |
+
component : pdfplumber.page.Page
|
| 193 |
+
The page object to extract the text from
|
| 194 |
+
|
| 195 |
+
Returns
|
| 196 |
+
-------
|
| 197 |
+
bool
|
| 198 |
+
True if the text is a relevant text, False otherwise
|
| 199 |
+
|
| 200 |
+
"""
|
| 201 |
+
text = component.get_text()
|
| 202 |
+
if self.__check_for_relevant_text(text):
|
| 203 |
+
self.one_page_contents["contents"].append(text)
|
| 204 |
+
return True
|
| 205 |
+
return False
|
| 206 |
+
|
| 207 |
+
def __table_converter(self, table):
|
| 208 |
+
"""Converts table from the output given by pdf_plumnber to make it more readble
|
| 209 |
+
|
| 210 |
+
Parameters
|
| 211 |
+
----------
|
| 212 |
+
table : list
|
| 213 |
+
The table to be converted
|
| 214 |
+
|
| 215 |
+
Returns
|
| 216 |
+
-------
|
| 217 |
+
str
|
| 218 |
+
The converted table as a string
|
| 219 |
+
|
| 220 |
+
Examples
|
| 221 |
+
--------
|
| 222 |
+
>>> table = [["Name", "Age"], ["John", "30"], ["Jane", "25"]]
|
| 223 |
+
>>> table_string = __table_converter(table)
|
| 224 |
+
>>> print(table_string)
|
| 225 |
+
|Name|Age|
|
| 226 |
+
|John|30|
|
| 227 |
+
|Jane|25|
|
| 228 |
+
|
| 229 |
+
>>> table = [["Name", "Age"], ["John", "30"], ["Jane", "25"], ["\n", "None"]]
|
| 230 |
+
>>> table_string = __table_converter(table)
|
| 231 |
+
>>> print(table_string)
|
| 232 |
+
|Name|Age|
|
| 233 |
+
|John|30|
|
| 234 |
+
|Jane|25|
|
| 235 |
+
|None|None|
|
| 236 |
+
|
| 237 |
+
>>> table = [["Name", "Age"], ["John", "30"], ["Jane", "25"], ["\n", "None"], ["\n", "None"]]
|
| 238 |
+
>>> table_string = __table_converter(table)
|
| 239 |
+
"""
|
| 240 |
+
table_string = ""
|
| 241 |
+
# Iterate through each row of the table
|
| 242 |
+
for row_num in range(len(table)):
|
| 243 |
+
row = table[row_num]
|
| 244 |
+
# Remove the line breaker from the wrapted texts
|
| 245 |
+
cleaned_row = [
|
| 246 |
+
item.replace("\n", " ")
|
| 247 |
+
if item is not None and "\n" in item
|
| 248 |
+
else "None"
|
| 249 |
+
if item is None
|
| 250 |
+
else item
|
| 251 |
+
for item in row
|
| 252 |
+
]
|
| 253 |
+
# Convert the table into a string
|
| 254 |
+
table_string += "|" + "|".join(cleaned_row) + "|" + "\n"
|
| 255 |
+
# Removing the last line break
|
| 256 |
+
table_string = table_string[:-1]
|
| 257 |
+
return table_string
|
| 258 |
+
|
| 259 |
+
def _handle_table(self, page_number):
|
| 260 |
+
"""Handles the extraction of tabular information
|
| 261 |
+
|
| 262 |
+
Parameters
|
| 263 |
+
----------
|
| 264 |
+
page_number : int
|
| 265 |
+
The page number to extract the table from
|
| 266 |
+
|
| 267 |
+
Returns
|
| 268 |
+
-------
|
| 269 |
+
None
|
| 270 |
+
"""
|
| 271 |
+
page = self.pdf_plumber_object.pages[page_number]
|
| 272 |
+
try:
|
| 273 |
+
tables = page.extract_tables()
|
| 274 |
+
except:
|
| 275 |
+
tables = [
|
| 276 |
+
["NA", "NA"],
|
| 277 |
+
]
|
| 278 |
+
tables_final = [self.__table_converter(table) for table in tables]
|
| 279 |
+
self.one_page_contents["tables"] = tables_final
|
| 280 |
+
|
| 281 |
+
# Create a function to crop the image elements from PDFs
|
| 282 |
+
|
| 283 |
+
def __crop_image(self, element, page_number):
|
| 284 |
+
"""Crops the pdf and creates a new pdf with only the cropped area. This will later be converted into image and then OCRed using pytesseract
|
| 285 |
+
|
| 286 |
+
Parameters
|
| 287 |
+
----------
|
| 288 |
+
element : pdfminer.layout.LTTextContainer
|
| 289 |
+
The element to crop from the pdf
|
| 290 |
+
page_number : int
|
| 291 |
+
The page number to crop the element from
|
| 292 |
+
|
| 293 |
+
Returns
|
| 294 |
+
-------
|
| 295 |
+
bytes
|
| 296 |
+
The cropped pdf as a byte object
|
| 297 |
+
"""
|
| 298 |
+
pypdf_page = self.pypdf_object.pages[page_number]
|
| 299 |
+
# Get the coordinates to crop the image from PDF
|
| 300 |
+
[image_left, image_top, image_right, image_bottom] = [
|
| 301 |
+
element.x0,
|
| 302 |
+
element.y0,
|
| 303 |
+
element.x1,
|
| 304 |
+
element.y1,
|
| 305 |
+
]
|
| 306 |
+
# Crop the page using coordinates (left, bottom, right, top)
|
| 307 |
+
pypdf_page.mediabox.lower_left = (image_left, image_bottom)
|
| 308 |
+
pypdf_page.mediabox.upper_right = (image_right, image_top)
|
| 309 |
+
# Save the cropped page to a new PDF
|
| 310 |
+
cropped_pdf_writer = PyPDF2.PdfWriter()
|
| 311 |
+
cropped_pdf_writer.add_page(pypdf_page)
|
| 312 |
+
# convert to byte
|
| 313 |
+
cropped_pdf_stream = BytesIO()
|
| 314 |
+
cropped_pdf_writer.write(cropped_pdf_stream)
|
| 315 |
+
byte_object = cropped_pdf_stream.getvalue()
|
| 316 |
+
return byte_object
|
| 317 |
+
|
| 318 |
+
# Create a function to convert the PDF to images
|
| 319 |
+
def _convert_to_images(self, pdf_byte):
|
| 320 |
+
"""Converts the pdf byte object to images
|
| 321 |
+
|
| 322 |
+
Parameters
|
| 323 |
+
----------
|
| 324 |
+
pdf_byte : bytes
|
| 325 |
+
The pdf byte object to be converted
|
| 326 |
+
|
| 327 |
+
Returns
|
| 328 |
+
-------
|
| 329 |
+
PIL.Image
|
| 330 |
+
The converted image
|
| 331 |
+
"""
|
| 332 |
+
images = convert_from_bytes(pdf_byte)
|
| 333 |
+
image = images[0]
|
| 334 |
+
return image
|
| 335 |
+
|
| 336 |
+
@timeout(20)
|
| 337 |
+
def _image_to_text(self, image):
|
| 338 |
+
"""Extracts text from image using pytesseract
|
| 339 |
+
|
| 340 |
+
Parameters
|
| 341 |
+
----------
|
| 342 |
+
image : PIL.Image
|
| 343 |
+
The image to extract text from
|
| 344 |
+
|
| 345 |
+
Returns
|
| 346 |
+
-------
|
| 347 |
+
str
|
| 348 |
+
The extracted text from the image
|
| 349 |
+
|
| 350 |
+
Examples
|
| 351 |
+
--------
|
| 352 |
+
>>> image = PIL.Image.open("image.jpg")
|
| 353 |
+
>>> text = _image_to_text(image)
|
| 354 |
+
>>> print(text)
|
| 355 |
+
DUMMY TEXT
|
| 356 |
+
"""
|
| 357 |
+
text = pytesseract.image_to_string(image)
|
| 358 |
+
# text = "DUMMY TEXT"
|
| 359 |
+
self.logger.debug(f"Extracted {text} from the image.")
|
| 360 |
+
return text
|
| 361 |
+
|
| 362 |
+
def _handle_image(self, element, page_number):
|
| 363 |
+
"""Handles the extraction of image information
|
| 364 |
+
|
| 365 |
+
Parameters
|
| 366 |
+
----------
|
| 367 |
+
element : pdfminer.layout.LTFigure
|
| 368 |
+
The element to extract the image from
|
| 369 |
+
page_number : int
|
| 370 |
+
The page number to extract the image from
|
| 371 |
+
|
| 372 |
+
Returns
|
| 373 |
+
-------
|
| 374 |
+
bool
|
| 375 |
+
True if the image is a relevant image, False otherwise
|
| 376 |
+
|
| 377 |
+
Notes
|
| 378 |
+
-----
|
| 379 |
+
Extract the text from the image using pytesseract. Check if the text is a relevant text or not. If the text is a relevant text, add it to the one_page_contents dictionary with the key "images"
|
| 380 |
+
If the text is not a relevant text, do nothing
|
| 381 |
+
Return True if the image is a relevant image, False otherwise
|
| 382 |
+
If the image is a relevant image, add it to the one_page_contents dictionary with the key "images"
|
| 383 |
+
If the image is not a relevant image, do nothing
|
| 384 |
+
Return True if the image is a relevant image, False otherwise
|
| 385 |
+
"""
|
| 386 |
+
cropped_pdf = self.__crop_image(element, page_number)
|
| 387 |
+
image = self._convert_to_images(pdf_byte=cropped_pdf)
|
| 388 |
+
extracted_text = self._image_to_text(image)
|
| 389 |
+
# try:
|
| 390 |
+
# extracted_text = self._image_to_text(image)
|
| 391 |
+
# except TimeoutError:
|
| 392 |
+
# self.logger.warning("Timeout encountered. Skipping it.")
|
| 393 |
+
# return False
|
| 394 |
+
if self.__check_for_relevant_text(extracted_text):
|
| 395 |
+
self.one_page_contents["images"].append(extracted_text)
|
| 396 |
+
return True
|
| 397 |
+
return False
|
| 398 |
+
|
| 399 |
+
def extract_one_page(self, page_number, pdfminer_page):
|
| 400 |
+
"""Extracts information from one page of the pdf
|
| 401 |
+
|
| 402 |
+
Parameters
|
| 403 |
+
----------
|
| 404 |
+
page_number : int
|
| 405 |
+
The page number to extract the information from
|
| 406 |
+
pdfminer_page : pdfminer.layout.LTPage
|
| 407 |
+
The pdfminer page object to extract the information from
|
| 408 |
+
|
| 409 |
+
Returns
|
| 410 |
+
-------
|
| 411 |
+
None
|
| 412 |
+
"""
|
| 413 |
+
self.one_page_contents = {
|
| 414 |
+
"contents": [],
|
| 415 |
+
"images": [],
|
| 416 |
+
"tables": [],
|
| 417 |
+
}
|
| 418 |
+
self._handle_table(page_number=page_number)
|
| 419 |
+
max_image_per_page = 2
|
| 420 |
+
image_number = 0
|
| 421 |
+
for element_number, element in enumerate(pdfminer_page._objs):
|
| 422 |
+
type_ = type(element)
|
| 423 |
+
self.logger.debug(
|
| 424 |
+
f"Handling Page: {page_number}, Element: {element_number} Type: {type_}"
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
if isinstance(element, LTTextContainer):
|
| 428 |
+
self._handle_text(element)
|
| 429 |
+
|
| 430 |
+
if isinstance(element, LTFigure) and image_number < max_image_per_page:
|
| 431 |
+
added = self._handle_image(
|
| 432 |
+
element=element,
|
| 433 |
+
page_number=page_number,
|
| 434 |
+
)
|
| 435 |
+
if added:
|
| 436 |
+
image_number += 1
|
| 437 |
+
self.page_contents.append(self.one_page_contents)
|
| 438 |
+
|
| 439 |
+
def create_final_text_content(self):
|
| 440 |
+
"""Creates the final text using the information extracted so far. The final text has all the textual information, image information and tabular information. This text directly can be used for machine learning."""
|
| 441 |
+
final_content = ""
|
| 442 |
+
for page_number, content in enumerate(self.page_contents):
|
| 443 |
+
final_content += f"PAGE {page_number}\n"
|
| 444 |
+
text_contents = content["contents"]
|
| 445 |
+
final_content += "\n".join(text_contents)
|
| 446 |
+
for i, image in enumerate(content["images"]):
|
| 447 |
+
final_content += f"IMAGE {i}\n{image.strip()}\nIMAGE {i} ENDS\n"
|
| 448 |
+
|
| 449 |
+
for i, table in enumerate(content["tables"]):
|
| 450 |
+
final_content += f"TABLE {i}\n{table.strip()}\nTABLE {i} ENDS\n"
|
| 451 |
+
final_content += f"PAGE {page_number} ENDS\n"
|
| 452 |
+
self.final_content = final_content
|
| 453 |
+
return final_content
|
| 454 |
+
|
| 455 |
+
def extract_pages(self):
|
| 456 |
+
"""Extracts information from all the pages of the pdf. This is the final method to be used
|
| 457 |
+
|
| 458 |
+
Parameters
|
| 459 |
+
----------
|
| 460 |
+
None
|
| 461 |
+
|
| 462 |
+
Returns
|
| 463 |
+
-------
|
| 464 |
+
str
|
| 465 |
+
The final text content of the pdf. This text can be directly used for machine learning.
|
| 466 |
+
"""
|
| 467 |
+
pages = extract_pages(self.byte_file)
|
| 468 |
+
for page_number, page in enumerate(pages):
|
| 469 |
+
self.logger.info(f"Working on the page: {page_number}")
|
| 470 |
+
if page_number >= self.maximum_pages:
|
| 471 |
+
self.logger.info(f"Maximum page limit reached. Breaking...")
|
| 472 |
+
break
|
| 473 |
+
self.extract_one_page(page_number, page)
|
| 474 |
+
|
| 475 |
+
final_content = self.create_final_text_content()
|
| 476 |
+
return final_content
|
| 477 |
+
|
| 478 |
+
# def __del__(self):
|
| 479 |
+
# # Make sure that the file is closed once the object is deleted
|
| 480 |
+
# self.logger.debug("Closing open file and removing the temporary directory.")
|
| 481 |
+
# self.byte_file.close()
|