kebson commited on
Commit
f46ff1b
·
verified ·
1 Parent(s): 7638b6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -150
app.py CHANGED
@@ -1,162 +1,153 @@
1
  import gradio as gr
 
2
  import numpy as np
3
- import unicodedata
 
 
4
  from paddleocr import PaddleOCR
5
- from sklearn.cluster import KMeans
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- # -------------------------------------------------
8
- # OCR
9
- # -------------------------------------------------
10
  ocr = PaddleOCR(
11
  lang="fr",
12
- use_textline_orientation=True
 
13
  )
14
 
15
- # -------------------------------------------------
16
- # Normalisation texte (casse + accents)
17
- # -------------------------------------------------
18
- def normalize(text: str) -> str:
19
- text = text.lower()
20
- text = unicodedata.normalize("NFD", text)
21
- text = "".join(c for c in text if unicodedata.category(c) != "Mn")
22
- return " ".join(text.split())
23
-
24
- # -------------------------------------------------
25
- # Titres valides de la colonne 2
26
- # -------------------------------------------------
27
- COL_TITLES = {
28
- "designation",
29
- "designations",
30
- "description",
31
- "description des services"
32
- }
33
-
34
- # -------------------------------------------------
35
- # Mots / lignes à ignorer
36
- # -------------------------------------------------
37
- IGNORE_KEYWORDS = {
38
- "prix", "total", "ht", "htva", "tva",
39
- "ttc", "general", "generale"
40
- }
41
-
42
- # -------------------------------------------------
43
- # Fonction principale
44
- # -------------------------------------------------
45
- def extract_second_column(image):
46
- if image is None:
47
- return "Aucune image fournie."
48
-
49
- img = np.array(image)
50
- result = ocr.predict(img)
51
-
52
- if not result:
53
- return "OCR : aucun texte détecté."
54
-
55
- data = result[0]
56
- texts = data.get("rec_texts", [])
57
- boxes = data.get("dt_polys", [])
58
-
59
- blocks = []
60
- for text, box in zip(texts, boxes):
61
- t = text.strip()
62
- if len(t) < 2:
63
- continue
64
-
65
- x = np.mean([p[0] for p in box])
66
- y = np.mean([p[1] for p in box])
67
-
68
- blocks.append((t, x, y))
69
-
70
- if len(blocks) < 5:
71
- return "Pas assez de texte exploitable."
72
-
73
- # -------------------------------------------------
74
- # 1. Détection du X de la colonne cible via son titre
75
- # -------------------------------------------------
76
- col_x = None
77
- for text, x, y in blocks:
78
- if normalize(text) in COL_TITLES:
79
- col_x = x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  break
81
 
82
- if col_x is None:
83
- return "Titre de la colonne cible non détecté."
84
-
85
- # -------------------------------------------------
86
- # 2. Sélection des blocs proches du X détecté
87
- # -------------------------------------------------
88
- X_THRESHOLD = 45
89
- column_blocks = [
90
- (t, x, y) for t, x, y in blocks
91
- if abs(x - col_x) < X_THRESHOLD
92
- ]
93
-
94
- if not column_blocks:
95
- return "Colonne détectée mais vide."
96
-
97
- # -------------------------------------------------
98
- # 3. Tri vertical (haut → bas)
99
- # -------------------------------------------------
100
- column_blocks.sort(key=lambda e: e[2])
101
-
102
- # -------------------------------------------------
103
- # 4. Fusion intelligente des lignes OCR
104
- # -------------------------------------------------
105
- merged = []
106
- current = ""
107
- last_y = None
108
- Y_THRESHOLD = 22
109
-
110
- for text, x, y in column_blocks:
111
- nt = normalize(text)
112
-
113
- if any(k in nt for k in IGNORE_KEYWORDS):
114
- continue
115
-
116
- if last_y is None or abs(y - last_y) > Y_THRESHOLD:
117
- if current:
118
- merged.append(current.strip())
119
- current = text
120
- else:
121
- current += " " + text
122
-
123
- last_y = y
124
-
125
- if current:
126
- merged.append(current.strip())
127
-
128
- # -------------------------------------------------
129
- # 5. Nettoyage final (cellules texte uniquement)
130
- # -------------------------------------------------
131
- final = []
132
- for line in merged:
133
- nt = normalize(line)
134
- if len(nt) < 4:
135
- continue
136
- if sum(c.isdigit() for c in line) > len(line) / 2:
137
- continue
138
- final.append(line)
139
-
140
- if not final:
141
- return "Aucune cellule texte valide trouvée."
142
-
143
- # -------------------------------------------------
144
- # 6. Résultat numéroté
145
- # -------------------------------------------------
146
- return "\n".join(f"{i+1}. {line}" for i, line in enumerate(final))
147
-
148
- # -------------------------------------------------
149
- # Interface Gradio
150
- # -------------------------------------------------
151
- demo = gr.Interface(
152
- fn=extract_second_column,
153
- inputs=gr.Image(type="pil", label="Image du tableau"),
154
- outputs=gr.Textbox(label="Contenu de la colonne 2"),
155
- title="Extraction fiable de la colonne 2 (Désignation / Description)",
156
- description=(
157
- "Extraction robuste de la deuxième colonne des tableaux scannés "
158
- "(Désignation, DESIGNATIONS, Description, Description des services)."
159
  )
160
- )
161
 
