File size: 2,337 Bytes
c81b4f9
 
 
 
 
 
 
 
 
c2fc929
1ac1152
c2fc929
 
 
 
a343eeb
 
 
 
 
 
c2fc929
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
---
license: apache-2.0
metrics:
- cer
base_model:
- HuggingFaceTB/SmolVLM-256M-Instruct
tags:
- ocr
- vlm
---
Check the GitHub project here => https://github.com/sovit-123/receipt_ocr

Usage:

```python
from transformers import AutoModelForVision2Seq, AutoProcessor
from PIL import Image

import torch
import argparse

model = AutoModelForVision2Seq.from_pretrained(
    'sovitrath/receipt-ocr-full-ft',
    device_map='auto',
    torch_dtype=torch.bfloat16,
    _attn_implementation='flash_attention_2' # Use `flash_attention_2` on Ampere GPUs and above and `eager` on older GPUs.
    # _attn_implementation='eager', # Use `flash_attention_2` on Ampere GPUs and above and `eager` on older GPUs.
)

processor = AutoProcessor.from_pretrained('sovitrath/receipt-ocr-full-ft')

test_image = Image.open('inference_data/image_1.jpeg').convert('RGB')

def test(model, processor, image, max_new_tokens=1024, device='cuda'):
    messages = [
        {
            'role': 'user',
            'content': [
                {'type': 'image'},
                {'type': 'text', 'text': 'OCR this image accurately'}
            ]
        },
    ]
    
    # Prepare the text input by applying the chat template
    text_input = processor.apply_chat_template(
        messages,  # Use the sample without the system message
        add_generation_prompt=True
    )

    image_inputs = []
    if image.mode != 'RGB':
        image = image.convert('RGB')
        
    image_inputs.append([image])

    # Prepare the inputs for the model
    model_inputs = processor(
        #text=[text_input],
        text=text_input,
        images=image_inputs,
        return_tensors='pt',
    ).to(device)  # Move inputs to the specified device

    # Generate text with the model
    generated_ids = model.generate(**model_inputs, max_new_tokens=max_new_tokens)

    # Trim the generated ids to remove the input ids
    trimmed_generated_ids = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids)
    ]

    # Decode the output text
    output_text = processor.batch_decode(
        trimmed_generated_ids, 
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False
    )

    return output_text[0]  # Return the first decoded output text

output = test(model, processor, test_image)
print(output)
```