kebson commited on
Commit
6f02e5b
·
verified ·
1 Parent(s): 010fbd5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -133
app.py CHANGED
@@ -1,153 +1,139 @@
1
  import gradio as gr
2
- import cv2
3
  import numpy as np
4
- import torch
5
- from PIL import Image
6
- from transformers import TableTransformerForObjectDetection, AutoImageProcessor
7
  from paddleocr import PaddleOCR
8
- from unidecode import unidecode
9
-
10
- # =========================
11
- # Initialisation modèles
12
- # =========================
13
-
14
- device = "cpu"
15
-
16
- processor = AutoImageProcessor.from_pretrained(
17
- "microsoft/table-transformer-detection"
18
- )
19
-
20
- model = TableTransformerForObjectDetection.from_pretrained(
21
- "microsoft/table-transformer-detection"
22
- ).to(device)
23
 
 
 
 
24
  ocr = PaddleOCR(
25
- lang="fr",
26
- use_angle_cls=True
27
-
28
  )
29
 
30
- # =========================
31
- # Utils
32
- # =========================
 
 
 
33
 
34
- def normalize_text(text):
35
- return unidecode(text.lower().strip())
36
 
37
- def preprocess_image(pil_img):
38
- img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
39
- gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
40
- gray = cv2.adaptiveThreshold(
41
- gray, 255,
42
- cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
43
- cv2.THRESH_BINARY, 31, 2
44
- )
45
- return gray
46
-
47
- # =========================
48
- # Détection tableau
49
- # =========================
50
-
51
- def detect_table(pil_img):
52
- inputs = processor(images=pil_img, return_tensors="pt")
53
- outputs = model(**inputs)
54
-
55
- target_sizes = torch.tensor([pil_img.size[::-1]])
56
- results = processor.post_process_object_detection(
57
- outputs,
58
- threshold=0.7,
59
- target_sizes=target_sizes
60
- )[0]
61
-
62
- for score, label, box in zip(
63
- results["scores"],
64
- results["labels"],
65
- results["boxes"]
66
- ):
67
- if model.config.id2label[label.item()] == "table":
68
- return [int(x) for x in box.tolist()]
69
-
70
- return None
71
-
72
- # =========================
73
- # OCR complet image
74
- # =========================
75
-
76
- def run_ocr(img):
77
- result = ocr.ocr(img, cls=True)
78
- lines = []
79
- for block in result:
80
- for line in block:
81
- bbox, (text, _) = line
82
- lines.append((bbox, text))
83
- return lines
84
-
85
- # =========================
86
- # Extraction colonne Désignations
87
- # =========================
88
-
89
- def extract_designations(pil_img):
90
- table_box = detect_table(pil_img)
91
- if table_box is None:
92
- return "❌ Aucun tableau détecté", []
93
-
94
- x1, y1, x2, y2 = table_box
95
- img = preprocess_image(pil_img)
96
- table_img = img[y1:y2, x1:x2]
97
-
98
- ocr_lines = run_ocr(table_img)
99
-
100
- # Regrouper lignes par hauteur (approx colonnes)
101
- columns = {}
102
- for bbox, text in ocr_lines:
103
- x_coords = [p[0] for p in bbox]
104
- x_center = int(sum(x_coords) / len(x_coords))
105
 
106
- if x_center not in columns:
107
- columns[x_center] = []
 
108
 
109
- columns[x_center].append(text)
 
110
 
111
- # Trier colonnes de gauche à droite
112
- sorted_cols = sorted(columns.items(), key=lambda x: x[0])
 
 
 
 
 
 
113
 
114
- designation_col = None
115
- for _, texts in sorted_cols:
116
- header = normalize_text(" ".join(texts[:2]))
117
- if any(k in header for k in [
118
- "designation", "designation des travaux",
119
- "libelle", "description"
120
- ]):
121
- designation_col = texts[1:] # skip header
122
- break
123
 
124
- if designation_col is None:
125
- return "❌ Colonne Désignations non trouvée", []
126
 
127
- cleaned = [t for t in designation_col if len(t.strip()) > 2]
128
- return " Extraction réussie", cleaned
129
 
130
- # =========================
131
- # Gradio UI
132
- # =========================
 
 
133
 
134
- def process(image):
135
- status, designations = extract_designations(image)
136
- return status, "\n".join(designations)
137
 
138
- with gr.Blocks() as demo:
139
- gr.Markdown("## 📄 Extraction de la colonne **Désignations**")
140
-
141
- image_input = gr.Image(type="pil", label="Uploader une image")
142
- status = gr.Textbox(label="Statut")
143
- output = gr.Textbox(label="Désignations extraites", lines=15)
144
-
145
- btn = gr.Button("Extraire")
146
-
147
- btn.click(
148
- process,
149
- inputs=image_input,
150
- outputs=[status, output]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  )
152
 
