Harry Pham commited on
Commit
f407757
·
1 Parent(s): 5c88957

Fix: remove paddlepaddle, pin versions for Python 3.12

Browse files
Files changed (2) hide show
  1. requirements.txt +0 -3
  2. src/inference.py +49 -167
requirements.txt CHANGED
@@ -139,9 +139,6 @@ opt-einsum==3.3.0
139
  orjson==3.11.8
140
  overrides==7.4.0
141
  packaging==23.2
142
- paddleocr==3.4.0
143
- paddlepaddle==3.3.1
144
- paddlex==3.4.3
145
  pandas==2.1.1
146
  pandocfilters==1.5.0
147
  parso==0.8.3
 
139
  orjson==3.11.8
140
  overrides==7.4.0
141
  packaging==23.2
 
 
 
142
  pandas==2.1.1
143
  pandocfilters==1.5.0
144
  parso==0.8.3
src/inference.py CHANGED
@@ -1,12 +1,10 @@
1
  # src/inference.py
2
- # ── Patch torch.load — PHẢI LÀ DÒNG ĐẦU TIÊN ──────────────
3
  import torch
4
  _orig_torch_load = torch.load
5
  def _patched_load(*args, **kwargs):
6
  kwargs.setdefault("weights_only", False)
7
  return _orig_torch_load(*args, **kwargs)
8
  torch.load = _patched_load
9
- # ───────────────────────────────────────────────────────────
10
 
11
  import cv2
12
  import json
@@ -14,143 +12,71 @@ import numpy as np
14
  from pathlib import Path
15
  from ultralytics import RTDETR
16
 
17
- # ── Device ─────────────────────────────────────────────────
18
- DEVICE = "mps" if torch.backends.mps.is_available() else "cpu"
19
  print(f"[INFO] Device: {DEVICE}")
20
 
21
- # ── Class config ────────────────────────────────────────────
22
  CLASS_NAMES = ["note", "part-drawing", "table"]
23
- CLASS_DISPLAY = {
24
- "note": "Note",
25
- "part-drawing": "PartDrawing",
26
- "table": "Table",
27
- }
28
- COLORS = {
29
- "note": (0, 165, 255),
30
- "part-drawing": (0, 200, 0),
31
- "table": (0, 0, 220),
32
- }
33
 
34
- # ───────────────────────────────────────────────────────────
35
- # DETECTION MODEL
36
- # ───────────────────────────────────────────────────────────
37
- _det_model = None
38
 
39
- def get_det_model(checkpoint: str = "best.pt") -> RTDETR:
40
  global _det_model
41
  if _det_model is None:
42
- print(f"[INFO] Loading detection model: {checkpoint}")
43
  _det_model = RTDETR(checkpoint)
44
  return _det_model
45
 
46
-
47
- # ───────────────────────────────────────────────────────────
48
- # OCR ENGINES
49
- # ───────────────────────────────────────────────────────────
50
- _easy_reader = None
51
- _paddle_engine = None
52
-
53
  def get_easy_reader():
54
  global _easy_reader
55
  if _easy_reader is None:
56
  import easyocr
57
- print("[INFO] Loading EasyOCR (vi + en)...")
58
- _easy_reader = easyocr.Reader(
59
- ["vi", "en"],
60
- gpu=False,
61
- verbose=False,
62
- )
63
  return _easy_reader
64
 
65
-
66
- def get_paddle_engine():
67
- global _paddle_engine
68
- if _paddle_engine is None:
69
- from paddleocr import PaddleOCR
70
- print("[INFO] Loading PaddleOCR (vi)...")
71
- _paddle_engine = PaddleOCR(
72
- use_angle_cls=True,
73
- lang="vi",
74
- show_log=False,
75
- use_gpu=False,
76
- )
77
- return _paddle_engine
78
-
79
-
80
- # ───────────────────────────────────────────────────────────
81
- # PREPROCESSING
82
- # ───────────────────────────────────────────────────────────
83
- def preprocess_for_ocr(img_bgr: np.ndarray) -> np.ndarray:
84
  h, w = img_bgr.shape[:2]
85
-
86
- # Upscale nếu quá nhỏ
87
  if w < 800:
