import gradio as gr from transformers import DonutProcessor, VisionEncoderDecoderModel from PIL import Image import torch import json # =============================== # Charger le modèle Donut public # =============================== model_name = "naver-clova-ocr-donut-base" processor = DonutProcessor.from_pretrained(model_name, revision="main") model = VisionEncoderDecoderModel.from_pretrained(model_name, revision="main") device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) # =============================== # Fonction extraction colonne Description # =============================== def extract_description(image_pil): # Préparer l'image pixel_values = processor(images=image_pil, return_tensors="pt").pixel_values.to(device) # Générer le texte generated_ids = model.generate(pixel_values, max_length=1024) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] # Donut renvoie souvent du JSON ou semi-structuré try: data = json.loads(generated_text) except: data = {"text": generated_text} # Extraire les lignes contenant "Description" desc_lines = [] if isinstance(data, dict): for key, value in data.items(): if "description" in key.lower(): if isinstance(value, list): desc_lines.extend(value) else: desc_lines.append(str(value)) else: # fallback si Donut ne renvoie pas JSON lines = generated_text.split("\n") found_header = False for line in lines: if found_header: desc_lines.append(line) elif "description" in line.lower(): found_header = True if not desc_lines: return "❌ Colonne 'Description' non trouvée", generated_text else: return "\n".join(desc_lines), generated_text # =============================== # Interface Gradio # =============================== demo = gr.Interface( fn=extract_description, inputs=gr.Image(type="pil", label="Image de facture"), outputs=[ gr.Textbox(label="📋 Colonne Description"), gr.Textbox(label="🛠 Texte complet Donut") ], title="Extraction de la colonne Description (Donut)", description="Détection automatique de la colonne Description dans les factures avec Donut" ) demo.launch()