Spaces:
Build error
Build error
| import re | |
| import gradio as gr | |
| import torch | |
| from transformers import DonutProcessor, VisionEncoderDecoderModel | |
| import os | |
| from PIL import Image | |
| processor = DonutProcessor.from_pretrained("debu-das/donut_receipt_v1.20") | |
| model = VisionEncoderDecoderModel.from_pretrained("debu-das/donut_receipt_v1.20") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| def process_document(image): | |
| if isinstance(image, str): # Si l'image est un chemin de fichier | |
| image = Image.open(image).convert("RGB") | |
| pixel_values = processor(image, return_tensors="pt").pixel_values | |
| task_prompt = "<s_cord-v2>" | |
| decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids | |
| outputs = model.generate( | |
| pixel_values.to(device), | |
| decoder_input_ids=decoder_input_ids.to(device), | |
| max_length=model.decoder.config.max_position_embeddings, | |
| early_stopping=True, | |
| pad_token_id=processor.tokenizer.pad_token_id, | |
| eos_token_id=processor.tokenizer.eos_token_id, | |
| use_cache=True, | |
| num_beams=1, | |
| bad_words_ids=[[processor.tokenizer.unk_token_id]], | |
| return_dict_in_generate=True, | |
| ) | |
| sequence = processor.batch_decode(outputs.sequences)[0] | |
| sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") | |
| sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() | |
| return processor.token2json(sequence) | |
| def process_batch(images): | |
| results = [] | |
| for image in images: | |
| result = process_document(image) | |
| results.append(result) | |
| return results | |
| description = "Démo Gradio pour Donut, une instance du modèle VisionEncoderDecoderModel affiné sur CORD (analyse de documents). Pour l'utiliser, téléchargez une ou plusieurs images et cliquez sur Submit, ou cliquez sur l'un des exemples pour les charger." | |
| article = "Cette application permet maintenant de traiter plusieurs images de tickets de caisse à la fois." | |
| demo = gr.Interface( | |
| fn=process_batch, # Votre fonction de traitement par lot | |
| inputs=gr.File(file_count="multiple", type="filepath", label="Téléchargez vos images de tickets de caisse"), | |
| outputs="json", | |
| title="Reconnaissance des tickets de caisse en lot 🧾", | |
| description=description, | |
| article=article, | |
| examples=[ | |
| [["example.jpg"]], | |
| [["example_1.jpg", "example_2.jpg"]], | |
| [["example_3.jpg", "example_4.jpg"]] | |
| ], | |
| cache_examples=False | |
| ) | |
| port = int(os.environ.get("PORT", 7860)) | |
| demo.launch(server_name="0.0.0.0", server_port=port) |