pavansuresh commited on
Commit
120db3b
·
verified ·
1 Parent(s): 83973ae

Update ai_mapping.py

Browse files
Files changed (1) hide show
  1. ai_mapping.py +27 -3
ai_mapping.py CHANGED
@@ -39,14 +39,38 @@ def extract_key_values_with_layoutlm(text_data: str, pdf_path: str) -> Dict[str,
39
  pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72)) # 300 DPI
40
  img_path = f"{pdf_path}_page_{page_num}.png"
41
  pix.save(img_path)
42
- image = Image.open(img_path)
43
 
44
- encoding = feature_extractor(images=[image], text=text_data.splitlines(), return_tensors="pt")
 
 
 
 
 
 
 
 
45
  input_ids = encoding["input_ids"]
46
  attention_mask = encoding["attention_mask"]
47
 
 
 
 
 
 
 
 
 
 
 
 
48
  with torch.no_grad():
49
- outputs = model(input_ids=input_ids, attention_mask=attention_mask)
 
 
 
 
 
50
  predictions = torch.argmax(outputs.logits, dim=2)
51
 
52
  tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
 
39
  pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72)) # 300 DPI
40
  img_path = f"{pdf_path}_page_{page_num}.png"
41
  pix.save(img_path)
42
+ image = Image.open(img_path).convert("RGB")
43
 
44
+ # Tokenize text
45
+ words = text_data.splitlines()
46
+ encoding = tokenizer(
47
+ words,
48
+ return_tensors="pt",
49
+ truncation=True,
50
+ padding=True,
51
+ max_length=512
52
+ )
53
  input_ids = encoding["input_ids"]
54
  attention_mask = encoding["attention_mask"]
55
 
56
+ # Process image to get bounding boxes
57
+ image_encoding = feature_extractor(image, return_tensors="pt")
58
+ bbox = image_encoding["bbox"][0] # Shape: (num_tokens, 4)
59
+
60
+ # Ensure bbox length matches input_ids
61
+ if len(bbox) < len(input_ids[0]):
62
+ bbox = torch.cat([bbox, torch.zeros((len(input_ids[0]) - len(bbox), 4), dtype=torch.int64)])
63
+ elif len(bbox) > len(input_ids[0]):
64
+ bbox = bbox[:len(input_ids[0])]
65
+
66
+ # Pass inputs to the model
67
  with torch.no_grad():
68
+ outputs = model(
69
+ input_ids=input_ids,
70
+ attention_mask=attention_mask,
71
+ bbox=bbox,
72
+ pixel_values=image_encoding["pixel_values"]
73
+ )
74
  predictions = torch.argmax(outputs.logits, dim=2)
75
 
76
  tokens = tokenizer.convert_ids_to_tokens(input_ids[0])