88
- scale = 800 / w
89
- img_bgr = cv2.resize(
90
- img_bgr,
91
- (int(w * scale), int(h * scale)),
92
- interpolation=cv2.INTER_CUBIC,
93
- )
94
-
95
  gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
96
- clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
97
  gray = clahe.apply(gray)
98
  gray = cv2.fastNlMeansDenoising(gray, h=15,
99
- templateWindowSize=7,
100
- searchWindowSize=21)
101
- kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])
102
  gray = cv2.filter2D(gray, -1, kernel)
103
-
104
  return cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
105
 
106
-
107
- # ───────────────────────────────────────────────────────────
108
- # OCR: NOTE
109
- # ───────────────────────────────────────────────────────────
110
- def ocr_note(img_path: str) -> str:
111
  img = cv2.imread(img_path)
112
  if img is None:
113
  return ""
114
-
115
  img_proc = preprocess_for_ocr(img)
116
-
117
- # EasyOCR
118
  try:
119
  reader = get_easy_reader()
120
  results = reader.readtext(img_proc, detail=1, paragraph=False,
121
  width_ths=0.7, height_ths=0.7)
122
- lines = [t for (_, t, c) in results if c >= 0.2 and t.strip()]
123
- if lines:
124
- return "\n".join(lines)
125
- except Exception as e:
126
- print(f"[WARN] EasyOCR note: {e}")
127
-
128
- # Fallback PaddleOCR
129
- try:
130
- ocr = get_paddle_engine()
131
- result = ocr.ocr(img_proc, cls=True)
132
- if result and result[0]:
133
- return "\n".join(l[1][0] for l in result[0] if l[1][1] >= 0.2)
134
  except Exception as e:
135
- print(f"[WARN] PaddleOCR note: {e}")
136
-
137
- return ""
138
-
139
 
140
- # ───────────────────────────────────────────────────────────
141
- # OCR: TABLE
142
- # ───────────────────────────────────────────────────────────
143
- def _group_rows(items: list) -> list:
144
  if not items:
145
  return []
146
  items = sorted(items, key=lambda x: x["y"])
147
  y_vals = [it["y"] for it in items]
148
  if len(y_vals) > 1:
149
- gaps = [y_vals[i+1] - y_vals[i] for i in range(len(y_vals)-1)]
150
  thresh = max(8, (sum(gaps)/len(gaps)) * 0.6)
151
  else:
152
  thresh = 12
153
-
154
  rows, cur = [], [items[0]]
155
  for item in items[1:]:
156
  if item["y"] - cur[-1]["y"] < thresh:
@@ -163,16 +89,12 @@ def _group_rows(items: list) -> list:
163
  rows.append([i["text"] for i in cur])
164
  return rows
165
 
166
-
167
- def ocr_table(img_path: str) -> dict:
168
  img = cv2.imread(img_path)
169
  if img is None:
170
- return {"rows": [], "text": ""}
171
-
172
  img_proc = preprocess_for_ocr(img)
173
- items = []
174
-
175
- # EasyOCR
176
  try:
177
  reader = get_easy_reader()
