kebson commited on
Commit
fe5b596
·
verified ·
1 Parent(s): 3a9c77b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -103
app.py CHANGED
@@ -1,136 +1,139 @@
1
- import os
2
- os.environ["OMP_NUM_THREADS"] = "1"
3
- os.environ["DISABLE_MODEL_SOURCE_CHECK"] = "True"
4
-
5
  import gradio as gr
 
6
  import cv2
 
7
  import numpy as np
8
- from paddleocr import PaddleOCR
9
  from PIL import Image
 
10
 
 
 
 
11
 
12
- ocr = PaddleOCR(lang="en")
13
-
14
 
15
- def extract_description_column(image: Image.Image):
16
- if image is None:
17
- return "❌ Aucune image fournie."
 
 
 
18
 
19
- img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
20
- result = ocr.ocr(img)
 
 
 
 
21
 
22
- if not result or not result[0]:
23
- return "❌ Aucun texte détecté."
 
24
 
25
- words = []
 
 
 
26
 
27
- # 1️⃣ OCR words
28
- for item in result[0]:
29
- box, (text, score) = item
30
- try:
31
- score = float(score)
32
- except:
33
- score = 1.0
34
 
35
- if score < 0.4 or not text.strip():
36
- continue
37
 
38
- xs = [p[0] for p in box]
39
- ys = [p[1] for p in box]
40
 
41
- words.append({
42
- "text": text.strip(),
43
- "x": min(xs),
44
- "y": min(ys),
45
- "w": max(xs) - min(xs),
46
- "h": max(ys) - min(ys),
47
- })
48
 
49
- # 2️⃣ Trouver le début du tableau ("ITEMS")
50
- table_start_y = None
51
- for w in words:
52
- if "item" in w["text"].lower():
53
- table_start_y = w["y"]
54
- break
55
 
56
- if table_start_y is None:
57
- table_start_y = 0 # fallback
 
 
58
 
59
- table_words = [w for w in words if w["y"] > table_start_y + 30]
 
60
 
