kebson commited on
Commit
f823607
·
verified ·
1 Parent(s): 71b8f4e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -24
app.py CHANGED
@@ -1,14 +1,15 @@
1
  import gradio as gr
2
- from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
  from PIL import Image
4
  import torch
 
5
 
6
  # ===============================
7
- # Charger le modèle TrOCR public
8
  # ===============================
9
- model_name = "microsoft/trocr-base-handwritten" # modèle OCR général
10
- processor = TrOCRProcessor.from_pretrained(model_name)
11
- model = VisionEncoderDecoderModel.from_pretrained(model_name)
12
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
  model.to(device)
@@ -17,29 +18,42 @@ model.to(device)
17
  # Fonction extraction colonne Description
18
  # ===============================
19
  def extract_description(image_pil):
20
- # OCR avec TrOCR
21
  pixel_values = processor(images=image_pil, return_tensors="pt").pixel_values.to(device)
22
- generated_ids = model.generate(pixel_values)
23
- ocr_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
24
 
25
- # Séparer le texte en lignes
26
- lines = [line.strip() for line in ocr_text.split("\n") if line.strip()]
 
27
 
28
- # Détection de la colonne Description via mot-clé
29
- desc_lines = []
30
- found_header = False
 
 
31
 
32
- for line in lines:
33
- if found_header:
34
- # toutes les lignes après le header sont considérées comme contenu de la colonne
35
- desc_lines.append(line)
36
- elif "description" in line.lower():
37
- found_header = True
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  if not desc_lines:
40
- return "❌ Colonne 'Description' non trouvée", ocr_text
41
  else:
42
- return "\n".join(desc_lines), ocr_text
43
 
44
  # ===============================
45
  # Interface Gradio
@@ -49,10 +63,10 @@ demo = gr.Interface(
49
  inputs=gr.Image(type="pil", label="Image de facture"),
50
  outputs=[
51
  gr.Textbox(label="📋 Colonne Description"),
52
- gr.Textbox(label="🛠 OCR complet pour debug")
53
  ],
54
- title="Extraction de la colonne Description (TrOCR public)",
55
- description="Détection automatique de la colonne Description dans les factures avec TrOCR"
56
  )
57
 
58
  demo.launch()
 
1
  import gradio as gr
2
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
3
  from PIL import Image
4
  import torch
5
+ import json
6
 
7
  # ===============================
8
+ # Charger le modèle Donut public
9
  # ===============================
10
+ model_name = "naver-clova-ocr-donut-base"
11
+ processor = DonutProcessor.from_pretrained(model_name, revision="main")
12
+ model = VisionEncoderDecoderModel.from_pretrained(model_name, revision="main")
13
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
  model.to(device)
 
18
  # Fonction extraction colonne Description
19
  # ===============================
20
  def extract_description(image_pil):
21
+ # Préparer l'image
22
  pixel_values = processor(images=image_pil, return_tensors="pt").pixel_values.to(device)
 
 
23
 
24
+ # Générer le texte
25
+ generated_ids = model.generate(pixel_values, max_length=1024)
26
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
27
 
28
+ # Donut renvoie souvent du JSON ou semi-structuré
29
+ try:
30
+ data = json.loads(generated_text)
31
+ except:
32
+ data = {"text": generated_text}
33
 
34
+ # Extraire les lignes contenant "Description"
35
+ desc_lines = []
36
+ if isinstance(data, dict):
37
+ for key, value in data.items():
38
+ if "description" in key.lower():
39
+ if isinstance(value, list):
40
+ desc_lines.extend(value)
41
+ else:
42
+ desc_lines.append(str(value))
43
+ else:
44
+ # fallback si Donut ne renvoie pas JSON
45
+ lines = generated_text.split("\n")
46
+ found_header = False
47
+ for line in lines:
48
+ if found_header:
49
+ desc_lines.append(line)
50
+ elif "description" in line.lower():
51
+ found_header = True
52
 
53
  if not desc_lines:
54
+ return "❌ Colonne 'Description' non trouvée", generated_text
55
  else:
56
+ return "\n".join(desc_lines), generated_text
57
 
58
  # ===============================
59
  # Interface Gradio
 
63
  inputs=gr.Image(type="pil", label="Image de facture"),
64
  outputs=[
65
  gr.Textbox(label="📋 Colonne Description"),
66
+ gr.Textbox(label="🛠 Texte complet Donut")
67
  ],
68
+ title="Extraction de la colonne Description (Donut)",
69
+ description="Détection automatique de la colonne Description dans les factures avec Donut"
70
  )
71
 
72
  demo.launch()