rosemariafontana commited on
Commit
f9937f7
Β·
verified Β·
1 Parent(s): a54945b

Changed to tokenization

Browse files
Files changed (1) hide show
  1. app.py +159 -94
app.py CHANGED
@@ -1,34 +1,98 @@
1
  import gradio as gr
2
  import pandas as pd
3
 
4
- from PIL import Image
5
- from transformers import LayoutLMv3Processor, LayoutLMv3ForQuestionAnswering
 
6
 
7
  processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base")
8
- model = LayoutLMv3ForQuestionAnswering.from_pretrained("microsoft/layoutlmv3-base")
9
 
10
- def process_question(question, document):
11
- #print(f"Debug - Processing Question: {question}")
12
-
13
- encoding = processor(document, question, return_tensors="pt")
14
- #print(f"Debug - Encoding Input IDs: {encoding.input_ids}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  outputs = model(**encoding)
17
- #print(f"Debug - Model Outputs: {outputs}")
18
 
19
- predicted_start_idx = outputs.start_logits.argmax(-1).item()
20
- predicted_end_idx = outputs.end_logits.argmax(-1).item()
21
 
22
- # Check if indices are valid
23
- if predicted_start_idx < 0 or predicted_end_idx < 0:
24
- print(f"Warning - Invalid prediction indices: start={predicted_start_idx}, end={predicted_end_idx}")
25
- return ""
26
 
27
- answer_tokens = encoding.input_ids.squeeze()[predicted_start_idx: predicted_end_idx + 1]
28
- answer = processor.tokenizer.decode(answer_tokens)
 
29
 
30
- return answer
31
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  #def process_question(question, document):
33
  # if not question or document is None:
34
  # return None, None, None
@@ -46,82 +110,83 @@ def process_question(question, document):
46
  #
47
  # return text_value
48
 
