Spaces:
Runtime error
Runtime error
Changed some values to be updated
Browse files
app.py
CHANGED
|
@@ -15,47 +15,40 @@ model.to(device)
|
|
| 15 |
labels = model.config.id2label
|
| 16 |
print(labels)
|
| 17 |
|
| 18 |
-
# helper function to unnormalize bounding boxes
|
| 19 |
-
def unnormalize_box(bbox, width, height):
|
| 20 |
-
return [
|
| 21 |
-
width * (bbox[0] / 1000),
|
| 22 |
-
height * (bbox[1] / 1000),
|
| 23 |
-
width * (bbox[2] / 1000),
|
| 24 |
-
height * (bbox[3] / 1000),
|
| 25 |
-
]
|
| 26 |
-
|
| 27 |
# process the image in the correct format
|
| 28 |
# extract token classifications
|
| 29 |
def parse_ticket_image(image):
|
|
|
|
|
|
|
| 30 |
if image:
|
| 31 |
document = image.convert("RGB") if image.mode != "RGB" else image
|
| 32 |
else:
|
| 33 |
print(f"Warning - no image or malformed image!")
|
| 34 |
return pd.DataFrame()
|
| 35 |
|
|
|
|
| 36 |
encoding = processor(document, return_tensors="pt", truncation=True)
|
| 37 |
|
|
|
|
| 38 |
for k, v in encoding.items():
|
| 39 |
encoding[k] = v.to(device)
|
| 40 |
|
|
|
|
| 41 |
outputs = model(**encoding)
|
| 42 |
|
|
|
|
| 43 |
predictions = outputs.logits.argmax(-1).squeeze().tolist()
|
| 44 |
-
token_boxes = encoding.bbox.squeeze().tolist()
|
| 45 |
|
| 46 |
input_ids = encoding.input_ids.squeeze().tolist()
|
| 47 |
words = [processor.tokenizer.decode(id) for id in input_ids]
|
| 48 |
|
| 49 |
-
|
| 50 |
-
true_predictions = []
|
| 51 |
-
true_boxes = []
|
| 52 |
|
| 53 |
for idx, pred in enumerate(predictions):
|
| 54 |
label = model.config.id2label[pred]
|
| 55 |
# apparently 'O' stands for non-entity tokens
|
| 56 |
if label != 'O':
|
| 57 |
-
|
| 58 |
-
true_boxes.append(unnormalize_box(token_boxes[idx], width, height))
|
| 59 |
|
| 60 |
if len(extracted_fields) == 0:
|
| 61 |
print(f"Warning - no fields were extracted!")
|
|
@@ -67,8 +60,8 @@ def parse_ticket_image(image):
|
|
| 67 |
values = values[:min_length]
|
| 68 |
|
| 69 |
data = {
|
| 70 |
-
"Field":
|
| 71 |
-
"Value":
|
| 72 |
}
|
| 73 |
df = pd.DataFrame(data)
|
| 74 |
|
|
|
|
| 15 |
labels = model.config.id2label
|
| 16 |
print(labels)
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
# process the image in the correct format
|
| 19 |
# extract token classifications
|
| 20 |
def parse_ticket_image(image):
|
| 21 |
+
|
| 22 |
+
# Process image
|
| 23 |
if image:
|
| 24 |
document = image.convert("RGB") if image.mode != "RGB" else image
|
| 25 |
else:
|
| 26 |
print(f"Warning - no image or malformed image!")
|
| 27 |
return pd.DataFrame()
|
| 28 |
|
| 29 |
+
# Encode document image
|
| 30 |
encoding = processor(document, return_tensors="pt", truncation=True)
|
| 31 |
|
| 32 |
+
# Move encoding to appropriate device
|
| 33 |
for k, v in encoding.items():
|
| 34 |
encoding[k] = v.to(device)
|
| 35 |
|
| 36 |
+
# Perform inference
|
| 37 |
outputs = model(**encoding)
|
| 38 |
|
| 39 |
+
# extract predictions
|
| 40 |
predictions = outputs.logits.argmax(-1).squeeze().tolist()
|
|
|
|
| 41 |
|
| 42 |
input_ids = encoding.input_ids.squeeze().tolist()
|
| 43 |
words = [processor.tokenizer.decode(id) for id in input_ids]
|
| 44 |
|
| 45 |
+
extracted_fields = []
|
|
|
|
|
|
|
| 46 |
|
| 47 |
for idx, pred in enumerate(predictions):
|
| 48 |
label = model.config.id2label[pred]
|
| 49 |
# apparently 'O' stands for non-entity tokens
|
| 50 |
if label != 'O':
|
| 51 |
+
extracted_fields.append((label, words[idx]))
|
|
|
|
| 52 |
|
| 53 |
if len(extracted_fields) == 0:
|
| 54 |
print(f"Warning - no fields were extracted!")
|
|
|
|
| 60 |
values = values[:min_length]
|
| 61 |
|
| 62 |
data = {
|
| 63 |
+
"Field": fields,
|
| 64 |
+
"Value": values
|
| 65 |
}
|
| 66 |
df = pd.DataFrame(data)
|
| 67 |
|