Spaces:
Sleeping
Sleeping
Update ai_mapping.py
Browse files- ai_mapping.py +21 -13
ai_mapping.py
CHANGED
|
@@ -15,17 +15,18 @@ tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base")
|
|
| 15 |
feature_extractor = LayoutLMv3ImageProcessor(apply_ocr=False)
|
| 16 |
model = LayoutLMv3ForTokenClassification.from_pretrained("microsoft/layoutlmv3-base")
|
| 17 |
|
| 18 |
-
def extract_key_values_with_layoutlm(
|
| 19 |
"""
|
| 20 |
Extract key-value pairs from PDF text using LayoutLMv3-base or fallback to regex.
|
| 21 |
Args:
|
| 22 |
-
|
| 23 |
pdf_path (str): Path to the PDF file.
|
| 24 |
Returns:
|
| 25 |
dict: Key-value pairs extracted from the document.
|
| 26 |
"""
|
| 27 |
try:
|
| 28 |
-
# Fallback to regex
|
|
|
|
| 29 |
key_values = {}
|
| 30 |
dates = re.findall(r'\d{1,2}/\d{1,2}/\d{4}', text_data)
|
| 31 |
amounts = re.findall(r'\$\d{1,3}(?:,\d{3})*(?:\.\d{2})?', text_data)
|
|
@@ -34,17 +35,23 @@ def extract_key_values_with_layoutlm(text_data: str, pdf_path: str) -> Dict[str,
|
|
| 34 |
|
| 35 |
# Attempt LayoutLMv3 processing
|
| 36 |
doc = fitz.open(pdf_path)
|
| 37 |
-
for page_num in
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
page = doc[page_num]
|
| 39 |
pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72)) # 300 DPI
|
| 40 |
img_path = f"{pdf_path}_page_{page_num}.png"
|
| 41 |
pix.save(img_path)
|
| 42 |
image = Image.open(img_path).convert("RGB")
|
| 43 |
|
| 44 |
-
# Tokenize text
|
| 45 |
-
words =
|
|
|
|
| 46 |
encoding = tokenizer(
|
| 47 |
words,
|
|
|
|
| 48 |
return_tensors="pt",
|
| 49 |
truncation=True,
|
| 50 |
padding=True,
|
|
@@ -52,16 +59,17 @@ def extract_key_values_with_layoutlm(text_data: str, pdf_path: str) -> Dict[str,
|
|
| 52 |
)
|
| 53 |
input_ids = encoding["input_ids"]
|
| 54 |
attention_mask = encoding["attention_mask"]
|
|
|
|
| 55 |
|
| 56 |
-
# Process image
|
| 57 |
image_encoding = feature_extractor(image, return_tensors="pt")
|
| 58 |
-
|
| 59 |
|
| 60 |
# Ensure bbox length matches input_ids
|
| 61 |
-
if len(bbox) < len(input_ids[0]):
|
| 62 |
-
bbox = torch.cat([bbox, torch.zeros((len(input_ids[0]) - len(bbox), 4), dtype=torch.int64)])
|
| 63 |
-
elif len(bbox) > len(input_ids[0]):
|
| 64 |
-
bbox = bbox[:len(input_ids[0])]
|
| 65 |
|
| 66 |
# Pass inputs to the model
|
| 67 |
with torch.no_grad():
|
|
@@ -69,7 +77,7 @@ def extract_key_values_with_layoutlm(text_data: str, pdf_path: str) -> Dict[str,
|
|
| 69 |
input_ids=input_ids,
|
| 70 |
attention_mask=attention_mask,
|
| 71 |
bbox=bbox,
|
| 72 |
-
pixel_values=
|
| 73 |
)
|
| 74 |
predictions = torch.argmax(outputs.logits, dim=2)
|
| 75 |
|
|
|
|
| 15 |
feature_extractor = LayoutLMv3ImageProcessor(apply_ocr=False)
|
| 16 |
model = LayoutLMv3ForTokenClassification.from_pretrained("microsoft/layoutlmv3-base")
|
| 17 |
|
| 18 |
+
def extract_key_values_with_layoutlm(page_data: list, pdf_path: str) -> Dict[str, str]:
|
| 19 |
"""
|
| 20 |
Extract key-value pairs from PDF text using LayoutLMv3-base or fallback to regex.
|
| 21 |
Args:
|
| 22 |
+
page_data (list): List of dictionaries with 'text' (str) and 'bbox' (list of [x0, y0, x1, y1]) per page.
|
| 23 |
pdf_path (str): Path to the PDF file.
|
| 24 |
Returns:
|
| 25 |
dict: Key-value pairs extracted from the document.
|
| 26 |
"""
|
| 27 |
try:
|
| 28 |
+
# Fallback to regex using concatenated text from all pages
|
| 29 |
+
text_data = " ".join([page["text"] for page in page_data])
|
| 30 |
key_values = {}
|
| 31 |
dates = re.findall(r'\d{1,2}/\d{1,2}/\d{4}', text_data)
|
| 32 |
amounts = re.findall(r'\$\d{1,3}(?:,\d{3})*(?:\.\d{2})?', text_data)
|
|
|
|
| 35 |
|
| 36 |
# Attempt LayoutLMv3 processing
|
| 37 |
doc = fitz.open(pdf_path)
|
| 38 |
+
for page_num, page_info in enumerate(page_data):
|
| 39 |
+
if not page_info["text"].strip() or "No text detected" in page_info["text"]:
|
| 40 |
+
continue
|
| 41 |
+
|
| 42 |
+
# Load image for the page
|
| 43 |
page = doc[page_num]
|
| 44 |
pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72)) # 300 DPI
|
| 45 |
img_path = f"{pdf_path}_page_{page_num}.png"
|
| 46 |
pix.save(img_path)
|
| 47 |
image = Image.open(img_path).convert("RGB")
|
| 48 |
|
| 49 |
+
# Tokenize text and prepare bounding boxes
|
| 50 |
+
words = page_info["text"].split()
|
| 51 |
+
bboxes = page_info["bbox"]
|
| 52 |
encoding = tokenizer(
|
| 53 |
words,
|
| 54 |
+
boxes=bboxes,
|
| 55 |
return_tensors="pt",
|
| 56 |
truncation=True,
|
| 57 |
padding=True,
|
|
|
|
| 59 |
)
|
| 60 |
input_ids = encoding["input_ids"]
|
| 61 |
attention_mask = encoding["attention_mask"]
|
| 62 |
+
bbox = encoding["bbox"]
|
| 63 |
|
| 64 |
+
# Process image for pixel values
|
| 65 |
image_encoding = feature_extractor(image, return_tensors="pt")
|
| 66 |
+
pixel_values = image_encoding["pixel_values"]
|
| 67 |
|
| 68 |
# Ensure bbox length matches input_ids
|
| 69 |
+
if len(bbox[0]) < len(input_ids[0]):
|
| 70 |
+
bbox = torch.cat([bbox, torch.zeros((1, len(input_ids[0]) - len(bbox[0]), 4), dtype=torch.int64)])
|
| 71 |
+
elif len(bbox[0]) > len(input_ids[0]):
|
| 72 |
+
bbox = bbox[:, :len(input_ids[0])]
|
| 73 |
|
| 74 |
# Pass inputs to the model
|
| 75 |
with torch.no_grad():
|
|
|
|
| 77 |
input_ids=input_ids,
|
| 78 |
attention_mask=attention_mask,
|
| 79 |
bbox=bbox,
|
| 80 |
+
pixel_values=pixel_values
|
| 81 |
)
|
| 82 |
predictions = torch.argmax(outputs.logits, dim=2)
|
| 83 |
|