sovitrath commited on
Commit
c2fc929
·
verified ·
1 Parent(s): c81b4f9

Update README.md

Browse files

Updated README with inference code.

Files changed (1) hide show
  1. README.md +68 -1
README.md CHANGED
@@ -7,4 +7,71 @@ base_model:
7
  tags:
8
  - ocr
9
  - vlm
10
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  tags:
8
  - ocr
9
  - vlm
10
+ ---
11
+
12
+ Usage:
13
+
14
+ ```python
15
+ model = AutoModelForVision2Seq.from_pretrained(
16
+ 'sovitrath/receipt-ocr-full-ft',
17
+ device_map='auto',
18
+ torch_dtype=torch.bfloat16,
19
+ _attn_implementation='flash_attention_2' # Use `flash_attention_2` on Ampere GPUs and above and `eager` on older GPUs.
20
+ # _attn_implementation='eager', # Use `flash_attention_2` on Ampere GPUs and above and `eager` on older GPUs.
21
+ )
22
+
23
+ processor = AutoProcessor.from_pretrained('sovitrath/receipt-ocr-full-ft')
24
+
25
+ test_image = Image.open('inference_data/image_1.jpeg').convert('RGB')
26
+
27
+ def test(model, processor, image, max_new_tokens=1024, device='cuda'):
28
+ messages = [
29
+ {
30
+ 'role': 'user',
31
+ 'content': [
32
+ {'type': 'image'},
33
+ {'type': 'text', 'text': 'OCR this image accurately'}
34
+ ]
35
+ },
36
+ ]
37
+
38
+ # Prepare the text input by applying the chat template
39
+ text_input = processor.apply_chat_template(
40
+ messages, # Use the sample without the system message
41
+ add_generation_prompt=True
42
+ )
43
+
44
+ image_inputs = []
45
+ if image.mode != 'RGB':
46
+ image = image.convert('RGB')
47
+
48
+ image_inputs.append([image])
49
+
50
+ # Prepare the inputs for the model
51
+ model_inputs = processor(
52
+ #text=[text_input],
53
+ text=text_input,
54
+ images=image_inputs,
55
+ return_tensors='pt',
56
+ ).to(device) # Move inputs to the specified device
57
+
58
+ # Generate text with the model
59
+ generated_ids = model.generate(**model_inputs, max_new_tokens=max_new_tokens)
60
+
61
+ # Trim the generated ids to remove the input ids
62
+ trimmed_generated_ids = [
63
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids)
64
+ ]
65
+
66
+ # Decode the output text
67
+ output_text = processor.batch_decode(
68
+ trimmed_generated_ids,
69
+ skip_special_tokens=True,
70
+ clean_up_tokenization_spaces=False
71
+ )
72
+
73
+ return output_text[0] # Return the first decoded output text
74
+
75
+ output = test(model, processor, test_image)
76
+ print(output)
77
+ ```