kebson commited on
Commit
30b2704
·
verified ·
1 Parent(s): e6d8b93

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -134
app.py CHANGED
@@ -1,144 +1,55 @@
1
  import gradio as gr
2
- import torch
3
- import cv2
4
- import pytesseract
5
- import numpy as np
6
  from PIL import Image
7
- from transformers import DetrImageProcessor, TableTransformerForObjectDetection
8
 
9
  # ===============================
10
- # Chargement des modèles
11
  # ===============================
12
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
13
 
14
- # Détection de tableau
15
- det_processor = DetrImageProcessor.from_pretrained(
16
- "microsoft/table-transformer-detection"
17
- )
18
- det_model = TableTransformerForObjectDetection.from_pretrained(
19
- "microsoft/table-transformer-detection"
20
- ).to(DEVICE)
21
-
22
- # Structure (cellules)
23
- struct_processor = DetrImageProcessor.from_pretrained(
24
- "microsoft/table-transformer-structure-recognition"
25
- )
26
- struct_model = TableTransformerForObjectDetection.from_pretrained(
27
- "microsoft/table-transformer-structure-recognition"
28
- ).to(DEVICE)
29
-
30
- # ===============================
31
- # OCR cellule
32
- # ===============================
33
- def ocr_cell(image):
34
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
35
- text = pytesseract.image_to_string(gray, config="--psm 6")
36
- return text.strip()
37
 
38
  # ===============================
39
- # Fonction principale
40
  # ===============================
41
  def extract_description(image_pil):
