kebson commited on
Commit
30ffd4f
·
verified ·
1 Parent(s): 44d10bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -9
app.py CHANGED
@@ -9,7 +9,6 @@ from transformers import DetrImageProcessor, TableTransformerForObjectDetection
9
  # ===============================
10
  # Chargement des modèles
11
  # ===============================
12
-
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
  det_processor = DetrImageProcessor.from_pretrained(
@@ -29,7 +28,6 @@ struct_model = TableTransformerForObjectDetection.from_pretrained(
29
  # ===============================
30
  # OCR cellule
31
  # ===============================
32
-
33
  def ocr_cell(image):
34
  gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
35
  text = pytesseract.image_to_string(gray, config="--psm 6")
@@ -38,14 +36,14 @@ def ocr_cell(image):
38
  # ===============================
39
  # Fonction principale
40
  # ===============================
41
-
42
  def extract_description(image_pil):
43
-
44
  image = np.array(image_pil)
45
  h, w, _ = image.shape
46
 
47
  # ---- Détection du tableau ----
48
- inputs = det_processor(images=image_pil, return_tensors="pt").to(DEVICE)
 
49
  outputs = det_model(**inputs)
50
 
51
  results = det_processor.post_process_object_detection(
@@ -67,7 +65,8 @@ def extract_description(image_pil):
67
  table_img = image[y0:y1, x0:x1]
68
 
69
  # ---- Structure du tableau ----
70
- inputs = struct_processor(images=Image.fromarray(table_img), return_tensors="pt").to(DEVICE)
 
71
  outputs = struct_model(**inputs)
72
 
73
  results = struct_processor.post_process_object_detection(
@@ -87,7 +86,6 @@ def extract_description(image_pil):
87
 
88
  # ---- Grouper par colonne (X) ----
89
  cells_sorted = sorted(cells, key=lambda b: (b[0] + b[2]) / 2)
90
-
91
  columns = {}
92
  for cell in cells_sorted:
93
  cx = (cell[0] + cell[2]) // 2
@@ -116,7 +114,7 @@ def extract_description(image_pil):
116
 
117
  if desc_col is None:
118
  return "❌ Colonne 'Description' non trouvée", "\n\n".join(
119
- [" | ".join(col) for col in column_texts]
120
  )
121
 
122
  return "\n".join(desc_col[1:]), "\n\n".join(desc_col)
@@ -124,7 +122,6 @@ def extract_description(image_pil):
124
  # ===============================
125
  # Interface Gradio
126
  # ===============================
127
-
128
  demo = gr.Interface(
129
  fn=extract_description,
130
  inputs=gr.Image(type="pil", label="Image de facture"),
 
9
  # ===============================
10
  # Chargement des modèles
11
  # ===============================
 
12
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
  det_processor = DetrImageProcessor.from_pretrained(
 
28
  # ===============================
29
  # OCR cellule
30
  # ===============================
 
31
  def ocr_cell(image):
32
  gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
33
  text = pytesseract.image_to_string(gray, config="--psm 6")
 
36
  # ===============================
37
  # Fonction principale
38
  # ===============================
 
39
  def extract_description(image_pil):
40
+ # Convertir PIL -> np.array
41
  image = np.array(image_pil)
42
  h, w, _ = image.shape
43
 
44
  # ---- Détection du tableau ----
45
+ inputs = det_processor(images=image, return_tensors="pt")
46
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
47
  outputs = det_model(**inputs)
48
 
49
  results = det_processor.post_process_object_detection(
 
65
  table_img = image[y0:y1, x0:x1]
66
 
67
  # ---- Structure du tableau ----
68
+ inputs = struct_processor(images=table_img, return_tensors="pt")
69
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
70
  outputs = struct_model(**inputs)
71
 
72
  results = struct_processor.post_process_object_detection(
 
86
 
87
  # ---- Grouper par colonne (X) ----
88
  cells_sorted = sorted(cells, key=lambda b: (b[0] + b[2]) / 2)
 
89
  columns = {}
90
  for cell in cells_sorted:
91
  cx = (cell[0] + cell[2]) // 2
 
114
 
115
  if desc_col is None:
116
  return "❌ Colonne 'Description' non trouvée", "\n\n".join(
117
+ [f"Col {i}: " + " | ".join(col) for i, col in enumerate(column_texts)]
118
  )
119
 
120
  return "\n".join(desc_col[1:]), "\n\n".join(desc_col)
 
122
  # ===============================
123
  # Interface Gradio
124
  # ===============================
 
125
  demo = gr.Interface(
126
  fn=extract_description,
127
  inputs=gr.Image(type="pil", label="Image de facture"),