artyomxyz commited on
Commit
4bab72e
·
1 Parent(s): 45a5093

generation

Browse files
Files changed (1) hide show
  1. app.py +15 -11
app.py CHANGED
@@ -4,25 +4,28 @@ import os
4
  import gradio as gr
5
  from huggingface_hub import snapshot_download
6
  import spaces
 
 
 
7
 
8
  from pix2struct.modeling import Pix2StructModel
9
  from pix2struct.processing import extract_patches
10
- from pix2struct.inference import ask_generator, generate
11
 
12
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
13
 
14
  hub_token = os.environ.get('HUB_TOKEN')
15
  model_path = snapshot_download('artyomxyz/pix2struct-docmatix', use_auth_token=hub_token)
16
- model = Pix2StructModel.load(model_path)
17
- model.eval()
18
 
19
  tokenizer = T5TokenizerFast.from_pretrained('google/pix2struct-base')
20
 
21
  @spaces.GPU
22
- def ask(image, question):
 
23
  accelerator = Accelerator(mixed_precision="bf16")
 
 
24
  model = accelerator.prepare(model)
25
-
26
  documents = [
27
  DocumentQueries(
28
  meta=None,
@@ -30,21 +33,22 @@ def ask(image, question):
30
  queries=[
31
  DocumentQuery(
32
  meta=None,
33
- generator=ask_generator(tokenizer, qa['question'])
34
  )
 
35
  ]
36
  )
37
  ]
38
- result = generate(model, documents, device=accelerator.device)
39
-
40
-
41
- return result[0].queries[0].output
42
 
43
  demo = gr.Interface(
44
  fn=ask,
45
  inputs=[
46
  gr.Image(type='numpy'),
47
- gr.Textbox(),
48
  ],
49
  outputs='text'
50
  )
 
4
  import gradio as gr
5
  from huggingface_hub import snapshot_download
6
  import spaces
7
+ import torch
8
+ from transformers import T5TokenizerFast
9
+ from accelerate import Accelerator
10
 
11
  from pix2struct.modeling import Pix2StructModel
12
  from pix2struct.processing import extract_patches
13
+ from pix2struct.inference import ask_generator, generate, DocumentQueries, DocumentQuery
14
 
15
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
16
 
17
  hub_token = os.environ.get('HUB_TOKEN')
18
  model_path = snapshot_download('artyomxyz/pix2struct-docmatix', use_auth_token=hub_token)
 
 
19
 
20
  tokenizer = T5TokenizerFast.from_pretrained('google/pix2struct-base')
21
 
22
  @spaces.GPU
23
+ def ask(image, questions):
24
+ questions = questions.split('\n')
25
  accelerator = Accelerator(mixed_precision="bf16")
26
+ model = Pix2StructModel.load(model_path)
27
+ model.eval()
28
  model = accelerator.prepare(model)
 
29
  documents = [
30
  DocumentQueries(
31
  meta=None,
 
33
  queries=[
34
  DocumentQuery(
35
  meta=None,
36
+ generator=ask_generator(tokenizer, question)
37
  )
38
+ for question in questions
39
  ]
40
  )
41
  ]
42
+ with torch.inference_mode():
43
+ with accelerator.autocast():
44
+ result = generate(model, documents, device=accelerator.device)
45
+ return '\n'.join([q.output for q in result[0].queries])
46
 
47
  demo = gr.Interface(
48
  fn=ask,
49
  inputs=[
50
  gr.Image(type='numpy'),
51
+ gr.Textbox(label="Questions (one question per line)"),
52
  ],
53
  outputs='text'
54
  )