Receipt_OCR / app.py
RickyGM15's picture
Upload folder using huggingface_hub
d1e4f85 verified
import gradio as gr
import torch
from PIL import Image
from model_def import OCRModel
from tokenizer import OCRTokenizer
from utils import detect_text_boxes, preprocess_image, decode_predictions, crop_and_resize_line, sort_annotations_by_top
import cv2
import numpy as np
import json
# Kiểm tra và chọn device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Đang sử dụng: {device}")
# Load vocab
with open("vocab.json", "r", encoding="utf-8") as f:
vocab_data = json.load(f)
vocab = vocab_data["vocab"]
# Khởi tạo tokenizer
tokenizer = OCRTokenizer(vocab)
# Load model với hỗ trợ GPU
model = OCRModel(
vocab_size=102,
encoder_dim=512,
embed_dim=512,
num_heads=8,
num_layers=3,
sos_token_id=vocab.get("<SOS>", 1),
eos_token_id=vocab.get("<EOS>", 2)
).to(device)
# Tối ưu hóa cho GPU
torch.backends.cudnn.benchmark = True
# Load weights với map_location phù hợp
model.load_state_dict(torch.load("model/resnet34_7epoch_2h50.pt", map_location=device))
model.eval()
def group_by_lines(annotations, line_threshold=15):
"""Nhóm các box text thành dòng"""
for item in annotations:
y_coords = [point[1] for point in item['box']]
item['center_y'] = sum(y_coords) / len(y_coords)
annotations.sort(key=lambda x: x['center_y'])
lines = []
for item in annotations:
matched = False
for line in lines:
if abs(item['center_y'] - line['center_y']) <= line_threshold:
line['items'].append(item)
line['center_y'] = sum(i['center_y'] for i in line['items']) / len(line['items'])
matched = True
break
if not matched:
lines.append({'items': [item], 'center_y': item['center_y']})
for line in lines:
line['items'].sort(key=lambda x: min(point[0] for point in x['box']))
return lines
def ocr_pipeline(image: Image.Image):
try:
image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
image_rgb = np.array(image)
# Detect text boxes với PaddleOCR
annotations = detect_text_boxes(image_cv)
lines = group_by_lines(annotations)
results = []
target_height = 48
for line in lines:
line_text = []
for item in line['items']:
line_image = crop_and_resize_line(image_rgb, item['box'], target_height)
pil_img = Image.fromarray(line_image)
# Xử lý tensor trên GPU
input_tensor = preprocess_image(pil_img).unsqueeze(0).to(device)
# Dự đoán với mixed precision
with torch.no_grad():
with torch.cuda.amp.autocast():
output = model.predict(input_tensor, beam_size=3) # Giảm beam_size để tăng tốc
pred_text = decode_predictions(output, tokenizer)[0]
line_text.append(pred_text)
results.append(" ".join(line_text))
return "\n".join(results)
except Exception as e:
return f"Error processing image: {str(e)}"
demo = gr.Interface(
fn=ocr_pipeline,
inputs=gr.Image(type="pil"),
outputs="text",
title="Vietnamese Receipt OCR",
examples=None
)
if __name__ == "__main__":
demo.launch(share=False)