File size: 4,139 Bytes
76a5fff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35d3d26
39d84f4
76a5fff
39d84f4
76a5fff
39d84f4
76a5fff
39d84f4
 
 
76a5fff
39d84f4
 
 
 
 
 
 
 
76a5fff
 
 
 
 
 
 
 
 
 
 
39d84f4
76a5fff
 
 
 
 
 
 
 
 
 
 
 
39d84f4
76a5fff
 
 
 
 
 
 
 
4be1568
76a5fff
39d84f4
 
 
 
76a5fff
39d84f4
 
 
 
 
 
 
76a5fff
35d3d26
3400cca
ef5d4ca
39d84f4
 
ef5d4ca
 
39d84f4
35d3d26
 
3400cca
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# Code anh Thang
# import gradio as gr
# from transformers import AutoProcessor, AutoModelForVision2Seq
# from PIL import Image
# import torch

# device = "cuda" if torch.cuda.is_available() else "cpu"
# torch.cuda.empty_cache()

# model_id = "prithivMLmods/Camel-Doc-OCR-062825"
# processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
# model = AutoModelForVision2Seq.from_pretrained(
#     model_id,
#     torch_dtype=torch.float16 if device == "cuda" else torch.float32,
#     trust_remote_code=True
# ).to(device)

# def predict(image, prompt=None):
#     image = image.convert("RGB")

#     # Cực kỳ quan trọng: text="" bắt buộc phải có
#     inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
#     # In debug để kiểm tra input_ids
#     print(">>> input_ids shape:", inputs.input_ids.shape)
#     generated_ids = model.generate(
#         **inputs,
#         max_new_tokens=512,
#         do_sample=False,
#         use_cache=False,  # ✅ Thêm dòng này để fix lỗi cache_position
#         eos_token_id=processor.tokenizer.eos_token_id,
#         pad_token_id=processor.tokenizer.pad_token_id
#     )

#     result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
#     return result

# demo = gr.Interface(
#     fn=predict,
#     inputs=[
#         gr.Image(type="pil", label="Tải ảnh tài liệu lên"),
#         gr.Textbox(label="Gợi ý (tuỳ chọn)", placeholder="VD: Trích số hóa đơn")
#     ],
#     outputs="text",
#     title="Camel-Doc OCR - Trích xuất văn bản từ ảnh"
# )

# if __name__ == "__main__":
#     demo.launch()

# Code fix
import gradio as gr
from transformers import AutoProcessor, AutoModelForVision2Seq
from PIL import Image, UnidentifiedImageError
import torch
import os

# Cấu hình thiết bị
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.cuda.empty_cache()

# Load mô hình
model_id = "prithivMLmods/Camel-Doc-OCR-062825"
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForVision2Seq.from_pretrained(
    model_id,
    torch_dtype=torch.float16 if device == "cuda" else torch.float32,
    trust_remote_code=True
).to(device)

# Hỗ trợ định dạng ảnh
def is_supported_image(image):
    return isinstance(image, Image.Image)

# Chuyển PNG sang JPG
def convert_png_to_jpg(image):
    converted = Image.new("RGB", image.size, (255, 255, 255))
    converted.paste(image)
    return converted

# Hàm chính
def predict(image, prompt=None):
    # Kiểm tra ảnh hợp lệ
    if not is_supported_image(image):
        return "Không hỗ trợ định dạng file này. Vui lòng tải ảnh đúng."

    # Prompt rỗng
    if prompt is None or prompt.strip() == "":
        return "Vui lòng nhập prompt để trích xuất dữ liệu từ ảnh."

    try:
        # Nếu ảnh là PNG có alpha, convert sang RGB
        if image.mode == "RGBA" or image.mode == "LA":
            image = convert_png_to_jpg(image)

        image = image.convert("RGB")

    except UnidentifiedImageError:
        return "Không thể đọc ảnh. Vui lòng kiểm tra lại định dạng hoặc ảnh bị lỗi."
    except Exception as e:
        return f"Lỗi khi xử lý ảnh: {str(e)}"

    # Inference
    inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)

    generated_ids = model.generate(
        **inputs,
        max_new_tokens=512,
        do_sample=False,
        use_cache=False,  # fix cache_position
        eos_token_id=processor.tokenizer.eos_token_id,
        pad_token_id=processor.tokenizer.pad_token_id
    )

    result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return result


demo = gr.Interface(
    fn=predict,
    inputs=[
        gr.Image(type="pil", label="Tải ảnh tài liệu lên"),
        gr.Textbox(label="Gợi ý (tuỳ chọn)", placeholder="VD: Trích số hóa đơn")
    ],
    outputs="text",
    title="Camel-Doc OCR - Trích xuất văn bản từ ảnh"
)

if __name__ == "__main__":
    demo.launch()