chengyewang commited on
Commit
2cbeb26
·
verified ·
1 Parent(s): 6bebbb2

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +73 -1
README.md CHANGED
@@ -11,4 +11,76 @@ This repository contains the reinforcement learning (RL) model based on **TexOCR
11
 
12
  - **Base Model**: TexOCR_OCR
13
  - **Training Method**: GRPO (Reinforcement Learning)
14
- - **Task**: Compilable Page-to-LaTeX Reconstruction
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  - **Base Model**: TexOCR_OCR
13
  - **Training Method**: GRPO (Reinforcement Learning)
14
+ - **Task**: Compilable Page-to-LaTeX Reconstruction
15
+
16
+ ## Inference
17
+
18
+ You can use the following code to run inference with the fine-tuned TexOCR model.
19
+
20
+ ```python
21
+ import torch
22
+ from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
23
+
24
+ # Load the fine-tuned model
25
+ model = Qwen3VLForConditionalGeneration.from_pretrained(
26
+ "chengyewang/TexOCR-RL",
27
+ dtype="auto",
28
+ device_map="auto"
29
+ )
30
+
31
+
32
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-2B-Instruct")
33
+
34
+ # Input document page image
35
+ image_path = "path/to/your/document_page.png"
36
+
37
+ messages = [
38
+ {
39
+ "role": "user",
40
+ "content": [
41
+ {
42
+ "type": "image",
43
+ "image": image_path,
44
+ },
45
+ {
46
+ "type": "text",
47
+ "text": (
48
+ "Convert this document page image into compilable LaTeX code. "
49
+ ),
50
+ },
51
+ ],
52
+ }
53
+ ]
54
+
55
+ # Preparation for inference
56
+ inputs = processor.apply_chat_template(
57
+ messages,
58
+ tokenize=True,
59
+ add_generation_prompt=True,
60
+ return_dict=True,
61
+ return_tensors="pt"
62
+ )
63
+ inputs = inputs.to(model.device)
64
+
65
+ # Inference: generate LaTeX output
66
+ generated_ids = model.generate(
67
+ **inputs,
68
+ max_new_tokens=2048,
69
+ do_sample=False
70
+ )
71
+
72
+ # Remove input tokens from the generated sequence
73
+ generated_ids_trimmed = [
74
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
75
+ ]
76
+
77
+ # Decode the generated LaTeX
78
+ latex_output = processor.batch_decode(
79
+ generated_ids_trimmed,
80
+ skip_special_tokens=True,
81
+ clean_up_tokenization_spaces=False
82
+ )
83
+
84
+ print(latex_output[0])
85
+ ```
86
+