LLDDWW commited on
Commit
a97e706
ยท
1 Parent(s): 2879fbc

Use PaddleOCR predict API and normalize inputs

Browse files
Files changed (1) hide show
  1. app.py +85 -21
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import json
2
  import re
3
- from typing import List, Optional, Tuple
4
  import numpy as np
5
  import os
6
 
@@ -26,6 +26,55 @@ MED_MODEL = None
26
  MED_TOKENIZER = None
27
  OCR_MODEL_REPO_ID = "PaddlePaddle/korean_PP-OCRv5_mobile_rec"
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def load_models():
30
  """๋ชจ๋ธ๋“ค์„ ํ•œ ๋ฒˆ๋งŒ ๋กœ๋“œ"""
31
  global OCR_READER, MED_MODEL, MED_TOKENIZER
@@ -90,7 +139,11 @@ def analyze_medication_image(image: Image.Image) -> Tuple[str, str]:
90
  # Step 1: OCR - PaddleOCR๋กœ ํ•œ๊ธ€ ํ…์ŠคํŠธ ์ถ”์ถœ
91
  start_time = time.time()
92
  img_array = np.array(image)
93
- ocr_results = OCR_READER.ocr(img_array)
 
 
 
 
94
  ocr_time = time.time() - start_time
95
  print(f"โฑ๏ธ OCR took {ocr_time:.2f}s")
96
 
@@ -98,20 +151,7 @@ def analyze_medication_image(image: Image.Image) -> Tuple[str, str]:
98
  return "ํ…์ŠคํŠธ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.", ""
99
 
100
  # ํ…์ŠคํŠธ ์ถ”์ถœ
101
- texts: List[str] = []
102
- first_entry = ocr_results[0]
103
-
104
- if isinstance(first_entry, list):
105
- texts = [line[1][0] for line in first_entry if len(line) > 1 and line[1]]
106
- elif isinstance(first_entry, dict):
107
- rec_results = first_entry.get("text_recognition") or first_entry.get("rec_results")
108
- if isinstance(rec_results, list):
109
- for rec in rec_results:
110
- if isinstance(rec, dict) and rec.get("text"):
111
- texts.append(rec["text"])
112
-
113
- if not texts and isinstance(first_entry.get("text"), str):
114
- texts.append(first_entry["text"])
115
 
116
  if not texts:
117
  return "ํ…์ŠคํŠธ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.", ""