42
- # Convertir PIL -> np.array
43
- image = np.array(image_pil)
44
- h, w, _ = image.shape
45
-
46
- # ---- Détection du tableau ----
47
- inputs = det_processor(images=image, return_tensors="pt")
48
- inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
49
- outputs = det_model(**inputs)
50
-
51
- results = det_processor.post_process_object_detection(
52
- outputs,
53
- threshold=0.8,
54
- target_sizes=[(h, w)]
55
- )[0]
56
-
57
- tables = [
58
- box for box, label in zip(results["boxes"], results["labels"])
59
- if det_model.config.id2label[label.item()] == "table"
60
- ]
61
-
62
- if not tables:
63
- return "❌ Aucun tableau détecté", ""
64
-
65
- # Extraire premier tableau détecté
66
- table_box = tables[0].int().tolist()
67
- x0, y0, x1, y1 = table_box
68
- table_img = image[y0:y1, x0:x1]
69
-
70
- # ---- Optionnel : vérifier visuellement le tableau ----
71
- # Image.fromarray(table_img).show()
72
-
73
- # ---- Redimensionner le tableau pour la structure ----
74
- max_size = 1024
75
- scale = max(table_img.shape[:2]) / max_size
76
- new_w = int(table_img.shape[1] / scale)
77
- new_h = int(table_img.shape[0] / scale)
78
- table_resized = cv2.resize(table_img, (new_w, new_h))
79
-
80
- # ---- Structure du tableau ----
81
- inputs = struct_processor(images=table_resized, return_tensors="pt")
82
- inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
83
- outputs = struct_model(**inputs)
84
-
85
- results = struct_processor.post_process_object_detection(
86
- outputs,
87
- threshold=0.5, # seuil abaissé pour capturer plus de cellules
88
- target_sizes=[table_resized.shape[:2]]
89
- )[0]
90
-
91
- cells = []
92
- for box, label in zip(results["boxes"], results["labels"]):
93
- label_name = struct_model.config.id2label[label.item()]
94
- if label_name == "table cell":
95
- # Remettre les coordonnées à l'échelle originale
96
- scale_x = table_img.shape[1] / table_resized.shape[1]
97
- scale_y = table_img.shape[0] / table_resized.shape[0]
98
- x0c, y0c, x1c, y1c = box.int().tolist()
99
- x0c = int(x0c * scale_x)
100
- x1c = int(x1c * scale_x)
101
- y0c = int(y0c * scale_y)
102
- y1c = int(y1c * scale_y)
103
- cells.append([x0c, y0c, x1c, y1c])
104
-
105
- if not cells:
106
- return "❌ Aucune cellule détectée", ""
107
-
108
- # ---- Grouper par colonne (X) ----
109
- cells_sorted = sorted(cells, key=lambda b: (b[0] + b[2]) / 2)
110
- columns = {}
111
- for cell in cells_sorted:
112
- cx = (cell[0] + cell[2]) // 2
113
- columns.setdefault(cx // 50, []).append(cell)
114
-
115
- columns = list(columns.values())
116
- columns = sorted(columns, key=lambda col: np.mean([(c[0]+c[2])/2 for c in col]))
117
-
118
- # ---- OCR par colonne ----
119
- column_texts = []
120
- for col in columns:
121
- col_text = []
122
- for x0, y0, x1, y1 in sorted(col, key=lambda b: b[1]):
123
- cell_img = table_img[y0:y1, x0:x1]
124
- text = ocr_cell(cell_img)
125
- col_text.append(text)
126
- column_texts.append(col_text)
127
-
128
- # ---- Identifier colonne Description ----
129
- desc_col = None
130
- for col in column_texts:
131
- header = col[0].lower() if col else ""
132
- if "description" in header:
133
- desc_col = col
134
- break
135
-
136
- if desc_col is None:
137
- return "❌ Colonne 'Description' non trouvée", "\n\n".join(
138
- [f"Col {i}: " + " | ".join(col) for i, col in enumerate(column_texts)]
139
- )
140
-
141
- return "\n".join(desc_col[1:]), "\n\n".join(desc_col)
142
 
143
  # ===============================
144
  # Interface Gradio
@@ -148,10 +59,10 @@ demo = gr.Interface(
148
  inputs=gr.Image(type="pil", label="Image de facture"),
149
  outputs=[
150
  gr.Textbox(label="📋 Colonne Description"),
151
- gr.Textbox(label="🛠 Debug colonne détectée")
152
  ],
153
- title="Extraction de la colonne Description (Table Transformer)",
154
- description="Détection automatique de la colonne Description dans les tableaux de factures"
155
  )
156
 
157
  demo.launch()
 
1
  import gradio as gr
2
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
 
 
 
3
  from PIL import Image
4
+ import re
5
 
6
  # ===============================
7
+ # Charger le modèle pré-entraîné TrOCR
8
  # ===============================
9
+ model_name = "microsoft/trocr-base-table-finetuned" # Spécial tables
10
+ processor = TrOCRProcessor.from_pretrained(model_name)
11
+ model = VisionEncoderDecoderModel.from_pretrained(model_name)
12
 
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ model.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  # ===============================
17
+ # Fonction d'extraction de la colonne Description
18
  # ===============================
19
  def extract_description(image_pil):
20
+ # OCR avec TrOCR
21
+ pixel_values = processor(images=image_pil, return_tensors="pt").pixel_values.to(device)
22
+ generated_ids = model.generate(pixel_values)
23
+ ocr_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
24
+
25
+ # Séparer le texte en lignes
26
+ lines = [line.strip() for line in ocr_text.split("\n") if line.strip()]
27
+
28
+ # Identifier la colonne Description
29
+ desc_col = []
30
+ header_found = False
31
+ headers = []
32
+ # Détecter les headers possibles
33
+ if lines:
34
+ first_line = lines[0].lower()
35
+ # Split en colonnes par tabulation ou espaces multiples
36
+ headers = re.split(r"\t+|\s{2,}", first_line)
37
+ try:
38
+ desc_index = next(i for i, h in enumerate(headers) if "description" in h.lower())
39
+ header_found = True
40
+ except StopIteration:
41
+ desc_index = None
42
+
43
+ # Extraire les valeurs sous la colonne Description
44
+ if header_found:
45
+ for line in lines[1:]:
46
+ cols = re.split(r"\t+|\s{2,}", line)
47
+ if desc_index is not None and desc_index < len(cols):
48
+ desc_col.append(cols[desc_index])
49
+ else:
50
+ return "❌ Colonne 'Description' non trouvée", ocr_text
51
+
52
+ return "\n".join(desc_col), ocr_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  # ===============================
55
  # Interface Gradio
 
59
  inputs=gr.Image(type="pil", label="Image de facture"),
60
  outputs=[
61
  gr.Textbox(label="📋 Colonne Description"),
62
+ gr.Textbox(label="🛠 OCR complet pour debug")
63
  ],
64
+ title="Extraction de la colonne Description (TrOCR + tables)",
65
+ description="Détection automatique de la colonne Description dans les factures avec TrOCR"
66
  )
67
 
68
  demo.launch()