import gradio as gr import torch from PIL import Image from model_def import OCRModel from tokenizer import OCRTokenizer from utils import detect_text_boxes, preprocess_image, decode_predictions, crop_and_resize_line, sort_annotations_by_top import cv2 import numpy as np import json # Kiểm tra và chọn device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Đang sử dụng: {device}") # Load vocab with open("vocab.json", "r", encoding="utf-8") as f: vocab_data = json.load(f) vocab = vocab_data["vocab"] # Khởi tạo tokenizer tokenizer = OCRTokenizer(vocab) # Load model với hỗ trợ GPU model = OCRModel( vocab_size=102, encoder_dim=512, embed_dim=512, num_heads=8, num_layers=3, sos_token_id=vocab.get("", 1), eos_token_id=vocab.get("", 2) ).to(device) # Tối ưu hóa cho GPU torch.backends.cudnn.benchmark = True # Load weights với map_location phù hợp model.load_state_dict(torch.load("model/resnet34_7epoch_2h50.pt", map_location=device)) model.eval() def group_by_lines(annotations, line_threshold=15): """Nhóm các box text thành dòng""" for item in annotations: y_coords = [point[1] for point in item['box']] item['center_y'] = sum(y_coords) / len(y_coords) annotations.sort(key=lambda x: x['center_y']) lines = [] for item in annotations: matched = False for line in lines: if abs(item['center_y'] - line['center_y']) <= line_threshold: line['items'].append(item) line['center_y'] = sum(i['center_y'] for i in line['items']) / len(line['items']) matched = True break if not matched: lines.append({'items': [item], 'center_y': item['center_y']}) for line in lines: line['items'].sort(key=lambda x: min(point[0] for point in x['box'])) return lines def ocr_pipeline(image: Image.Image): try: image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) image_rgb = np.array(image) # Detect text boxes với PaddleOCR annotations = detect_text_boxes(image_cv) lines = group_by_lines(annotations) results = [] target_height = 48 for line in lines: line_text = [] for item in line['items']: line_image = crop_and_resize_line(image_rgb, item['box'], target_height) pil_img = Image.fromarray(line_image) # Xử lý tensor trên GPU input_tensor = preprocess_image(pil_img).unsqueeze(0).to(device) # Dự đoán với mixed precision with torch.no_grad(): with torch.cuda.amp.autocast(): output = model.predict(input_tensor, beam_size=3) # Giảm beam_size để tăng tốc pred_text = decode_predictions(output, tokenizer)[0] line_text.append(pred_text) results.append(" ".join(line_text)) return "\n".join(results) except Exception as e: return f"Error processing image: {str(e)}" demo = gr.Interface( fn=ocr_pipeline, inputs=gr.Image(type="pil"), outputs="text", title="Vietnamese Receipt OCR", examples=None ) if __name__ == "__main__": demo.launch(share=False)