61
- # 3️⃣ Regrouper par colonnes X
62
- columns = {}
63
- for w in table_words:
64
- col_key = int(w["x"] // 50)
65
- columns.setdefault(col_key, []).append(w)
66
 
67
- # 4️⃣ Identifier la colonne Description
68
- best_col = None
69
- best_score = 0
70
 
71
- for col in columns.values():
72
- text_len = sum(len(w["text"]) for w in col)
73
- numeric_ratio = sum(any(c.isdigit() for c in w["text"]) for w in col) / max(len(col), 1)
 
 
74
 
75
- score = text_len * (1 - numeric_ratio)
 
 
 
 
76
 
77
- if score > best_score:
78
- best_score = score
79
- best_col = col
80
 
81
- if best_col is None:
82
- return "❌ Impossible d’identifier la colonne Description."
83
 
84
- # 5️⃣ Regrouper par lignes
85
- lines = {}
86
- for w in best_col:
87
- key = int(w["y"] // 25)
88
- lines.setdefault(key, []).append(w)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- ordered_lines = []
91
- for k in sorted(lines.keys()):
92
- line = " ".join(
93
- w["text"] for w in sorted(lines[k], key=lambda x: x["x"])
94
  )
95
- ordered_lines.append(line)
96
-
97
- # 6️⃣ Nettoyage
98
- cleaned = []
99
- for line in ordered_lines:
100
- low = line.lower()
101
- if any(x in low for x in ["vat", "net", "gross", "each", "%"]):
102
- continue
103
- cleaned.append(line)
104
-
105
- # 7️⃣ Fusion multilignes
106
- cells = []
107
- buffer = ""
108
-
109
- for line in cleaned:
110
- if line[:2].replace(".", "").isdigit():
111
- if buffer:
112
- cells.append(buffer.strip())
113
- buffer = line.split(".", 1)[-1].strip()
114
- else:
115
- buffer += " " + line
116
-
117
- if buffer:
118
- cells.append(buffer.strip())
119
-
120
- # 8️⃣ Sortie
121
- output = ""
122
- for i, cell in enumerate(cells, 1):
123
- output += f"{i}. {cell}\n\n"
124
 
125
- return output.strip()
126
 
 
 
 
127
 
128
  demo = gr.Interface(
129
- fn=extract_description_column,
130
  inputs=gr.Image(type="pil", label="Image de facture"),
131
- outputs=gr.Textbox(lines=20, label="Colonne Description"),
132
- title="Extraction robuste de la colonne Description",
133
- description="Fonctionne sans dépendre des headers OCR"
 
 
 
134
  )
135
 
136
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
  import cv2
4
+ import pytesseract
5
  import numpy as np
 
6
  from PIL import Image
7
+ from transformers import DetrImageProcessor, TableTransformerForObjectDetection
8
 
9
+ # ===============================
10
+ # Chargement des modèles
11
+ # ===============================
12
 
13
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
14
 
15
+ det_processor = DetrImageProcessor.from_pretrained(
16
+ "microsoft/table-transformer-detection"
17
+ )
18
+ det_model = TableTransformerForObjectDetection.from_pretrained(
19
+ "microsoft/table-transformer-detection"
20
+ ).to(DEVICE)
21
 
22
+ struct_processor = DetrImageProcessor.from_pretrained(
23
+ "microsoft/table-transformer-structure-recognition"
24
+ )
25
+ struct_model = TableTransformerForObjectDetection.from_pretrained(
26
+ "microsoft/table-transformer-structure-recognition"
27
+ ).to(DEVICE)
28
 
29
+ # ===============================
30
+ # OCR cellule
31
+ # ===============================
32
 
33
+ def ocr_cell(image):
34
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
35
+ text = pytesseract.image_to_string(gray, config="--psm 6")
36
+ return text.strip()
37
 
38
+ # ===============================
39
+ # Fonction principale
40
+ # ===============================
 
 
 
 
41
 
42
+ def extract_description(image_pil):
 
43
 
44
+ image = np.array(image_pil)
45
+ h, w, _ = image.shape
46
 
47
+ # ---- Détection du tableau ----
48
+ inputs = det_processor(images=image_pil, return_tensors="pt").to(DEVICE)
49
+ outputs = det_model(**inputs)
 
 
 
 
50
 
51
+ results = det_processor.post_process_object_detection(
52
+ outputs,
53
+ threshold=0.8,
54
+ target_sizes=[(h, w)]
55
+ )[0]
 
56
 
57
+ tables = [
58
+ box for box, label in zip(results["boxes"], results["labels"])
59
+ if det_model.config.id2label[label.item()] == "table"
60
+ ]
61
 
62
+ if not tables:
63
+ return "❌ Aucun tableau détecté", ""
64
 
65
+ table_box = tables[0].int().tolist()
66
+ x0, y0, x1, y1 = table_box
67
+ table_img = image[y0:y1, x0:x1]
 
 
68
 
69
+ # ---- Structure du tableau ----
70
+ inputs = struct_processor(images=Image.fromarray(table_img), return_tensors="pt").to(DEVICE)
71
+ outputs = struct_model(**inputs)
72
 
73
+ results = struct_processor.post_process_object_detection(
74
+ outputs,
75
+ threshold=0.7,
76
+ target_sizes=[table_img.shape[:2]]
77
+ )[0]
78
 
79
+ cells = []
80
+ for box, label in zip(results["boxes"], results["labels"]):
81
+ label_name = struct_model.config.id2label[label.item()]
82
+ if label_name == "table cell":
83
+ cells.append(box.int().tolist())
84
 
85
+ if not cells:
86
+ return "❌ Aucune cellule détectée", ""
 
87
 
88
+ # ---- Grouper par colonne (X) ----
89
+ cells_sorted = sorted(cells, key=lambda b: (b[0] + b[2]) / 2)
90
 
91
+ columns = {}
92
+ for cell in cells_sorted:
93
+ cx = (cell[0] + cell[2]) // 2
94
+ columns.setdefault(cx // 50, []).append(cell)
95
+
96
+ columns = list(columns.values())
97
+ columns = sorted(columns, key=lambda col: np.mean([(c[0]+c[2])/2 for c in col]))
98
+
99
+ # ---- OCR par colonne ----
100
+ column_texts = []
101
+ for col in columns:
102
+ col_text = []
103
+ for x0, y0, x1, y1 in sorted(col, key=lambda b: b[1]):
104
+ cell_img = table_img[y0:y1, x0:x1]
105
+ text = ocr_cell(cell_img)
106
+ col_text.append(text)
107
+ column_texts.append(col_text)
108
+
109
+ # ---- Identifier colonne Description ----
110
+ desc_col = None
111
+ for col in column_texts:
112
+ header = col[0].lower() if col else ""
113
+ if "description" in header:
114
+ desc_col = col
115
+ break
116
 
117
+ if desc_col is None:
118
+ return "❌ Colonne 'Description' non trouvée", "\n\n".join(
119
+ [" | ".join(col) for col in column_texts]
 
120
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
+ return "\n".join(desc_col[1:]), "\n\n".join(desc_col)
123
 
124
+ # ===============================
125
+ # Interface Gradio
126
+ # ===============================
127
 
128
  demo = gr.Interface(
129
+ fn=extract_description,
130
  inputs=gr.Image(type="pil", label="Image de facture"),
131
+ outputs=[
132
+ gr.Textbox(label="📋 Colonne Description"),
133
+ gr.Textbox(label="🛠 Debug colonne détectée")
134
+ ],
135
+ title="Extraction de la colonne Description (Table Transformer)",
136
+ description="Détection automatique de la colonne Description dans les tableaux de factures"
137
  )
138
 
139
+ demo.launch()