|
|
import gradio as gr |
|
|
from transformers import AutoProcessor, AutoModelForVision2Seq |
|
|
from PIL import Image |
|
|
import torch |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
model_id = "prithivMLmods/Doc-VLMs-v2-Localization" |
|
|
processor = AutoProcessor.from_pretrained(model_id) |
|
|
model = AutoModelForVision2Seq.from_pretrained(model_id).to(device) |
|
|
|
|
|
|
|
|
def predict(image, text_input, system_prompt="Trích thông tin, không cần diễn giải"): |
|
|
image = image.convert("RGB") |
|
|
inputs = processor(images=image, text=text_input, return_tensors="pt").to(device) |
|
|
generated_ids = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=512, |
|
|
do_sample=False, |
|
|
eos_token_id=processor.tokenizer.eos_token_id, |
|
|
pad_token_id=processor.tokenizer.pad_token_id |
|
|
) |
|
|
result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
|
return result |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=predict, |
|
|
inputs=[ |
|
|
gr.Image(type="pil", label="Upload ảnh tài liệu"), |
|
|
gr.Textbox(label="Câu hỏi muốn hỏi mô hình"), |
|
|
gr.Textbox(label="System prompt (tuỳ chọn)", value="Trích thông tin, không cần diễn giải") |
|
|
], |
|
|
outputs="text", |
|
|
title="Doc-VLMs v2 - Vision Document QA" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|