@@ -261,16 +301,40 @@ def format_results(extracted_text: str, medications: List[str]) -> Tuple[str, st
261
  return text_output, med_output
262
 
263
 
264
- def run_analysis(image: Optional[Image.Image], progress=gr.Progress()):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  """๋ฉ”์ธ ๋ถ„์„ ํŒŒ์ดํ”„๋ผ์ธ: OCR + ์•ฝ ์ •๋ณด ๋ถ„์„"""
266
- if image is None:
 
 
267
  return "๐Ÿ“ท ์•ฝ ๋ด‰ํˆฌ๋‚˜ ์ฒ˜๋ฐฉ์ „ ์‚ฌ์ง„์„ ์—…๋กœ๋“œํ•ด์ฃผ์„ธ์š”.", ""
268
 
269
  progress(0.3, desc="๐Ÿ“ธ 1๋‹จ๊ณ„: OCR ํ…์ŠคํŠธ ์ถ”์ถœ ์ค‘...")
270
  progress(0.6, desc="๐Ÿค– 2๋‹จ๊ณ„: ์•ฝ ์ •๋ณด ๋ถ„์„ ์ค‘...")
271
 
272
  try:
273
- ocr_text, analysis = analyze_medication_image(image)
274
  progress(1.0, desc="โœ… ์™„๋ฃŒ!")
275
 
276
  ocr_output = f"### ๐Ÿ“„ ์ถ”์ถœ๋œ ํ…์ŠคํŠธ\n\n```\n{ocr_text}\n```"
@@ -375,7 +439,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
375
 
376
  with gr.Column(elem_classes=["upload-section"]):
377
  gr.Markdown("### ๐Ÿ“ธ ์‚ฌ์ง„ ์—…๋กœ๋“œ")
378
- image_input = gr.Image(type="pil", label="์•ฝ๋ด‰ํˆฌ ๋˜๋Š” ์ฒ˜๋ฐฉ์ „ ์‚ฌ์ง„", height=350)
379
  analyze_button = gr.Button("๐Ÿ” ์•ฝ ์ •๋ณด ๋ถ„์„ํ•˜๊ธฐ", elem_classes=["analyze-btn"], size="lg")
380
 
381
  with gr.Row():
@@ -406,7 +470,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
406
  - AI๊ฐ€ ์ƒ์„ฑํ•œ ์ •๋ณด์ด๋ฏ€๋กœ ์ •ํ™•ํ•˜์ง€ ์•Š์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค
407
 
408
  **๐Ÿค– ๊ธฐ์ˆ  ์Šคํƒ**
409
- - EasyOCR (ํ•œ๊ธ€+์˜์–ด, ์ดˆ๊ณ ์† OCR)
410
  - Google Gemma-2-2B-IT (8bit ์–‘์žํ™”, ๋น ๋ฅธ ์˜๋ฃŒ ์ •๋ณด ๋ถ„์„)
411
 
412
  **๐Ÿ”‘ ์„ค์ • ๋ฐฉ๋ฒ•**
 
1
  import json
2
  import re
3
+ from typing import List, Optional, Tuple, Union
4
  import numpy as np
5
  import os
6
 
 
26
  MED_TOKENIZER = None
27
  OCR_MODEL_REPO_ID = "PaddlePaddle/korean_PP-OCRv5_mobile_rec"
28
 
29
+
30
+ def _collect_ocr_texts(ocr_payload) -> List[str]:
31
+ """PaddleOCR ๊ฒฐ๊ณผ ๊ตฌ์กฐ์—์„œ ํ…์ŠคํŠธ๋งŒ ์ถ”์ถœ"""
32
+ texts: List[str] = []
33
+ seen = set()
34
+
35
+ def add_text(candidate: str):
36
+ if not isinstance(candidate, str):
37
+ return
38
+ normalized = candidate.strip()
39
+ if normalized and normalized not in seen:
40
+ seen.add(normalized)
41
+ texts.append(normalized)
42
+
43
+ def walk(node):
44
+ if isinstance(node, str):
45
+ add_text(node)
46
+ return
47
+
48
+ if isinstance(node, dict):
49
+ for key in ("text", "label", "transcription"):
50
+ add_text(node.get(key))
51
+
52
+ for key in ("texts", "labels"):
53
+ values = node.get(key)
54
+ if isinstance(values, (list, tuple)):
55
+ for value in values:
56
+ add_text(value)
57
+
58
+ for key in ("text_recognition", "rec_results", "data", "results"):
59
+ if key in node:
60
+ walk(node[key])
61
+ return
62
+
63
+ if isinstance(node, (list, tuple)):
64
+ if len(node) >= 2:
65
+ second = node[1]
66
+ if isinstance(second, str):
67
+ add_text(second)
68
+ elif isinstance(second, (list, tuple)) and second:
69
+ maybe_text = second[0]
70
+ add_text(maybe_text)
71
+
72
+ for item in node:
73
+ walk(item)
74
+
75
+ walk(ocr_payload)
76
+ return texts
77
+
78
  def load_models():
79
  """๋ชจ๋ธ๋“ค์„ ํ•œ ๋ฒˆ๋งŒ ๋กœ๋“œ"""
80
  global OCR_READER, MED_MODEL, MED_TOKENIZER
 
139
  # Step 1: OCR - PaddleOCR๋กœ ํ•œ๊ธ€ ํ…์ŠคํŠธ ์ถ”์ถœ
140
  start_time = time.time()
141
  img_array = np.array(image)
142
+
143
+ try:
144
+ ocr_results = OCR_READER.predict(img_array)
145
+ except (TypeError, AttributeError):
146
+ ocr_results = OCR_READER.ocr(img_array)
147
  ocr_time = time.time() - start_time
148
  print(f"โฑ๏ธ OCR took {ocr_time:.2f}s")
149
 
 
151
  return "ํ…์ŠคํŠธ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.", ""
152
 
153
  # ํ…์ŠคํŠธ ์ถ”์ถœ
154
+ texts = _collect_ocr_texts(ocr_results)
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  if not texts:
157
  return "ํ…์ŠคํŠธ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.", ""
 
301
  return text_output, med_output
302
 
303
 
304
+ def _ensure_pil(image_input: Optional[Union[Image.Image, np.ndarray, str]]) -> Optional[Image.Image]:
305
+ """Gradio ์ž…๋ ฅ์„ PIL ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜"""
306
+ if image_input is None:
307
+ return None
308
+
309
+ if isinstance(image_input, Image.Image):
310
+ return image_input
311
+
312
+ if isinstance(image_input, np.ndarray):
313
+ if image_input.dtype != np.uint8:
314
+ image_input = np.clip(image_input, 0, 255).astype(np.uint8)
315
+ return Image.fromarray(image_input).convert("RGB")
316
+
317
+ if isinstance(image_input, str):
318
+ if not os.path.exists(image_input):
319
+ return None
320
+ with Image.open(image_input) as img:
321
+ return img.convert("RGB")
322
+
323
+ return None
324
+
325
+
326
+ def run_analysis(image: Optional[Union[Image.Image, np.ndarray, str]], progress=gr.Progress()):
327
  """๋ฉ”์ธ ๋ถ„์„ ํŒŒ์ดํ”„๋ผ์ธ: OCR + ์•ฝ ์ •๋ณด ๋ถ„์„"""
328
+ pil_image = _ensure_pil(image)
329
+
330
+ if pil_image is None:
331
  return "๐Ÿ“ท ์•ฝ ๋ด‰ํˆฌ๋‚˜ ์ฒ˜๋ฐฉ์ „ ์‚ฌ์ง„์„ ์—…๋กœ๋“œํ•ด์ฃผ์„ธ์š”.", ""
332
 
333
  progress(0.3, desc="๐Ÿ“ธ 1๋‹จ๊ณ„: OCR ํ…์ŠคํŠธ ์ถ”์ถœ ์ค‘...")
334
  progress(0.6, desc="๐Ÿค– 2๋‹จ๊ณ„: ์•ฝ ์ •๋ณด ๋ถ„์„ ์ค‘...")
335
 
336
  try:
337
+ ocr_text, analysis = analyze_medication_image(pil_image)
338
  progress(1.0, desc="โœ… ์™„๋ฃŒ!")
339
 
340
  ocr_output = f"### ๐Ÿ“„ ์ถ”์ถœ๋œ ํ…์ŠคํŠธ\n\n```\n{ocr_text}\n```"
 
439
 
440
  with gr.Column(elem_classes=["upload-section"]):
441
  gr.Markdown("### ๐Ÿ“ธ ์‚ฌ์ง„ ์—…๋กœ๋“œ")
442
+ image_input = gr.Image(type="numpy", image_mode="RGB", label="์•ฝ๋ด‰ํˆฌ ๋˜๋Š” ์ฒ˜๋ฐฉ์ „ ์‚ฌ์ง„", height=350)
443
  analyze_button = gr.Button("๐Ÿ” ์•ฝ ์ •๋ณด ๋ถ„์„ํ•˜๊ธฐ", elem_classes=["analyze-btn"], size="lg")
444
 
445
  with gr.Row():
 
470
  - AI๊ฐ€ ์ƒ์„ฑํ•œ ์ •๋ณด์ด๋ฏ€๋กœ ์ •ํ™•ํ•˜์ง€ ์•Š์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค
471
 
472
  **๐Ÿค– ๊ธฐ์ˆ  ์Šคํƒ**
473
+ - PaddleOCR PP-OCRv5 (ํ•œ๊ตญ์–ด ์ตœ์ ํ™” OCR)
474
  - Google Gemma-2-2B-IT (8bit ์–‘์žํ™”, ๋น ๋ฅธ ์˜๋ฃŒ ์ •๋ณด ๋ถ„์„)
475
 
476
  **๐Ÿ”‘ ์„ค์ • ๋ฐฉ๋ฒ•**