rosemariafontana commited on
Commit
7e409fd
Β·
verified Β·
1 Parent(s): 7252bf3

update to extract date entity

Browse files
Files changed (1) hide show
  1. app.py +34 -0
app.py CHANGED
@@ -10,11 +10,42 @@ processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base")
10
  # More traditional approach that works from token classification basis (not questions)
11
  model = LayoutLMv3ForTokenClassification.from_pretrained("microsoft/layoutlmv3-base")
12
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 
13
  model.to(device)
14
 
15
  labels = model.config.id2label
16
  print(labels)
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  # process the image in the correct format
19
  # extract token classifications
20
  def parse_ticket_image(image):
@@ -69,6 +100,9 @@ def parse_ticket_image(image):
69
  min_length = min(len(fields), len(values))
70
  fields = fields[:min_length]
71
  values = values[:min_length]
 
 
 
72
 
73
  data = {
74
  "Field": fields,
 
10
  # More traditional approach that works from token classification basis (not questions)
11
  model = LayoutLMv3ForTokenClassification.from_pretrained("microsoft/layoutlmv3-base")
12
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
+
14
+ print(f"Debug -- Using device: {device}")
15
  model.to(device)
16
 
17
  labels = model.config.id2label
18
  print(labels)
19
 
20
+ # Homemade feature extraction
21
+ def extract_features(tokens, labels):
22
+ merged_entities = []
23
+ current_date = ""
24
+
25
+ # Loop through tokens and labels
26
+ for token, label in zip(tokens, labels):
27
+ if label === 'LABEL_1':
28
+
29
+ # Date logic
30
+ if re.match(r"^\d{1,2}/$", token) or re.match(r"^\d{4}$", token):
31
+ current_date += token
32
+
33
+ # Date logic
34
+ if re.match(r"^\d{4}$", token) and current_date.count('/') == 2:
35
+ merged_entities.append(current_date)
36
+ current_date = ""
37
+ else:
38
+ if current_date:
39
+ merged_entities.append(current_date)
40
+ current_date = ""
41
+ merged_entities.append(token)
42
+
43
+ if current_date:
44
+ merged_entities.append(current_date)
45
+
46
+ return merged_entities
47
+
48
+
49
  # process the image in the correct format
50
  # extract token classifications
51
  def parse_ticket_image(image):
 
100
  min_length = min(len(fields), len(values))
101
  fields = fields[:min_length]
102
  values = values[:min_length]
103
+
104
+ #Homemade feature extraction
105
+ values = extract_features(values, fields)
106
 
107
  data = {
108
  "Field": fields,