Spaces:
Build error
Build error
File size: 2,626 Bytes
f69843d 0bef32f 1256e60 f69843d 1256e60 0bef32f f69843d 1256e60 0bef32f 1256e60 315c653 f69843d fec79ad 1256e60 58c58b9 b1310ec 315c653 fec79ad ef3d30e 315c653 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 | import re
import gradio as gr
import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel
import os
from PIL import Image
processor = DonutProcessor.from_pretrained("debu-das/donut_receipt_v1.20")
model = VisionEncoderDecoderModel.from_pretrained("debu-das/donut_receipt_v1.20")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
def process_document(image):
if isinstance(image, str): # Si l'image est un chemin de fichier
image = Image.open(image).convert("RGB")
pixel_values = processor(image, return_tensors="pt").pixel_values
task_prompt = "<s_cord-v2>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
return processor.token2json(sequence)
def process_batch(images):
results = []
for image in images:
result = process_document(image)
results.append(result)
return results
description = "Démo Gradio pour Donut, une instance du modèle VisionEncoderDecoderModel affiné sur CORD (analyse de documents). Pour l'utiliser, téléchargez une ou plusieurs images et cliquez sur Submit, ou cliquez sur l'un des exemples pour les charger."
article = "Cette application permet maintenant de traiter plusieurs images de tickets de caisse à la fois."
demo = gr.Interface(
fn=process_batch, # Votre fonction de traitement par lot
inputs=gr.File(file_count="multiple", type="filepath", label="Téléchargez vos images de tickets de caisse"),
outputs="json",
title="Reconnaissance des tickets de caisse en lot 🧾",
description=description,
article=article,
examples=[
[["example.jpg"]],
[["example_1.jpg", "example_2.jpg"]],
[["example_3.jpg", "example_4.jpg"]]
],
cache_examples=False
)
port = int(os.environ.get("PORT", 7860))
demo.launch(server_name="0.0.0.0", server_port=port) |