49
- def parse_ticket_image(image, question):
50
- """Basically just runs through these questions for the document"""
51
- # Processing the image
52
- if image:
53
- try:
54
- if image.mode != "RGB":
55
- document = image.convert("RGB")
56
- else:
57
- document = image
58
- except Exception as e:
59
- traceback.print_exc()
60
- error = str(e)
61
-
62
-
63
- # Define questions you want to ask the model
64
-
65
- questions = [
66
- "What is the ticket number?",
67
- "What is the type of grain (For example: corn, soybeans, wheat)?",
68
- "What is the date?",
69
- "What is the time?",
70
- "What is the gross weight?",
71
- "What is the tare weight?",
72
- "What is the net weight?",
73
- "What is the moisture (moist) percentage?",
74
- "What is the damage percentage?",
75
- "What is the gross units?",
76
- "What is the dock units?",
77
- "What is the comment?",
78
- "What is the assembly number?",
79
- ]
80
-
81
- # Use the model to answer each question
82
- answers = {}
83
- for q in questions:
84
- print(f"Question: {q}")
85
- answer_text = process_question(q, document)
86
- print(f"Answer Text extracted here: {answer_text}")
87
- answers[q] = answer_text
88
-
89
-
90
- ticket_number = answers["What is the ticket number?"]
91
- grain_type = answers["What is the type of grain (For example: corn, soybeans, wheat)?"]
92
- date = answers["What is the date?"]
93
- time = answers["What is the time?"]
94
- gross_weight = answers["What is the gross weight?"]
95
- tare_weight = answers["What is the tare weight?"]
96
- net_weight = answers["What is the net weight?"]
97
- moisture = answers["What is the moisture (moist) percentage?"]
98
- damage = answers["What is the damage percentage?"]
99
- gross_units = answers["What is the gross units?"]
100
- dock_units = answers["What is the dock units?"]
101
- comment = answers["What is the comment?"]
102
- assembly_number = answers["What is the assembly number?"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
-
105
- # Create a structured format (like a table) using pandas
106
- data = {
107
- "Ticket Number": [ticket_number],
108
- "Grain Type": [grain_type],
109
- "Assembly Number": [assembly_number],
110
- "Date": [date],
111
- "Time": [time],
112
- "Gross Weight": [gross_weight],
113
- "Tare Weight": [tare_weight],
114
- "Net Weight": [net_weight],
115
- "Moisture": [moisture],
116
- "Damage": [damage],
117
- "Gross Units": [gross_units],
118
- "Dock Units": [dock_units],
119
- "Comment": [comment],
120
- }
121
- df = pd.DataFrame(data)
122
-
123
- return df
124
-
125
 
126
  """
127
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
 
1
  import gradio as gr
2
  import pandas as pd
3
 
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ import torch
6
+ from transformers import LayoutLMv3Processor, LayoutLMv3ForQuestionAnswering, LayoutLMv3ForTokenClassification
7
 
8
  processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base")
 
9
 
10
+ # More traditional approach that works from token classification basis (not questions)
11
+ model = LayoutLMv3ForTokenClassification.from_pretrained("microsoft/layoutlmv3-base")
12
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
+ model.to(device)
14
+
15
+ labels = model.config.id2label
16
+ print(labels)
17
+
18
+ # helper function to unnormalize bounding boxes
19
+ def unnormalize_box(bbox, width, height):
20
+ return [
21
+ width * (bbox[0] / 1000),
22
+ height * (bbox[1] / 1000),
23
+ width * (bbox[2] / 1000),
24
+ height * (bbox[3] / 1000),
25
+ ]
26
+
27
+ # process the image in the correct format
28
+ # extract token classifications
29
+ def parse_ticket_image(image):
30
+ if image:
31
+ document = image.convert("RGB") if image.mode != "RGB" else image
32
+ else:
33
+ print(f"Warning - no image or malformed image!")
34
+ return pd.DataFrame()
35
+
36
+ encoding = processor(document, return_tensors="pt", truncation=True)
37
+
38
+ for k, v in encoding.items():
39
+ encoding[k] = v.to(device)
40
 
41
  outputs = model(**encoding)
 
42
 
43
+ predictions = outputs.logits.argmax(-1).squeeze().tolist()
44
+ token_boxes = encoding.bbox.squeeze().tolist()
45
 
46
+ input_ids = encoding.input_ids.squeeze().tolist()
47
+ words = [processor.tokenizer.decode(id) for id in input_ids]
 
 
48
 
49
+ width, height = document.size
50
+ true_predictions = []
51
+ true_boxes = []
52
 
53
+ for idx, pred in enumerate(predictions):
54
+ label = model.config.id2label[pred]
55
+ # apparently 'O' stands for non-entity tokens
56
+ if label != 'O':
57
+ true_predictions.append(label)
58
+ true_boxes.append(unnormalize_box(token_boxes[idx], width, height))
59
+
60
+ data = {
61
+ "Field": true_predictions,
62
+ "Value": words[1:len(true_predictions)+1]
63
+ }
64
+ df = pd.DataFrame(data)
65
+
66
+ return df
67
+
68
+
69
+ # This is how to use questions to find answers in the document
70
+ # Less traditional approach, less flexibility, easier to implement/understand (didnt provide robust answers)
71
+ #model = LayoutLMv3ForQuestionAnswering.from_pretrained("microsoft/layoutlmv3-base")
72
+
73
+ #def process_question(question, document):
74
+ # #print(f"Debug - Processing Question: {question}")
75
+ #
76
+ # encoding = processor(document, question, return_tensors="pt")
77
+ # #print(f"Debug - Encoding Input IDs: {encoding.input_ids}")
78
+ #
79
+ # outputs = model(**encoding)
80
+ # #print(f"Debug - Model Outputs: {outputs}")
81
+ #
82
+ # predicted_start_idx = outputs.start_logits.argmax(-1).item()
83
+ # predicted_end_idx = outputs.end_logits.argmax(-1).item()
84
+ #
85
+ # # Check if indices are valid
86
+ # if predicted_start_idx < 0 or predicted_end_idx < 0:
87
+ # print(f"Warning - Invalid prediction indices: start={predicted_start_idx}, end={predicted_end_idx}")
88
+ # return ""
89
+ #
90
+ # answer_tokens = encoding.input_ids.squeeze()[predicted_start_idx: predicted_end_idx + 1]
91
+ # answer = processor.tokenizer.decode(answer_tokens)
92
+ #
93
+ # return answer
94
+
95
+ # Older iteration of the code, retaining for emergencies ?
96
  #def process_question(question, document):
97
  # if not question or document is None:
98
  # return None, None, None
 
110
  #
111
  # return text_value
112
 
113
+ #def parse_ticket_image(image, question):
114
+ # """Basically just runs through these questions for the document"""
115
+ # # Processing the image
116
+ # if image:
117
+ # try:
118
+ # if image.mode != "RGB":
119
+ # document = image.convert("RGB")
120
+ # else:
121
+ # document = image
122
+ # except Exception as e:
123
+ # traceback.print_exc()
124
+ # error = str(e)
125
+ #
126
+ #
127
+ # # Define questions you want to ask the model
128
+ #
129
+ # questions = [
130
+ # "What is the ticket number?",
131
+ # "What is the type of grain (For example: corn, soybeans, wheat)?",
132
+ # "What is the date?",
133
+ # "What is the time?",
134
+ # "What is the gross weight?",
135
+ # "What is the tare weight?",
136
+ # "What is the net weight?",
137
+ # "What is the moisture (moist) percentage?",
138
+ # "What is the damage percentage?",
139
+ # "What is the gross units?",
140
+ # "What is the dock units?",
141
+ # "What is the comment?",
142
+ # "What is the assembly number?",
143
+ # ]
144
+ #
145
+ # # Use the model to answer each question
146
+ # answers = {}
147
+ # for q in questions:
148
+ # print(f"Question: {q}")
149
+ # answer_text = process_question(q, document)
150
+ # print(f"Answer Text extracted here: {answer_text}")
151
+ # answers[q] = answer_text
152
+ #
153
+ #
154
+ # ticket_number = answers["What is the ticket number?"]
155
+ # grain_type = answers["What is the type of grain (For example: corn, soybeans, wheat)?"]
156
+ # date = answers["What is the date?"]
157
+ # time = answers["What is the time?"]
158
+ # gross_weight = answers["What is the gross weight?"]
159
+ # tare_weight = answers["What is the tare weight?"]
160
+ # net_weight = answers["What is the net weight?"]
161
+ # moisture = answers["What is the moisture (moist) percentage?"]
162
+ # damage = answers["What is the damage percentage?"]
163
+ # gross_units = answers["What is the gross units?"]
164
+ # dock_units = answers["What is the dock units?"]
165
+ # comment = answers["What is the comment?"]
166
+ # assembly_number = answers["What is the assembly number?"]
167
+ #
168
+ #
169
+ # # Create a structured format (like a table) using pandas
170
+ # data = {
171
+ # "Ticket Number": [ticket_number],
172
+ # "Grain Type": [grain_type],
173
+ # "Assembly Number": [assembly_number],
174
+ # "Date": [date],
175
+ # "Time": [time],
176
+ # "Gross Weight": [gross_weight],
177
+ # "Tare Weight": [tare_weight],
178
+ # "Net Weight": [net_weight],
179
+ # "Moisture": [moisture],
180
+ # "Damage": [damage],
181
+ # "Gross Units": [gross_units],
182
+ # "Dock Units": [dock_units],
183
+ # "Comment": [comment],
184
+ # }
185
+ # df = pd.DataFrame(data)
186
+ #
187
+ # return df
188
+
189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
  """
192
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface