File size: 2,417 Bytes
e626df6
f823607
f7bee90
d4e0cc5
f823607
fe12926
fe5b596
f823607
fe5b596
f823607
 
 
f7bee90
30b2704
 
4acaff2
fe5b596
d4e0cc5
fe5b596
 
f823607
30b2704
 
f823607
 
 
30b2704
f823607
 
 
 
 
d4e0cc5
f823607
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30b2704
71b8f4e
f823607
71b8f4e
f823607
a6d8f18
fe5b596
 
 
6f02e5b
fe5b596
c794a72
fe5b596
 
f823607
fe5b596
f823607
 
6f02e5b
 
fe5b596
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import gradio as gr
from transformers import DonutProcessor, VisionEncoderDecoderModel
from PIL import Image
import torch
import json

# ===============================
# Charger le modèle Donut public
# ===============================
model_name = "naver-clova-ocr-donut-base"
processor = DonutProcessor.from_pretrained(model_name, revision="main")
model = VisionEncoderDecoderModel.from_pretrained(model_name, revision="main")

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# ===============================
# Fonction extraction colonne Description
# ===============================
def extract_description(image_pil):
    # Préparer l'image
    pixel_values = processor(images=image_pil, return_tensors="pt").pixel_values.to(device)

    # Générer le texte
    generated_ids = model.generate(pixel_values, max_length=1024)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

    # Donut renvoie souvent du JSON ou semi-structuré
    try:
        data = json.loads(generated_text)
    except:
        data = {"text": generated_text}

    # Extraire les lignes contenant "Description"
    desc_lines = []
    if isinstance(data, dict):
        for key, value in data.items():
            if "description" in key.lower():
                if isinstance(value, list):
                    desc_lines.extend(value)
                else:
                    desc_lines.append(str(value))
    else:
        # fallback si Donut ne renvoie pas JSON
        lines = generated_text.split("\n")
        found_header = False
        for line in lines:
            if found_header:
                desc_lines.append(line)
            elif "description" in line.lower():
                found_header = True

    if not desc_lines:
        return "❌ Colonne 'Description' non trouvée", generated_text
    else:
        return "\n".join(desc_lines), generated_text

# ===============================
# Interface Gradio
# ===============================
demo = gr.Interface(
    fn=extract_description,
    inputs=gr.Image(type="pil", label="Image de facture"),
    outputs=[
        gr.Textbox(label="📋 Colonne Description"),
        gr.Textbox(label="🛠 Texte complet Donut")
    ],
    title="Extraction de la colonne Description (Donut)",
    description="Détection automatique de la colonne Description dans les factures avec Donut"
)

demo.launch()