rosemariafontana commited on
Commit
623e388
Β·
verified Β·
1 Parent(s): cad3a8c

Changed some values to be updated

Browse files
Files changed (1) hide show
  1. app.py +10 -17
app.py CHANGED
@@ -15,47 +15,40 @@ model.to(device)
15
  labels = model.config.id2label
16
  print(labels)
17
 
18
- # helper function to unnormalize bounding boxes
19
- def unnormalize_box(bbox, width, height):
20
- return [
21
- width * (bbox[0] / 1000),
22
- height * (bbox[1] / 1000),
23
- width * (bbox[2] / 1000),
24
- height * (bbox[3] / 1000),
25
- ]
26
-
27
  # process the image in the correct format
28
  # extract token classifications
29
  def parse_ticket_image(image):
 
 
30
  if image:
31
  document = image.convert("RGB") if image.mode != "RGB" else image
32
  else:
33
  print(f"Warning - no image or malformed image!")
34
  return pd.DataFrame()
35
 
 
36
  encoding = processor(document, return_tensors="pt", truncation=True)
37
 
 
38
  for k, v in encoding.items():
39
  encoding[k] = v.to(device)
40
 
 
41
  outputs = model(**encoding)
42
 
 
43
  predictions = outputs.logits.argmax(-1).squeeze().tolist()
44
- token_boxes = encoding.bbox.squeeze().tolist()
45
 
46
  input_ids = encoding.input_ids.squeeze().tolist()
47
  words = [processor.tokenizer.decode(id) for id in input_ids]
48
 
49
- width, height = document.size
50
- true_predictions = []
51
- true_boxes = []
52
 
53
  for idx, pred in enumerate(predictions):
54
  label = model.config.id2label[pred]
55
  # apparently 'O' stands for non-entity tokens
56
  if label != 'O':
57
- true_predictions.append(label)
58
- true_boxes.append(unnormalize_box(token_boxes[idx], width, height))
59
 
60
  if len(extracted_fields) == 0:
61
  print(f"Warning - no fields were extracted!")
@@ -67,8 +60,8 @@ def parse_ticket_image(image):
67
  values = values[:min_length]
68
 
69
  data = {
70
- "Field": true_predictions,
71
- "Value": words[1:len(true_predictions)+1]
72
  }
73
  df = pd.DataFrame(data)
74
 
 
15
  labels = model.config.id2label
16
  print(labels)
17
 
 
 
 
 
 
 
 
 
 
18
  # process the image in the correct format
19
  # extract token classifications
20
  def parse_ticket_image(image):
21
+
22
+ # Process image
23
  if image:
24
  document = image.convert("RGB") if image.mode != "RGB" else image
25
  else:
26
  print(f"Warning - no image or malformed image!")
27
  return pd.DataFrame()
28
 
29
+ # Encode document image
30
  encoding = processor(document, return_tensors="pt", truncation=True)
31
 
32
+ # Move encoding to appropriate device
33
  for k, v in encoding.items():
34
  encoding[k] = v.to(device)
35
 
36
+ # Perform inference
37
  outputs = model(**encoding)
38
 
39
+ # extract predictions
40
  predictions = outputs.logits.argmax(-1).squeeze().tolist()
 
41
 
42
  input_ids = encoding.input_ids.squeeze().tolist()
43
  words = [processor.tokenizer.decode(id) for id in input_ids]
44
 
45
+ extracted_fields = []
 
 
46
 
47
  for idx, pred in enumerate(predictions):
48
  label = model.config.id2label[pred]
49
  # apparently 'O' stands for non-entity tokens
50
  if label != 'O':
51
+ extracted_fields.append((label, words[idx]))
 
52
 
53
  if len(extracted_fields) == 0:
54
  print(f"Warning - no fields were extracted!")
 
60
  values = values[:min_length]
61
 
62
  data = {
63
+ "Field": fields,
64
+ "Value": values
65
  }
66
  df = pd.DataFrame(data)
67