feat: upgrade ocr to paddleocr and qwen 1.5b
Browse files- app.py +30 -17
- requirements.txt +2 -1
app.py
CHANGED
|
@@ -2,22 +2,27 @@ import json
|
|
| 2 |
import re
|
| 3 |
from typing import Any, Dict, List, Optional, Sequence
|
| 4 |
|
| 5 |
-
import easyocr
|
| 6 |
import gradio as gr
|
| 7 |
import numpy as np
|
| 8 |
import torch
|
| 9 |
from PIL import Image, ImageDraw
|
|
|
|
| 10 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
| 11 |
|
| 12 |
# --- OCR pipeline ---------------------------------------------------------
|
| 13 |
# Use a high-capacity OCR model for better accuracy on prescription labels.
|
| 14 |
-
OCR_LANGS = ["
|
| 15 |
-
LLM_MODEL_ID = "Qwen/Qwen2.5-
|
| 16 |
|
| 17 |
|
| 18 |
def _load_ocr():
|
| 19 |
use_gpu = torch.cuda.is_available()
|
| 20 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
ocr_reader = _load_ocr()
|
|
@@ -26,10 +31,15 @@ ocr_reader = _load_ocr()
|
|
| 26 |
def _load_llm():
|
| 27 |
device_map = "auto" if torch.cuda.is_available() else None
|
| 28 |
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 29 |
-
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
if device_map is None:
|
| 31 |
model = model.to(torch.device("cpu"))
|
| 32 |
-
tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID)
|
| 33 |
return model, tokenizer
|
| 34 |
|
| 35 |
|
|
@@ -156,21 +166,24 @@ def parse_fields(raw: str) -> Dict[str, Any]:
|
|
| 156 |
|
| 157 |
def ocr_and_parse(image: Image.Image) -> Dict[str, Any]:
|
| 158 |
np_img = np.array(image.convert("RGB"))
|
| 159 |
-
|
| 160 |
|
| 161 |
segments: List[Dict[str, Any]] = []
|
| 162 |
lines: List[str] = []
|
| 163 |
-
for
|
| 164 |
-
|
| 165 |
-
if not cleaned:
|
| 166 |
continue
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
raw_text = "\n".join(lines)
|
| 176 |
fields = parse_fields(raw_text)
|
|
|
|
| 2 |
import re
|
| 3 |
from typing import Any, Dict, List, Optional, Sequence
|
| 4 |
|
|
|
|
| 5 |
import gradio as gr
|
| 6 |
import numpy as np
|
| 7 |
import torch
|
| 8 |
from PIL import Image, ImageDraw
|
| 9 |
+
from paddleocr import PaddleOCR
|
| 10 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
| 11 |
|
| 12 |
# --- OCR pipeline ---------------------------------------------------------
|
| 13 |
# Use a high-capacity OCR model for better accuracy on prescription labels.
|
| 14 |
+
OCR_LANGS = ["korean", "en"]
|
| 15 |
+
LLM_MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
|
| 16 |
|
| 17 |
|
| 18 |
def _load_ocr():
|
| 19 |
use_gpu = torch.cuda.is_available()
|
| 20 |
+
return PaddleOCR(
|
| 21 |
+
use_angle_cls=True,
|
| 22 |
+
lang=OCR_LANGS[0],
|
| 23 |
+
show_log=False,
|
| 24 |
+
use_gpu=use_gpu,
|
| 25 |
+
)
|
| 26 |
|
| 27 |
|
| 28 |
ocr_reader = _load_ocr()
|
|
|
|
| 31 |
def _load_llm():
|
| 32 |
device_map = "auto" if torch.cuda.is_available() else None
|
| 33 |
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 34 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 35 |
+
LLM_MODEL_ID,
|
| 36 |
+
device_map=device_map,
|
| 37 |
+
torch_dtype=dtype,
|
| 38 |
+
trust_remote_code=True,
|
| 39 |
+
)
|
| 40 |
if device_map is None:
|
| 41 |
model = model.to(torch.device("cpu"))
|
| 42 |
+
tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID, trust_remote_code=True)
|
| 43 |
return model, tokenizer
|
| 44 |
|
| 45 |
|
|
|
|
| 166 |
|
| 167 |
def ocr_and_parse(image: Image.Image) -> Dict[str, Any]:
|
| 168 |
np_img = np.array(image.convert("RGB"))
|
| 169 |
+
ocr_results = ocr_reader.ocr(np_img, cls=True)
|
| 170 |
|
| 171 |
segments: List[Dict[str, Any]] = []
|
| 172 |
lines: List[str] = []
|
| 173 |
+
for result in ocr_results:
|
| 174 |
+
if not result:
|
|
|
|
| 175 |
continue
|
| 176 |
+
for bbox, (text, confidence) in result:
|
| 177 |
+
cleaned = (text or "").strip()
|
| 178 |
+
if not cleaned:
|
| 179 |
+
continue
|
| 180 |
+
lines.append(cleaned)
|
| 181 |
+
box_serializable = np.asarray(bbox, dtype=float).tolist()
|
| 182 |
+
segments.append({
|
| 183 |
+
"text": cleaned,
|
| 184 |
+
"confidence": float(confidence),
|
| 185 |
+
"bbox": box_serializable,
|
| 186 |
+
})
|
| 187 |
|
| 188 |
raw_text = "\n".join(lines)
|
| 189 |
fields = parse_fields(raw_text)
|
requirements.txt
CHANGED
|
@@ -3,6 +3,7 @@ torch
|
|
| 3 |
gradio
|
| 4 |
Pillow
|
| 5 |
sentencepiece
|
| 6 |
-
|
|
|
|
| 7 |
opencv-python-headless
|
| 8 |
numpy
|
|
|
|
| 3 |
gradio
|
| 4 |
Pillow
|
| 5 |
sentencepiece
|
| 6 |
+
paddleocr
|
| 7 |
+
paddlepaddle
|
| 8 |
opencv-python-headless
|
| 9 |
numpy
|