Spaces:
Sleeping
Sleeping
Update ai_mapping.py
Browse files- ai_mapping.py +27 -3
ai_mapping.py
CHANGED
|
@@ -39,14 +39,38 @@ def extract_key_values_with_layoutlm(text_data: str, pdf_path: str) -> Dict[str,
|
|
| 39 |
pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72)) # 300 DPI
|
| 40 |
img_path = f"{pdf_path}_page_{page_num}.png"
|
| 41 |
pix.save(img_path)
|
| 42 |
-
image = Image.open(img_path)
|
| 43 |
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
input_ids = encoding["input_ids"]
|
| 46 |
attention_mask = encoding["attention_mask"]
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
with torch.no_grad():
|
| 49 |
-
outputs = model(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
predictions = torch.argmax(outputs.logits, dim=2)
|
| 51 |
|
| 52 |
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
|
|
|
|
| 39 |
pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72)) # 300 DPI
|
| 40 |
img_path = f"{pdf_path}_page_{page_num}.png"
|
| 41 |
pix.save(img_path)
|
| 42 |
+
image = Image.open(img_path).convert("RGB")
|
| 43 |
|
| 44 |
+
# Tokenize text
|
| 45 |
+
words = text_data.splitlines()
|
| 46 |
+
encoding = tokenizer(
|
| 47 |
+
words,
|
| 48 |
+
return_tensors="pt",
|
| 49 |
+
truncation=True,
|
| 50 |
+
padding=True,
|
| 51 |
+
max_length=512
|
| 52 |
+
)
|
| 53 |
input_ids = encoding["input_ids"]
|
| 54 |
attention_mask = encoding["attention_mask"]
|
| 55 |
|
| 56 |
+
# Process image to get bounding boxes
|
| 57 |
+
image_encoding = feature_extractor(image, return_tensors="pt")
|
| 58 |
+
bbox = image_encoding["bbox"][0] # Shape: (num_tokens, 4)
|
| 59 |
+
|
| 60 |
+
# Ensure bbox length matches input_ids
|
| 61 |
+
if len(bbox) < len(input_ids[0]):
|
| 62 |
+
bbox = torch.cat([bbox, torch.zeros((len(input_ids[0]) - len(bbox), 4), dtype=torch.int64)])
|
| 63 |
+
elif len(bbox) > len(input_ids[0]):
|
| 64 |
+
bbox = bbox[:len(input_ids[0])]
|
| 65 |
+
|
| 66 |
+
# Pass inputs to the model
|
| 67 |
with torch.no_grad():
|
| 68 |
+
outputs = model(
|
| 69 |
+
input_ids=input_ids,
|
| 70 |
+
attention_mask=attention_mask,
|
| 71 |
+
bbox=bbox,
|
| 72 |
+
pixel_values=image_encoding["pixel_values"]
|
| 73 |
+
)
|
| 74 |
predictions = torch.argmax(outputs.logits, dim=2)
|
| 75 |
|
| 76 |
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
|