kebson commited on
Commit
a50192d
·
verified ·
1 Parent(s): 9d45d87

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -132
app.py CHANGED
@@ -1,159 +1,63 @@
1
- import torch
2
- import numpy as np
3
  import gradio as gr
 
4
  from PIL import Image
 
5
 
6
- from transformers import (
7
- DetrImageProcessor,
8
- TableTransformerForObjectDetection,
9
- TrOCRProcessor,
10
- VisionEncoderDecoderModel
11
- )
12
 
13
- # ===============================
14
- # Chargement des modèles
15
- # ===============================
16
 
17
- DEVICE = "cpu"
 
 
 
 
 
18
 
19
- # Table detection
20
- table_processor = DetrImageProcessor.from_pretrained(
21
- "microsoft/table-transformer-detection"
22
- )
23
- table_model = TableTransformerForObjectDetection.from_pretrained(
24
- "microsoft/table-transformer-detection"
25
- ).to(DEVICE)
26
- table_model.eval()
27
 
28
- # OCR
29
- ocr_processor = TrOCRProcessor.from_pretrained(
30
- "microsoft/trocr-base-printed"
31
- )
32
- ocr_model = VisionEncoderDecoderModel.from_pretrained(
33
- "microsoft/trocr-base-printed"
34
- ).to(DEVICE)
35
- ocr_model.eval()
36
 
37
- # ===============================
38
- # Utils
39
- # ===============================
40
 
41
- def cluster_columns(boxes, x_threshold=25):
42
- """
43
- Regroupe les bounding boxes par colonnes
44
- en se basant sur la position X (x_min)
45
- """
46
- boxes = sorted(boxes, key=lambda b: b[0])
47
  columns = []
48
-
49
- for box in boxes:
50
  placed = False
51
  for col in columns:
52
- if abs(col[0][0] - box[0]) < x_threshold:
53
- col.append(box)
54
  placed = True
55
  break
56
  if not placed:
57
- columns.append([box])
58
-
59
- return columns
60
-
61
-
62
- def ocr_cell(image, box):
63
- crop = image.crop(box)
64
- pixel_values = ocr_processor(
65
- crop, return_tensors="pt"
66
- ).pixel_values.to(DEVICE)
67
-
68
- with torch.no_grad():
69
- generated_ids = ocr_model.generate(pixel_values)
70
-
71
- text = ocr_processor.batch_decode(
72
- generated_ids, skip_special_tokens=True
73
- )[0]
74
-
75
- return text.strip()
76
-
77
-
78
- # ===============================
79
- # Pipeline principal
80
- # ===============================
81
-
82
- def extract_second_column(image):
83
- if image is None:
84
- return "Aucune image fournie"
85
-
86
- image = image.convert("RGB")
87
-
88
- # 1. Détection des cellules
89
- inputs = table_processor(
90
- images=image, return_tensors="pt"
91
- ).to(DEVICE)
92
-
93
- with torch.no_grad():
94
- outputs = table_model(**inputs)
95
-
96
- target_sizes = torch.tensor(
97
- [image.size[::-1]]
98
- )
99
-
100
- results = table_processor.post_process_object_detection(
101
- outputs,
102
- threshold=0.3,
103
- target_sizes=target_sizes
104
- )[0]
105
-
106
- # 2. Garder uniquement les cellules
107
- cells = []
108
- for label, box in zip(results["labels"], results["boxes"]):
109
- label_name = table_model.config.id2label[label.item()]
110
- if label_name == "table cell":
111
- cells.append([int(v) for v in box.tolist()])
112
-
113
- if len(cells) == 0:
114
- return "Aucune cellule détectée"
115
-
116
- # 3. Regrouper par colonnes
117
- columns = cluster_columns(cells)
118
 
119
  if len(columns) < 2:
120
  return "Moins de 2 colonnes détectées"
121
 
122
- second_column = columns[1]
123
-
124
- # Trier de haut en bas
125
- second_column = sorted(second_column, key=lambda b: b[1])
126
 
127
- # 4. OCR
128
- extracted_texts = []
129
- for box in second_column:
130
- text = ocr_cell(image, box)
131
- if text:
132
- extracted_texts.append(text)
133
-
134
- if not extracted_texts:
135
- return "Aucun texte OCR extrait"
136
-
137
- return "\n".join(extracted_texts)
138
 
139
 
140
- # ===============================
141
- # Interface Gradio
142
- # ===============================
143
-
144
  demo = gr.Interface(
145
  fn=extract_second_column,
146
- inputs=gr.Image(type="pil", label="Image du tableau"),
147
- outputs=gr.Textbox(
148
- label="Contenu de la 2ᵉ colonne",
149
- lines=20
150
- ),
151
- title="Extraction automatique de la 2ᵉ colonne d’un tableau",
152
- description=(
153
- "Upload une image de tableau (JPEG/PNG).\n"
154
- "Le système détecte le tableau et extrait uniquement "
155
- "les cellules de la deuxième colonne."
156
- )
157
  )
158
 
159
  demo.launch()
 
 
 
1
  import gradio as gr
2
+ import numpy as np
3
  from PIL import Image
4
+ import pytesseract
5
 
6
+ def extract_second_column(image):
7
+ if image is None:
8
+ return "Aucune image fournie"
 
 
 
9
 
10
+ image = image.convert("RGB")
11
+ img = np.array(image)
 
12
 
13
+ # OCR avec positions
14
+ data = pytesseract.image_to_data(
15
+ img,
16
+ output_type=pytesseract.Output.DICT,
17
+ config="--psm 6"
18
+ )
19
 
20
+ words = []
21
+ for i in range(len(data["text"])):
22
+ text = data["text"][i].strip()
23
+ if text:
24
+ x = data["left"][i]
25
+ y = data["top"][i]
26
+ words.append((text, x, y))
 
27
 
28
+ if not words:
29
+ return "Aucun texte détecté"
 
 
 
 
 
 
30
 
31
+ # Trier par X (colonnes)
32
+ words.sort(key=lambda w: w[1])
 
33
 
34
+ # Regrouper par colonnes
 
 
 
 
 
35
  columns = []
36
+ for word in words:
 
37
  placed = False
38
  for col in columns:
39
+ if abs(col[0][1] - word[1]) < 60:
40
+ col.append(word)
41
  placed = True
42
  break
43
  if not placed:
44
+ columns.append([word])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  if len(columns) < 2:
47
  return "Moins de 2 colonnes détectées"
48
 
49
+ # 2ᵉ colonne
50
+ second_column = sorted(columns[1], key=lambda w: w[2])
 
 
51
 
52
+ return "\n".join([w[0] for w in second_column])
 
 
 
 
 
 
 
 
 
 
53
 
54
 
 
 
 
 
55
  demo = gr.Interface(
56
  fn=extract_second_column,
57
+ inputs=gr.Image(type="pil"),
58
+ outputs=gr.Textbox(lines=20),
59
+ title="Extraction de la 2ᵉ colonne (Facture)",
60
+ description="OCR + regroupement par colonnes (optimisé pour factures)"
 
 
 
 
 
 
 
61
  )
62
 
63
  demo.launch()