import re import gradio as gr import torch from transformers import DonutProcessor, VisionEncoderDecoderModel import os from PIL import Image processor = DonutProcessor.from_pretrained("debu-das/donut_receipt_v1.20") model = VisionEncoderDecoderModel.from_pretrained("debu-das/donut_receipt_v1.20") device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) def process_document(image): if isinstance(image, str): # Si l'image est un chemin de fichier image = Image.open(image).convert("RGB") pixel_values = processor(image, return_tensors="pt").pixel_values task_prompt = "" decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids outputs = model.generate( pixel_values.to(device), decoder_input_ids=decoder_input_ids.to(device), max_length=model.decoder.config.max_position_embeddings, early_stopping=True, pad_token_id=processor.tokenizer.pad_token_id, eos_token_id=processor.tokenizer.eos_token_id, use_cache=True, num_beams=1, bad_words_ids=[[processor.tokenizer.unk_token_id]], return_dict_in_generate=True, ) sequence = processor.batch_decode(outputs.sequences)[0] sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() return processor.token2json(sequence) def process_batch(images): results = [] for image in images: result = process_document(image) results.append(result) return results description = "Démo Gradio pour Donut, une instance du modèle VisionEncoderDecoderModel affiné sur CORD (analyse de documents). Pour l'utiliser, téléchargez une ou plusieurs images et cliquez sur Submit, ou cliquez sur l'un des exemples pour les charger." article = "Cette application permet maintenant de traiter plusieurs images de tickets de caisse à la fois." demo = gr.Interface( fn=process_batch, # Votre fonction de traitement par lot inputs=gr.File(file_count="multiple", type="filepath", label="Téléchargez vos images de tickets de caisse"), outputs="json", title="Reconnaissance des tickets de caisse en lot 🧾", description=description, article=article, examples=[ [["example.jpg"]], [["example_1.jpg", "example_2.jpg"]], [["example_3.jpg", "example_4.jpg"]] ], cache_examples=False ) port = int(os.environ.get("PORT", 7860)) demo.launch(server_name="0.0.0.0", server_port=port)