Spaces:
Runtime error
Runtime error
| import subprocess | |
| subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | |
| import os | |
| import gradio as gr | |
| from huggingface_hub import snapshot_download | |
| import spaces | |
| import torch | |
| from transformers import T5TokenizerFast | |
| from pix2struct.modeling import Pix2StructModel | |
| from pix2struct.processing import extract_patches | |
| from pix2struct.inference import ask_generator, generate, DocumentQueries, DocumentQuery | |
| hub_token = os.environ.get('HUB_TOKEN') | |
| model_path = snapshot_download('artyomxyz/pix2struct-docmatix', use_auth_token=hub_token) | |
| model = Pix2StructModel.load(model_path) | |
| model.eval() | |
| model = model.to('cuda') | |
| tokenizer = T5TokenizerFast.from_pretrained('google/pix2struct-base') | |
| def ask(image, questions): | |
| questions = questions.split('\n') | |
| documents = [ | |
| DocumentQueries( | |
| meta=None, | |
| patches=extract_patches([image]), | |
| queries=[ | |
| DocumentQuery( | |
| meta=None, | |
| generator=ask_generator(tokenizer, question) | |
| ) | |
| for question in questions | |
| ] | |
| ) | |
| ] | |
| with torch.inference_mode(): | |
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16): | |
| result = generate(model, documents, device='cuda') | |
| return '\n'.join([q.output for q in result[0].queries]) | |
| demo = gr.Interface( | |
| fn=ask, | |
| inputs=[ | |
| gr.Image(type='numpy'), | |
| gr.Textbox(label="Questions (one question per line)"), | |
| ], | |
| outputs='text' | |
| ) | |
| demo.launch() | |