kebson commited on
Commit
e626df6
·
verified ·
1 Parent(s): 17c79c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -0
app.py CHANGED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import gradio as gr
4
+ from PIL import Image
5
+
6
+ from transformers import (
7
+ DetrImageProcessor,
8
+ TableTransformerForObjectDetection,
9
+ TrOCRProcessor,
10
+ VisionEncoderDecoderModel
11
+ )
12
+
13
+ # ===============================
14
+ # Chargement des modèles
15
+ # ===============================
16
+
17
+ DEVICE = "cpu"
18
+
19
+ # Table detection
20
+ table_processor = DetrImageProcessor.from_pretrained(
21
+ "microsoft/table-transformer-detection"
22
+ )
23
+ table_model = TableTransformerForObjectDetection.from_pretrained(
24
+ "microsoft/table-transformer-detection"
25
+ ).to(DEVICE)
26
+ table_model.eval()
27
+
28
+ # OCR
29
+ ocr_processor = TrOCRProcessor.from_pretrained(
30
+ "microsoft/trocr-base-printed"
31
+ )
32
+ ocr_model = VisionEncoderDecoderModel.from_pretrained(
33
+ "microsoft/trocr-base-printed"
34
+ ).to(DEVICE)
35
+ ocr_model.eval()
36
+
37
+ # ===============================
38
+ # Utils
39
+ # ===============================
40
+
41
+ def cluster_columns(boxes, x_threshold=25):
42
+ """
43
+ Regroupe les bounding boxes par colonnes
44
+ en se basant sur la position X (x_min)
45
+ """
46
+ boxes = sorted(boxes, key=lambda b: b[0])
47
+ columns = []
48
+
49
+ for box in boxes:
50
+ placed = False
51
+ for col in columns:
52
+ if abs(col[0][0] - box[0]) < x_threshold:
53
+ col.append(box)
54
+ placed = True
55
+ break
56
+ if not placed:
57
+ columns.append([box])
58
+
59
+ return columns
60
+
61
+
62
+ def ocr_cell(image, box):
63
+ crop = image.crop(box)
64
+ pixel_values = ocr_processor(
65
+ crop, return_tensors="pt"
66
+ ).pixel_values.to(DEVICE)
67
+
68
+ with torch.no_grad():
69
+ generated_ids = ocr_model.generate(pixel_values)
70
+
71
+ text = ocr_processor.batch_decode(
72
+ generated_ids, skip_special_tokens=True
73
+ )[0]
74
+
75
+ return text.strip()
76
+
77
+
78
+ # ===============================
79
+ # Pipeline principal
80
+ # ===============================
81
+
82
+ def extract_second_column(image):
83
+ if image is None:
84
+ return "Aucune image fournie"
85
+
86
+ image = image.convert("RGB")
87
+
88
+ # 1. Détection des cellules
89
+ inputs = table_processor(
90
+ images=image, return_tensors="pt"
91
+ ).to(DEVICE)
92
+
93
+ with torch.no_grad():
94
+ outputs = table_model(**inputs)
95
+
96
+ target_sizes = torch.tensor(
97
+ [image.size[::-1]]
98
+ )
99
+
100
+ results = table_processor.post_process_object_detection(
101
+ outputs,
102
+ threshold=0.7,
103
+ target_sizes=target_sizes
104
+ )[0]
105
+
106
+ # 2. Garder uniquement les cellules
107
+ cells = []
108
+ for label, box in zip(results["labels"], results["boxes"]):
109
+ label_name = table_model.config.id2label[label.item()]
110
+ if label_name == "table cell":
111
+ cells.append([int(v) for v in box.tolist()])
112
+
113
+ if len(cells) == 0:
114
+ return "Aucune cellule détectée"
115
+
116
+ # 3. Regrouper par colonnes
117
+ columns = cluster_columns(cells)
118
+
119
+ if len(columns) < 2:
120
+ return "Moins de 2 colonnes détectées"
121
+
122
+ second_column = columns[1]
123
+
124
+ # Trier de haut en bas
125
+ second_column = sorted(second_column, key=lambda b: b[1])
126
+
127
+ # 4. OCR
128
+ extracted_texts = []
129
+ for box in second_column:
130
+ text = ocr_cell(image, box)
131
+ if text:
132
+ extracted_texts.append(text)
133
+
134
+ if not extracted_texts:
135
+ return "Aucun texte OCR extrait"
136
+
137
+ return "\n".join(extracted_texts)
138
+
139
+
140
+ # ===============================
141
+ # Interface Gradio
142
+ # ===============================
143
+
144
+ demo = gr.Interface(
145
+ fn=extract_second_column,
146
+ inputs=gr.Image(type="pil", label="Image du tableau"),
147
+ outputs=gr.Textbox(
148
+ label="Contenu de la 2ᵉ colonne",
149
+ lines=20
150
+ ),
151
+ title="Extraction automatique de la 2ᵉ colonne d’un tableau",
152
+ description=(
153
+ "Upload une image de tableau (JPEG/PNG).\n"
154
+ "Le système détecte le tableau et extrait uniquement "
155
+ "les cellules de la deuxième colonne."
156
+ )
157
+ )
158
+
159
+ demo.launch()