File size: 4,320 Bytes
322bbf8
 
 
 
 
55a0a6c
322bbf8
 
 
 
 
55a0a6c
 
 
 
 
 
 
322bbf8
55a0a6c
 
 
 
322bbf8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55a0a6c
322bbf8
 
 
 
55a0a6c
322bbf8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55a0a6c
 
322bbf8
 
 
55a0a6c
322bbf8
55a0a6c
322bbf8
 
 
 
 
 
 
 
 
 
 
55a0a6c
322bbf8
55a0a6c
322bbf8
55a0a6c
 
322bbf8
 
 
 
 
 
 
55a0a6c
 
322bbf8
 
 
55a0a6c
322bbf8
 
 
 
 
 
 
 
 
55a0a6c
322bbf8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
import base64
import gradio as gr
from io import BytesIO
from PIL import Image
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, BitsAndBytesConfig
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.prompts import build_no_anchoring_v4_yaml_prompt
import warnings
warnings.filterwarnings('ignore')

# Configure 8-bit quantization to reduce memory
quantization_config = BitsAndBytesConfig(
    load_in_8bit=True,
    llm_int8_enable_fp32_cpu_offload=True
)

print("Loading model with 8-bit quantization...")
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    "allenai/olmOCR-2-7B-1025",
    quantization_config=quantization_config,
    device_map="auto",
    low_cpu_mem_usage=True,
).eval()

processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
print("Model loaded successfully")

def process_document(file, page_number, max_tokens):
    if file is None:
        return "Please upload a file first.", None
    
    try:
        # Handle different file types
        if file.name.endswith('.pdf'):
            image_base64 = render_pdf_to_base64png(
                file.name, 
                page_number, 
                target_longest_image_dim=896  # Further reduced for memory
            )
            main_image = Image.open(BytesIO(base64.b64decode(image_base64)))
        else:
            main_image = Image.open(file.name)
            max_size = 896  # Reduced image size
            if max(main_image.size) > max_size:
                main_image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
            
            buffered = BytesIO()
            main_image.save(buffered, format="PNG")
            image_base64 = base64.b64encode(buffered.getvalue()).decode()
        
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": build_no_anchoring_v4_yaml_prompt()},
                    {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
                ],
            }
        ]
        
        text = processor.apply_chat_template(
            messages, 
            tokenize=False, 
            add_generation_prompt=True
        )
        
        inputs = processor(
            text=[text],
            images=[main_image],
            padding=True,
            return_tensors="pt",
        )
        
        # Generate with memory optimization
        with torch.no_grad():
            output = model.generate(
                **inputs,
                temperature=0.1,
                max_new_tokens=min(max_tokens, 256),  # Limit tokens
                num_return_sequences=1,
                do_sample=False,
            )
        
        prompt_length = inputs["input_ids"].shape[1]
        new_tokens = output[:, prompt_length:]
        text_output = processor.tokenizer.batch_decode(
            new_tokens, skip_special_tokens=True
        )
        
        return text_output[0], main_image
        
    except Exception as e:
        return f"Error: {str(e)}", None

# Create Gradio interface (same as before, but update max_tokens)
with gr.Blocks(title="olmOCR - Document OCR (CPU)") as demo:
    gr.Markdown("# olmOCR: Document OCR (Quantized)")
    gr.Markdown("⚠️ **Note**: Using 8-bit quantization for CPU compatibility. Processing may take 60-120 seconds.")
    
    with gr.Row():
        with gr.Column():
            file_input = gr.File(
                label="Upload Document (PDF, PNG, or JPEG)",
                file_types=[".pdf", ".png", ".jpg", ".jpeg"]
            )
            page_number = gr.Slider(1, 20, value=1, step=1, label="Page Number")
            max_tokens = gr.Slider(50, 256, value=128, step=16, label="Max Tokens")
            process_btn = gr.Button("Extract Text", variant="primary")
        
        with gr.Column():
            output_text = gr.Textbox(label="Extracted Text", lines=20)
            output_image = gr.Image(label="Processed Image")
    
    process_btn.click(
        fn=process_document,
        inputs=[file_input, page_number, max_tokens],
        outputs=[output_text, output_image]
    )

if __name__ == "__main__":
    demo.queue(max_size=2)
    demo.launch(server_name="0.0.0.0", server_port=7860)