Spaces:
Runtime error
Runtime error
File size: 2,417 Bytes
e626df6 f823607 f7bee90 d4e0cc5 f823607 fe12926 fe5b596 f823607 fe5b596 f823607 f7bee90 30b2704 4acaff2 fe5b596 d4e0cc5 fe5b596 f823607 30b2704 f823607 30b2704 f823607 d4e0cc5 f823607 30b2704 71b8f4e f823607 71b8f4e f823607 a6d8f18 fe5b596 6f02e5b fe5b596 c794a72 fe5b596 f823607 fe5b596 f823607 6f02e5b fe5b596 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import gradio as gr
from transformers import DonutProcessor, VisionEncoderDecoderModel
from PIL import Image
import torch
import json
# ===============================
# Charger le modèle Donut public
# ===============================
model_name = "naver-clova-ocr-donut-base"
processor = DonutProcessor.from_pretrained(model_name, revision="main")
model = VisionEncoderDecoderModel.from_pretrained(model_name, revision="main")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# ===============================
# Fonction extraction colonne Description
# ===============================
def extract_description(image_pil):
# Préparer l'image
pixel_values = processor(images=image_pil, return_tensors="pt").pixel_values.to(device)
# Générer le texte
generated_ids = model.generate(pixel_values, max_length=1024)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# Donut renvoie souvent du JSON ou semi-structuré
try:
data = json.loads(generated_text)
except:
data = {"text": generated_text}
# Extraire les lignes contenant "Description"
desc_lines = []
if isinstance(data, dict):
for key, value in data.items():
if "description" in key.lower():
if isinstance(value, list):
desc_lines.extend(value)
else:
desc_lines.append(str(value))
else:
# fallback si Donut ne renvoie pas JSON
lines = generated_text.split("\n")
found_header = False
for line in lines:
if found_header:
desc_lines.append(line)
elif "description" in line.lower():
found_header = True
if not desc_lines:
return "❌ Colonne 'Description' non trouvée", generated_text
else:
return "\n".join(desc_lines), generated_text
# ===============================
# Interface Gradio
# ===============================
demo = gr.Interface(
fn=extract_description,
inputs=gr.Image(type="pil", label="Image de facture"),
outputs=[
gr.Textbox(label="📋 Colonne Description"),
gr.Textbox(label="🛠 Texte complet Donut")
],
title="Extraction de la colonne Description (Donut)",
description="Détection automatique de la colonne Description dans les factures avec Donut"
)
demo.launch() |