178
  results = reader.readtext(img_proc, detail=1, paragraph=False,
@@ -182,81 +104,46 @@ def ocr_table(img_path: str) -> dict:
182
  continue
183
  items.append({
184
  "text": text.strip(),
185
- "y": sum(p[1] for p in pts) / 4,
186
- "x": sum(p[0] for p in pts) / 4,
187
  })
188
  except Exception as e:
189
- print(f"[WARN] EasyOCR table: {e}")
190
-
191
- # Fallback PaddleOCR
192
- if not items:
193
- try:
194
- ocr = get_paddle_engine()
195
- result = ocr.ocr(img_proc, cls=True)
196
- if result and result[0]:
197
- for line in result[0]:
198
- pts, (text, conf) = line[0], line[1]
199
- if conf < 0.2 or not text.strip():
200
- continue
201
- items.append({
202
- "text": text.strip(),
203
- "y": sum(p[1] for p in pts) / 4,
204
- "x": sum(p[0] for p in pts) / 4,
205
- })
206
- except Exception as e:
207
- print(f"[WARN] PaddleOCR table: {e}")
208
-
209
  if not items:
210
- return {"rows": [], "text": ""}
211
-
212
  rows = _group_rows(items)
213
- return {
214
- "rows": rows,
215
- "text": "\n".join(" | ".join(r) for r in rows),
216
- }
217
 
218
-
219
- # ───────────────────────────────────────────────────────────
220
- # MAIN PIPELINE
221
- # ───────────────────────────────────────────────────────────
222
- def run_pipeline(
223
- image_path: str,
224
- output_dir: str = "outputs",
225
- checkpoint: str = "best.pt",
226
- conf_thresh: float = 0.3,
227
- ) -> tuple:
228
  image_path = str(image_path)
229
  img_name = Path(image_path).name
230
  stem = Path(image_path).stem
231
  crop_dir = Path(output_dir) / stem / "crops"
232
  crop_dir.mkdir(parents=True, exist_ok=True)
233
 
234
- # 1. Detect
235
  model = get_det_model(checkpoint)
236
  results = model(image_path, imgsz=1024, conf=conf_thresh,
237
  iou=0.5, device=DEVICE, verbose=False)
238
 
239
  img_bgr = cv2.imread(image_path)
240
  if img_bgr is None:
241
- raise ValueError(f"Không đọc được ảnh: {image_path}")
242
 
243
  objects = []
244
-
245
  for i, box in enumerate(results[0].boxes):
246
- x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
247
  cls_idx = int(box.cls[0])
248
  conf_val = round(float(box.conf[0]), 4)
249
  cls_raw = CLASS_NAMES[cls_idx]
250
  cls_show = CLASS_DISPLAY[cls_raw]
251
 
252
- # 2. Crop
253
  pad = 6
254
  crop = img_bgr[max(0,y1-pad):min(img_bgr.shape[0],y2+pad),
255
  max(0,x1-pad):min(img_bgr.shape[1],x2+pad)]
256
  crop_path = str(crop_dir / f"{cls_show}_{i+1}.jpg")
257
  cv2.imwrite(crop_path, crop, [cv2.IMWRITE_JPEG_QUALITY, 95])
258
 
259
- # 3. OCR
260
  ocr_content = None
261
  if cls_raw == "note":
262
  print(f"[OCR] Note #{i+1}...")
@@ -265,41 +152,36 @@ def run_pipeline(
265
  elif cls_raw == "table":
266
  print(f"[OCR] Table #{i+1}...")
267
  ocr_content = ocr_table(crop_path)
268
- print(f" → {repr(ocr_content.get('text','')[:80]) if ocr_content else 'EMPTY'}")
 
269
 
270
  objects.append({
271
- "id": i + 1,
272
- "class": cls_show,
273
- "confidence": conf_val,
274
- "bbox": {"x1": x1, "y1": y1, "x2": x2, "y2": y2},
275
- "crop_path": crop_path,
276
  "ocr_content": ocr_content,
277
  })
278
 
279
- # 4. Vẽ bbox
280
  color = COLORS[cls_raw]
281
- cv2.rectangle(img_bgr, (x1, y1), (x2, y2), color, 2)
282
  label = f"{cls_show} {conf_val:.2f}"
283
- (tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
284
- cv2.rectangle(img_bgr, (x1, y1-th-10), (x1+tw+8, y1), color, -1)
285
- cv2.putText(img_bgr, label, (x1+4, y1-4),
286
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,255), 2)
287
 
288
- # 5. Lưu visualize
289
- vis_path = str(Path(output_dir) / stem / "result_vis.jpg")
290
  cv2.imwrite(vis_path, img_bgr)
291
 
292
- # 6. Lưu JSON
293
  result = {"image": img_name, "objects": objects}
294
- json_path = str(Path(output_dir) / stem / "result.json")
295
- with open(json_path, "w", encoding="utf-8") as f:
296
  json.dump(result, f, ensure_ascii=False, indent=2)
297
 
298
- print(f"\n[✓] {len(objects)} objects | vis→{vis_path} | json→{json_path}")
299
  return result, vis_path
300
 
301
-
302
- # ── CLI ──────────────────────────────────────────────────────
303
  if __name__ == "__main__":
304
  import sys
305
  img = sys.argv[1] if len(sys.argv) > 1 else "test.jpg"
 
1
  # src/inference.py
 
2
  import torch
3
  _orig_torch_load = torch.load
4
  def _patched_load(*args, **kwargs):
5
  kwargs.setdefault("weights_only", False)
6
  return _orig_torch_load(*args, **kwargs)
7
  torch.load = _patched_load
 
8
 
9
  import cv2
10
  import json
 
12
  from pathlib import Path
13
  from ultralytics import RTDETR
14
 
15
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
16
  print(f"[INFO] Device: {DEVICE}")
17
 
 
18
  CLASS_NAMES = ["note", "part-drawing", "table"]
19
+ CLASS_DISPLAY = {"note": "Note", "part-drawing": "PartDrawing", "table": "Table"}
20
+ COLORS = {"note": (0,165,255), "part-drawing": (0,200,0), "table": (0,0,220)}
 
 
 
 
 
 
 
 
21
 
22
+ _det_model = None
23
+ _easy_reader = None
 
 
24
 
25
+ def get_det_model(checkpoint="best.pt"):
26
  global _det_model
27
  if _det_model is None:
28
+ print(f"[INFO] Loading model: {checkpoint}")
29
  _det_model = RTDETR(checkpoint)
30
  return _det_model
31
 
 
 
 
 
 
 
 
32
  def get_easy_reader():
33
  global _easy_reader
34
  if _easy_reader is None:
35
  import easyocr
36
+ print("[INFO] Loading EasyOCR...")
37
+ _easy_reader = easyocr.Reader(["vi","en"], gpu=False, verbose=False)
 
 
 
 
38
  return _easy_reader
39
 
40
+ def preprocess_for_ocr(img_bgr):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  h, w = img_bgr.shape[:2]
 
 
42
  if w < 800:
43
+ scale = 800 / w
44
+ img_bgr = cv2.resize(img_bgr, (int(w*scale), int(h*scale)),
45
+ interpolation=cv2.INTER_CUBIC)
 
 
 
 
46
  gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
47
+ clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
48
  gray = clahe.apply(gray)
49
  gray = cv2.fastNlMeansDenoising(gray, h=15,
50
+ templateWindowSize=7, searchWindowSize=21)
51
+ kernel = np.array([[0,-1,0],[-1,5,-1],[0,-1,0]])
 
52
  gray = cv2.filter2D(gray, -1, kernel)
 
53
  return cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
54
 
55
+ def ocr_note(img_path):
 
 
 
 
56
  img = cv2.imread(img_path)
57
  if img is None:
58
  return ""
 
59
  img_proc = preprocess_for_ocr(img)
 
 
60
  try:
61
  reader = get_easy_reader()
62
  results = reader.readtext(img_proc, detail=1, paragraph=False,
63
  width_ths=0.7, height_ths=0.7)
64
+ lines = [t for (_,t,c) in results if c >= 0.2 and t.strip()]
65
+ return "\n".join(lines)
 
 
 
 
 
 
 
 
 
 
66
  except Exception as e:
67
+ print(f"[WARN] ocr_note: {e}")
68
+ return ""
 
 
69
 
70
+ def _group_rows(items):
 
 
 
71
  if not items:
72
  return []
73
  items = sorted(items, key=lambda x: x["y"])
74
  y_vals = [it["y"] for it in items]
75
  if len(y_vals) > 1:
76
+ gaps = [y_vals[i+1]-y_vals[i] for i in range(len(y_vals)-1)]
77
  thresh = max(8, (sum(gaps)/len(gaps)) * 0.6)
78
  else:
79
  thresh = 12
 
80
  rows, cur = [], [items[0]]
81
  for item in items[1:]:
82
  if item["y"] - cur[-1]["y"] < thresh:
 
89
  rows.append([i["text"] for i in cur])
90
  return rows
91
 
92
+ def ocr_table(img_path):
 
93
  img = cv2.imread(img_path)
94
  if img is None:
95
+ return {"rows":[], "text":""}
 
96
  img_proc = preprocess_for_ocr(img)
97
+ items = []
 
 
98
  try:
99
  reader = get_easy_reader()
100
  results = reader.readtext(img_proc, detail=1, paragraph=False,
 
104
  continue
105
  items.append({
106
  "text": text.strip(),
107
+ "y": sum(p[1] for p in pts)/4,
108
+ "x": sum(p[0] for p in pts)/4,
109
  })
110
  except Exception as e:
111
+ print(f"[WARN] ocr_table: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  if not items:
113
+ return {"rows":[], "text":""}
 
114
  rows = _group_rows(items)
115
+ return {"rows": rows, "text": "\n".join(" | ".join(r) for r in rows)}
 
 
 
116
 
117
+ def run_pipeline(image_path, output_dir="outputs",
118
+ checkpoint="best.pt", conf_thresh=0.3):
 
 
 
 
 
 
 
 
119
  image_path = str(image_path)
120
  img_name = Path(image_path).name
121
  stem = Path(image_path).stem
122
  crop_dir = Path(output_dir) / stem / "crops"
123
  crop_dir.mkdir(parents=True, exist_ok=True)
124
 
 
125
  model = get_det_model(checkpoint)
126
  results = model(image_path, imgsz=1024, conf=conf_thresh,
127
  iou=0.5, device=DEVICE, verbose=False)
128
 
129
  img_bgr = cv2.imread(image_path)
130
  if img_bgr is None:
131
+ raise ValueError(f"Cannot read: {image_path}")
132
 
133
  objects = []
 
134
  for i, box in enumerate(results[0].boxes):
135
+ x1,y1,x2,y2 = map(int, box.xyxy[0].tolist())
136
  cls_idx = int(box.cls[0])
137
  conf_val = round(float(box.conf[0]), 4)
138
  cls_raw = CLASS_NAMES[cls_idx]
139
  cls_show = CLASS_DISPLAY[cls_raw]
140
 
 
141
  pad = 6
142
  crop = img_bgr[max(0,y1-pad):min(img_bgr.shape[0],y2+pad),
143
  max(0,x1-pad):min(img_bgr.shape[1],x2+pad)]
144
  crop_path = str(crop_dir / f"{cls_show}_{i+1}.jpg")
145
  cv2.imwrite(crop_path, crop, [cv2.IMWRITE_JPEG_QUALITY, 95])
146
 
 
147
  ocr_content = None
148
  if cls_raw == "note":
149
  print(f"[OCR] Note #{i+1}...")
 
152
  elif cls_raw == "table":
153
  print(f"[OCR] Table #{i+1}...")
154
  ocr_content = ocr_table(crop_path)
155
+ preview = ocr_content.get("text","")[:80]
156
+ print(f" → {repr(preview) if preview else 'EMPTY'}")
157
 
158
  objects.append({
159
+ "id": i+1, "class": cls_show,
160
+ "confidence": conf_val,
161
+ "bbox": {"x1":x1,"y1":y1,"x2":x2,"y2":y2},
162
+ "crop_path": crop_path,
 
163
  "ocr_content": ocr_content,
164
  })
165
 
 
166
  color = COLORS[cls_raw]
167
+ cv2.rectangle(img_bgr, (x1,y1), (x2,y2), color, 2)
168
  label = f"{cls_show} {conf_val:.2f}"
169
+ (tw,th),_ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
170
+ cv2.rectangle(img_bgr, (x1,y1-th-10), (x1+tw+8,y1), color, -1)
171
+ cv2.putText(img_bgr, label, (x1+4,y1-4),
172
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,255), 2)
173
 
174
+ vis_path = str(Path(output_dir)/stem/"result_vis.jpg")
 
175
  cv2.imwrite(vis_path, img_bgr)
176
 
 
177
  result = {"image": img_name, "objects": objects}
178
+ json_path = str(Path(output_dir)/stem/"result.json")
179
+ with open(json_path,"w",encoding="utf-8") as f:
180
  json.dump(result, f, ensure_ascii=False, indent=2)
181
 
182
+ print(f"[✓] {len(objects)} objects | {vis_path} | {json_path}")
183
  return result, vis_path
184
 
 
 
185
  if __name__ == "__main__":
186
  import sys
187
  img = sys.argv[1] if len(sys.argv) > 1 else "test.jpg"