File size: 4,826 Bytes
76a5fff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35d3d26
76a5fff
c786b95
 
25db7d4
c786b95
 
39d84f4
25db7d4
c786b95
25db7d4
39d84f4
c786b95
25db7d4
c786b95
506e1a2
 
 
 
 
 
c786b95
25db7d4
c786b95
25db7d4
c786b95
 
25db7d4
c786b95
76a5fff
 
25db7d4
 
c786b95
25db7d4
 
76a5fff
c786b95
 
25db7d4
 
76a5fff
 
25db7d4
c786b95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25db7d4
c786b95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76a5fff
 
c786b95
76a5fff
c786b95
dcc9745
35d3d26
3400cca
ef5d4ca
0fb18ff
39d84f4
ef5d4ca
 
39d84f4
35d3d26
 
3400cca
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# Code anh Thang
# import gradio as gr
# from transformers import AutoProcessor, AutoModelForVision2Seq
# from PIL import Image
# import torch

# device = "cuda" if torch.cuda.is_available() else "cpu"
# torch.cuda.empty_cache()

# model_id = "prithivMLmods/Camel-Doc-OCR-062825"
# processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
# model = AutoModelForVision2Seq.from_pretrained(
#     model_id,
#     torch_dtype=torch.float16 if device == "cuda" else torch.float32,
#     trust_remote_code=True
# ).to(device)

# def predict(image, prompt=None):
#     image = image.convert("RGB")

#     # Cực kỳ quan trọng: text="" bắt buộc phải có
#     inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
#     # In debug để kiểm tra input_ids
#     print(">>> input_ids shape:", inputs.input_ids.shape)
#     generated_ids = model.generate(
#         **inputs,
#         max_new_tokens=512,
#         do_sample=False,
#         use_cache=False,  # ✅ Thêm dòng này để fix lỗi cache_position
#         eos_token_id=processor.tokenizer.eos_token_id,
#         pad_token_id=processor.tokenizer.pad_token_id
#     )

#     result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
#     return result

# demo = gr.Interface(
#     fn=predict,
#     inputs=[
#         gr.Image(type="pil", label="Tải ảnh tài liệu lên"),
#         gr.Textbox(label="Gợi ý (tuỳ chọn)", placeholder="VD: Trích số hóa đơn")
#     ],
#     outputs="text",
#     title="Camel-Doc OCR - Trích xuất văn bản từ ảnh"
# )

# if __name__ == "__main__":
#     demo.launch()

# Code fix
import gradio as gr
from PIL import Image, UnidentifiedImageError
from transformers import AutoProcessor, BitsAndBytesConfig, TextIteratorStreamer
from transformers.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
import torch
from threading import Thread
import time

# Cấu hình thiết bị
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()

# Load mô hình Qwen2.5-VL với quantization 4-bit
model_id = "prithivMLmods/Camel-Doc-OCR-062825"

# bnb_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_use_double_quant=True,
#     bnb_4bit_quant_type="nf4",
#     bnb_4bit_compute_dtype=torch.float16
# )

processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    model_id,
    # quantization_config=bnb_config, Quantization
    device_map="auto",
    trust_remote_code=True
).eval()

def convert_png_to_jpg(image):
    if image.mode in ["RGBA", "LA"]:
        converted = Image.new("RGB", image.size, (255, 255, 255))
        converted.paste(image, mask=image.split()[-1])
        return converted
    return image.convert("RGB")

# Hàm dự đoán
def predict(image, prompt=""):
    if image is None:
        return "=Vui lòng tải lên ảnh hợp lệ."

    try:
        image = convert_png_to_jpg(image)
        prompt = prompt.strip() if prompt else "Please describe the document."

        # Xây dựng prompt theo định dạng Qwen2.5-VL
        messages = [{
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": prompt}
            ]
        }]
        text_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

        inputs = processor(
            text=[text_prompt],
            images=[image],
            return_tensors="pt",
            padding=True
        ).to(model.device)

        # Dùng streamer để sinh kết quả mượt hơn
        streamer = TextIteratorStreamer(processor.tokenizer, skip_special_tokens=True, skip_prompt=True)
        generation_kwargs = {
            **inputs,
            "streamer": streamer,
            "max_new_tokens": 512,
            "do_sample": False,
            "use_cache": True
        }

        thread = Thread(target=model.generate, kwargs=generation_kwargs)
        thread.start()

        buffer = ""
        for new_text in streamer:
            buffer += new_text
            time.sleep(0.01)

        return buffer

    except UnidentifiedImageError:
        return "Không thể đọc ảnh. Ảnh có thể bị hỏng hoặc sai định dạng."
    except Exception as e:
        return f"Lỗi khi xử lý ảnh: {str(e)}"

demo = gr.Interface(
    fn=predict,
    inputs=[
        gr.Image(type="pil", label="Tải ảnh tài liệu lên"),
        gr.Textbox(label="Gợi ý (tuỳ chọn)", placeholder="VD: Trích số hóa đơn")
    ],
    outputs="text",
    title="Camel-Doc OCR - Trích xuất văn bản từ ảnh"
)

if __name__ == "__main__":
    demo.launch()