153
- demo.queue().launch(server_name="0.0.0.0",server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
2
  import numpy as np
 
 
 
3
  from paddleocr import PaddleOCR
4
+ from sklearn.cluster import KMeans
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ # -----------------------------
7
+ # OCR
8
+ # -----------------------------
9
  ocr = PaddleOCR(
10
+ use_textline_orientation=True,
11
+ lang="fr"
 
12
  )
13
 
14
+ # -----------------------------
15
+ # Fonction principale
16
+ # -----------------------------
17
+ def extract_column2_9_lines(image):
18
+ if image is None:
19
+ return "Aucune image fournie."
20
 
21
+ img = np.array(image)
22
+ result = ocr.predict(img)
23
 
24
+ if not result or len(result) == 0:
25
+ return "OCR exécuté mais aucun texte détecté."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ data = result[0]
28
+ texts = data.get("rec_texts", [])
29
+ boxes = data.get("dt_polys", [])
30
 
31
+ if not texts:
32
+ return "Aucun texte exploitable détecté."
33
 
34
+ # -----------------------------
35
+ # 1. Collecte OCR
36
+ # -----------------------------
37
+ elements = []
38
+ for text, box in zip(texts, boxes):
39
+ text = text.strip()
40
+ if len(text) < 3:
41
+ continue
42
 
43
+ x_center = np.mean([p[0] for p in box])
44
+ y_center = np.mean([p[1] for p in box])
 
 
 
 
 
 
 
45
 
46
+ elements.append((x_center, y_center, text))
 
47
 
48
+ if len(elements) < 5:
49
+ return "Pas assez de texte détecté."
50
 
51
+ # -----------------------------
52
+ # 2. Clustering horizontal ADAPTATIF
53
+ # -----------------------------
54
+ X = np.array([[e[0]] for e in elements])
55
+ n_clusters = min(8, max(3, len(elements) // 8))
56
 
57
+ kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
58
+ labels = kmeans.fit_predict(X)
 
59
 
60
+ columns = {}
61
+ for (x, y, text), label in zip(elements, labels):
62
+ columns.setdefault(label, []).append((x, y, text))
63
+
64
+ # -----------------------------
65
+ # 3. Choisir la colonne "Description"
66
+ # => la plus riche en texte non numérique
67
+ # -----------------------------
68
+ def column_score(col):
69
+ score = 0
70
+ for _, _, t in col:
71
+ if not any(char.isdigit() for char in t):
72
+ score += len(t)
73
+ return score
74
+
75
+ best_column = max(columns.values(), key=column_score)
76
+
77
+ # Tri vertical
78
+ best_column.sort(key=lambda e: e[1])
79
+
80
+ # -----------------------------
81
+ # 4. Fusion intelligente des lignes
82
+ # -----------------------------
83
+ merged_lines = []
84
+ current_text = ""
85
+ last_y = None
86
+ Y_THRESHOLD = 22
87
+
88
+ blacklist = (
89
+ "DESIGNATION", "UNITE", "QUANT", "PRIX", "TOTAL",
90
+ "LOT", "BORDEREAU", "DATE", "NB", "TTC", "HT"
91
  )
92
 
93
+ for _, y, text in best_column:
94
+ if text.upper().startswith(blacklist):
95
+ continue
96
+
97
+ if last_y is None or abs(y - last_y) > Y_THRESHOLD:
98
+ if current_text:
99
+ merged_lines.append(current_text.strip())
100
+ current_text = text
101
+ else:
102
+ current_text += " " + text
103
+
104
+ last_y = y
105
+
106
+ if current_text:
107
+ merged_lines.append(current_text.strip())
108
+
109
+ # -----------------------------
110
+ # 5. Nettoyage final
111
+ # -----------------------------
112
+ cleaned = []
113
+ for line in merged_lines:
114
+ if len(line) < 5:
115
+ continue
116
+ if sum(c.isdigit() for c in line) > len(line) / 2:
117
+ continue
118
+ cleaned.append(line)
119
+
120
+ final_lines = cleaned[:9]
121
+
122
+ if not final_lines:
123
+ return "Colonne détectée mais contenu non exploitable."
124
+
125
+ # Numérotation demandée
126
+ return "\n".join([f"{i+1}. {l}" for i, l in enumerate(final_lines)])
127
+
128
+ # -----------------------------
129
+ # Interface Gradio
130
+ # -----------------------------
131
+ demo = gr.Interface(
132
+ fn=extract_column2_9_lines,
133
+ inputs=gr.Image(type="pil", label="Image du tableau"),
134
+ outputs=gr.Textbox(label="Colonne Description (9 lignes)"),
135
+ title="Extraction robuste de la colonne Description",
136
+ description="Optimisé pour tableaux photographiés (devis, factures, bordereaux)"
137
+ )
138
+
139
+ demo.launch(server_name="0.0.0.0", server_port=7860)