LLDDWW commited on
Commit
e35cc62
·
1 Parent(s): 00b2fbb

feat: upgrade ocr to paddleocr and qwen 1.5b

Browse files
Files changed (2) hide show
  1. app.py +30 -17
  2. requirements.txt +2 -1
app.py CHANGED
@@ -2,22 +2,27 @@ import json
2
  import re
3
  from typing import Any, Dict, List, Optional, Sequence
4
 
5
- import easyocr
6
  import gradio as gr
7
  import numpy as np
8
  import torch
9
  from PIL import Image, ImageDraw
 
10
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
11
 
12
  # --- OCR pipeline ---------------------------------------------------------
13
  # Use a high-capacity OCR model for better accuracy on prescription labels.
14
- OCR_LANGS = ["ko", "en"]
15
- LLM_MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
16
 
17
 
18
  def _load_ocr():
19
  use_gpu = torch.cuda.is_available()
20
- return easyocr.Reader(OCR_LANGS, gpu=use_gpu)
 
 
 
 
 
21
 
22
 
23
  ocr_reader = _load_ocr()
@@ -26,10 +31,15 @@ ocr_reader = _load_ocr()
26
  def _load_llm():
27
  device_map = "auto" if torch.cuda.is_available() else None
28
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
29
- model = AutoModelForCausalLM.from_pretrained(LLM_MODEL_ID, device_map=device_map, torch_dtype=dtype)
 
 
 
 
 
30
  if device_map is None:
31
  model = model.to(torch.device("cpu"))
32
- tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID)
33
  return model, tokenizer
34
 
35
 
@@ -156,21 +166,24 @@ def parse_fields(raw: str) -> Dict[str, Any]:
156
 
157
  def ocr_and_parse(image: Image.Image) -> Dict[str, Any]:
158
  np_img = np.array(image.convert("RGB"))
159
- results = ocr_reader.readtext(np_img, detail=1, paragraph=False)
160
 
161
  segments: List[Dict[str, Any]] = []
162
  lines: List[str] = []
163
- for bbox, text, confidence in results:
164
- cleaned = text.strip()
165
- if not cleaned:
166
  continue
167
- lines.append(cleaned)
168
- box_serializable = np.asarray(bbox, dtype=float).tolist()
169
- segments.append({
170
- "text": cleaned,
171
- "confidence": float(confidence),
172
- "bbox": box_serializable,
173
- })
 
 
 
 
174
 
175
  raw_text = "\n".join(lines)
176
  fields = parse_fields(raw_text)
 
2
  import re
3
  from typing import Any, Dict, List, Optional, Sequence
4
 
 
5
  import gradio as gr
6
  import numpy as np
7
  import torch
8
  from PIL import Image, ImageDraw
9
+ from paddleocr import PaddleOCR
10
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
11
 
12
  # --- OCR pipeline ---------------------------------------------------------
13
  # Use a high-capacity OCR model for better accuracy on prescription labels.
14
+ OCR_LANGS = ["korean", "en"]
15
+ LLM_MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
16
 
17
 
18
  def _load_ocr():
19
  use_gpu = torch.cuda.is_available()
20
+ return PaddleOCR(
21
+ use_angle_cls=True,
22
+ lang=OCR_LANGS[0],
23
+ show_log=False,
24
+ use_gpu=use_gpu,
25
+ )
26
 
27
 
28
  ocr_reader = _load_ocr()
 
31
  def _load_llm():
32
  device_map = "auto" if torch.cuda.is_available() else None
33
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
34
+ model = AutoModelForCausalLM.from_pretrained(
35
+ LLM_MODEL_ID,
36
+ device_map=device_map,
37
+ torch_dtype=dtype,
38
+ trust_remote_code=True,
39
+ )
40
  if device_map is None:
41
  model = model.to(torch.device("cpu"))
42
+ tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID, trust_remote_code=True)
43
  return model, tokenizer
44
 
45
 
 
166
 
167
  def ocr_and_parse(image: Image.Image) -> Dict[str, Any]:
168
  np_img = np.array(image.convert("RGB"))
169
+ ocr_results = ocr_reader.ocr(np_img, cls=True)
170
 
171
  segments: List[Dict[str, Any]] = []
172
  lines: List[str] = []
173
+ for result in ocr_results:
174
+ if not result:
 
175
  continue
176
+ for bbox, (text, confidence) in result:
177
+ cleaned = (text or "").strip()
178
+ if not cleaned:
179
+ continue
180
+ lines.append(cleaned)
181
+ box_serializable = np.asarray(bbox, dtype=float).tolist()
182
+ segments.append({
183
+ "text": cleaned,
184
+ "confidence": float(confidence),
185
+ "bbox": box_serializable,
186
+ })
187
 
188
  raw_text = "\n".join(lines)
189
  fields = parse_fields(raw_text)
requirements.txt CHANGED
@@ -3,6 +3,7 @@ torch
3
  gradio
4
  Pillow
5
  sentencepiece
6
- easyocr
 
7
  opencv-python-headless
8
  numpy
 
3
  gradio
4
  Pillow
5
  sentencepiece
6
+ paddleocr
7
+ paddlepaddle
8
  opencv-python-headless
9
  numpy