isslao's picture
Update app.py
315c653 verified
import re
import gradio as gr
import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel
import os
from PIL import Image
processor = DonutProcessor.from_pretrained("debu-das/donut_receipt_v1.20")
model = VisionEncoderDecoderModel.from_pretrained("debu-das/donut_receipt_v1.20")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
def process_document(image):
if isinstance(image, str): # Si l'image est un chemin de fichier
image = Image.open(image).convert("RGB")
pixel_values = processor(image, return_tensors="pt").pixel_values
task_prompt = "<s_cord-v2>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
return processor.token2json(sequence)
def process_batch(images):
results = []
for image in images:
result = process_document(image)
results.append(result)
return results
description = "Démo Gradio pour Donut, une instance du modèle VisionEncoderDecoderModel affiné sur CORD (analyse de documents). Pour l'utiliser, téléchargez une ou plusieurs images et cliquez sur Submit, ou cliquez sur l'un des exemples pour les charger."
article = "Cette application permet maintenant de traiter plusieurs images de tickets de caisse à la fois."
demo = gr.Interface(
fn=process_batch, # Votre fonction de traitement par lot
inputs=gr.File(file_count="multiple", type="filepath", label="Téléchargez vos images de tickets de caisse"),
outputs="json",
title="Reconnaissance des tickets de caisse en lot 🧾",
description=description,
article=article,
examples=[
[["example.jpg"]],
[["example_1.jpg", "example_2.jpg"]],
[["example_3.jpg", "example_4.jpg"]]
],
cache_examples=False
)
port = int(os.environ.get("PORT", 7860))
demo.launch(server_name="0.0.0.0", server_port=port)