162
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import gradio as gr
2
+ import cv2
3
  import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+ from transformers import TableTransformerForObjectDetection, AutoImageProcessor
7
  from paddleocr import PaddleOCR
8
+ from unidecode import unidecode
9
+
10
+ # =========================
11
+ # Initialisation modèles
12
+ # =========================
13
+
14
+ device = "cpu"
15
+
16
+ processor = AutoImageProcessor.from_pretrained(
17
+ "microsoft/table-transformer-detection"
18
+ )
19
+
20
+ model = TableTransformerForObjectDetection.from_pretrained(
21
+ "microsoft/table-transformer-detection"
22
+ ).to(device)
23
 
 
 
 
24
  ocr = PaddleOCR(
25
  lang="fr",
26
+ use_angle_cls=True,
27
+ show_log=False
28
  )
29
 
30
+ # =========================
31
+ # Utils
32
+ # =========================
33
+
34
+ def normalize_text(text):
35
+ return unidecode(text.lower().strip())
36
+
37
+ def preprocess_image(pil_img):
38
+ img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
39
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
40
+ gray = cv2.adaptiveThreshold(
41
+ gray, 255,
42
+ cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
43
+ cv2.THRESH_BINARY, 31, 2
44
+ )
45
+ return gray
46
+
47
+ # =========================
48
+ # Détection tableau
49
+ # =========================
50
+
51
+ def detect_table(pil_img):
52
+ inputs = processor(images=pil_img, return_tensors="pt")
53
+ outputs = model(**inputs)
54
+
55
+ target_sizes = torch.tensor([pil_img.size[::-1]])
56
+ results = processor.post_process_object_detection(
57
+ outputs,
58
+ threshold=0.7,
59
+ target_sizes=target_sizes
60
+ )[0]
61
+
62
+ for score, label, box in zip(
63
+ results["scores"],
64
+ results["labels"],
65
+ results["boxes"]
66
+ ):
67
+ if model.config.id2label[label.item()] == "table":
68
+ return [int(x) for x in box.tolist()]
69
+
70
+ return None
71
+
72
+ # =========================
73
+ # OCR complet image
74
+ # =========================
75
+
76
+ def run_ocr(img):
77
+ result = ocr.ocr(img, cls=True)
78
+ lines = []
79
+ for block in result:
80
+ for line in block:
81
+ bbox, (text, _) = line
82
+ lines.append((bbox, text))
83
+ return lines
84
+
85
+ # =========================
86
+ # Extraction colonne Désignations
87
+ # =========================
88
+
89
+ def extract_designations(pil_img):
90
+ table_box = detect_table(pil_img)
91
+ if table_box is None:
92
+ return "❌ Aucun tableau détecté", []
93
+
94
+ x1, y1, x2, y2 = table_box
95
+ img = preprocess_image(pil_img)
96
+ table_img = img[y1:y2, x1:x2]
97
+
98
+ ocr_lines = run_ocr(table_img)
99
+
100
+ # Regrouper lignes par hauteur (approx colonnes)
101
+ columns = {}
102
+ for bbox, text in ocr_lines:
103
+ x_coords = [p[0] for p in bbox]
104
+ x_center = int(sum(x_coords) / len(x_coords))
105
+
106
+ if x_center not in columns:
107
+ columns[x_center] = []
108
+
109
+ columns[x_center].append(text)
110
+
111
+ # Trier colonnes de gauche à droite
112
+ sorted_cols = sorted(columns.items(), key=lambda x: x[0])
113
+
114
+ designation_col = None
115
+ for _, texts in sorted_cols:
116
+ header = normalize_text(" ".join(texts[:2]))
117
+ if any(k in header for k in [
118
+ "designation", "designation des travaux",
119
+ "libelle", "description"
120
+ ]):
121
+ designation_col = texts[1:] # skip header
122
  break
123
 
124
+ if designation_col is None:
125
+ return " Colonne Désignations non trouvée", []
126
+
127
+ cleaned = [t for t in designation_col if len(t.strip()) > 2]
128
+ return "✅ Extraction réussie", cleaned
129
+
130
+ # =========================
131
+ # Gradio UI
132
+ # =========================
133
+
134
+ def process(image):
135
+ status, designations = extract_designations(image)
136
+ return status, "\n".join(designations)
137
+
138
+ with gr.Blocks() as demo:
139
+ gr.Markdown("## 📄 Extraction de la colonne **Désignations**")
140
+
141
+ image_input = gr.Image(type="pil", label="Uploader une image")
142
+ status = gr.Textbox(label="Statut")
143
+ output = gr.Textbox(label="Désignations extraites", lines=15)
144
+
145
+ btn = gr.Button("Extraire")
146
+
147
+ btn.click(
148
+ process,
149
+ inputs=image_input,
150
+ outputs=[status, output]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  )
 
152
 
153
+ demo.launch(server_name="0.0.0.0",server_port=7860)