rosemariafontana commited on
Commit
d787446
Β·
verified Β·
1 Parent(s): 4d13fcf

changed model

Browse files
Files changed (1) hide show
  1. app.py +29 -34
app.py CHANGED
@@ -1,52 +1,47 @@
1
  import gradio as gr
2
  import pandas as pd
3
 
4
- #from transformers import pipeline
5
 
6
- from docquery import pipeline
7
- from docquery.document import load_document
8
 
 
 
9
 
10
- # Chatbot model
11
- #model = pipeline("document-question-answering", model="impira/layoutlm-document-qa")
12
 
13
- def construct_pipeline(task, model):
14
- global PIPELINES
15
- if model in PIPELINES:
16
- return PIPELINES[model]
17
 
18
- device = "cuda" if torch.cuda.is_available() else "cpu"
19
- ret = pipeline(task=task, model=CHECKPOINTS[model], device=device)
20
- PIPELINES[model] = ret
21
- return ret
22
-
23
- def run_pipeline(question, document):
24
- pipeline = construct_pipeline("document-question-answering", "impira/layoutlm-document-qa")
25
- return pipeline(question=question, **document.context, top_k=3)
26
-
27
- def process_question(question, document):
28
- if not question or document is None:
29
- return None, None, None
30
 
31
- text_value = None
32
- predictions = run_pipeline(question, document)
33
-
34
- for i, p in enumerate(ensure_list(predictions)):
35
- if i == 0:
36
- text_value = p["answer"]
37
- else:
38
- # Keep the code around to produce multiple boxes, but only show the top
39
- # prediction for now
40
- break
41
-
42
- return text_value
 
 
 
 
 
 
43
 
44
  def parse_ticket_image(image, question):
45
  """Basically just runs through these questions for the document"""
46
  # Processing the image
47
  if image:
48
  try:
49
- document = load_document(image.name)
50
  except Exception as e:
51
  traceback.print_exc()
52
  error = str(e)
 
1
  import gradio as gr
2
  import pandas as pd
3
 
4
+ from transformers import LayoutLMv2Processor, LayoutLMv3ForQuestionAnswering
5
 
6
+ processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv3-base")
7
+ model = LayoutLMv3ForQuestionAnswering.from_pretrained("microsoft/layoutlmv3-base")
8
 
9
+ def process_question(question, document):
10
+ encoding = processor(document, question, return_tensors="pt")
11
 
12
+ outputs = mode(**encoding)
 
13
 
14
+ predicted_start_idx = outputs.start_logits.argmax(-1).item()
15
+ predicted_end_idx = outputs.end_logits.argmax(-1).item()
 
 
16
 
17
+ answer_tokens = encoding.input_ids.squeeze()[predicted_start_idx: predicted_end_idx + 1]
18
+ answer = processor.tokenizer.decode(answer_tokens)
 
 
 
 
 
 
 
 
 
 
19
 
20
+ return answer
21
+
22
+ #def process_question(question, document):
23
+ # if not question or document is None:
24
+ # return None, None, None
25
+ #
26
+ # text_value = None
27
+ # predictions = run_pipeline(question, document)
28
+ #
29
+ # for i, p in enumerate(ensure_list(predictions)):
30
+ # if i == 0:
31
+ # text_value = p["answer"]
32
+ # else:
33
+ # # Keep the code around to produce multiple boxes, but only show the top
34
+ # # prediction for now
35
+ # break
36
+ #
37
+ # return text_value
38
 
39
  def parse_ticket_image(image, question):
40
  """Basically just runs through these questions for the document"""
41
  # Processing the image
42
  if image:
43
  try:
44
+ document = Image.open(image.name).convert("RGB")
45
  except Exception as e:
46
  traceback.print_exc()
47
  error = str(e)