Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from model_def import OCRModel | |
| from tokenizer import OCRTokenizer | |
| from utils import detect_text_boxes, preprocess_image, decode_predictions, crop_and_resize_line, sort_annotations_by_top | |
| import cv2 | |
| import numpy as np | |
| import json | |
| # Kiểm tra và chọn device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Đang sử dụng: {device}") | |
| # Load vocab | |
| with open("vocab.json", "r", encoding="utf-8") as f: | |
| vocab_data = json.load(f) | |
| vocab = vocab_data["vocab"] | |
| # Khởi tạo tokenizer | |
| tokenizer = OCRTokenizer(vocab) | |
| # Load model với hỗ trợ GPU | |
| model = OCRModel( | |
| vocab_size=102, | |
| encoder_dim=512, | |
| embed_dim=512, | |
| num_heads=8, | |
| num_layers=3, | |
| sos_token_id=vocab.get("<SOS>", 1), | |
| eos_token_id=vocab.get("<EOS>", 2) | |
| ).to(device) | |
| # Tối ưu hóa cho GPU | |
| torch.backends.cudnn.benchmark = True | |
| # Load weights với map_location phù hợp | |
| model.load_state_dict(torch.load("model/resnet34_7epoch_2h50.pt", map_location=device)) | |
| model.eval() | |
| def group_by_lines(annotations, line_threshold=15): | |
| """Nhóm các box text thành dòng""" | |
| for item in annotations: | |
| y_coords = [point[1] for point in item['box']] | |
| item['center_y'] = sum(y_coords) / len(y_coords) | |
| annotations.sort(key=lambda x: x['center_y']) | |
| lines = [] | |
| for item in annotations: | |
| matched = False | |
| for line in lines: | |
| if abs(item['center_y'] - line['center_y']) <= line_threshold: | |
| line['items'].append(item) | |
| line['center_y'] = sum(i['center_y'] for i in line['items']) / len(line['items']) | |
| matched = True | |
| break | |
| if not matched: | |
| lines.append({'items': [item], 'center_y': item['center_y']}) | |
| for line in lines: | |
| line['items'].sort(key=lambda x: min(point[0] for point in x['box'])) | |
| return lines | |
| def ocr_pipeline(image: Image.Image): | |
| try: | |
| image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
| image_rgb = np.array(image) | |
| # Detect text boxes với PaddleOCR | |
| annotations = detect_text_boxes(image_cv) | |
| lines = group_by_lines(annotations) | |
| results = [] | |
| target_height = 48 | |
| for line in lines: | |
| line_text = [] | |
| for item in line['items']: | |
| line_image = crop_and_resize_line(image_rgb, item['box'], target_height) | |
| pil_img = Image.fromarray(line_image) | |
| # Xử lý tensor trên GPU | |
| input_tensor = preprocess_image(pil_img).unsqueeze(0).to(device) | |
| # Dự đoán với mixed precision | |
| with torch.no_grad(): | |
| with torch.cuda.amp.autocast(): | |
| output = model.predict(input_tensor, beam_size=3) # Giảm beam_size để tăng tốc | |
| pred_text = decode_predictions(output, tokenizer)[0] | |
| line_text.append(pred_text) | |
| results.append(" ".join(line_text)) | |
| return "\n".join(results) | |
| except Exception as e: | |
| return f"Error processing image: {str(e)}" | |
| demo = gr.Interface( | |
| fn=ocr_pipeline, | |
| inputs=gr.Image(type="pil"), | |
| outputs="text", | |
| title="Vietnamese Receipt OCR", | |
| examples=None | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=False) |