topdu commited on
Commit
640f03f
·
1 Parent(s): ead8146

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -123
app.py CHANGED
@@ -1,123 +1,2 @@
1
- import gradio as gr
2
- import torch
3
- from threading import Thread
4
-
5
- import numpy as np
6
- import re
7
- from openrec.postprocess.unirec_postprocess import clean_special_tokens
8
- from openrec.preprocess import create_operators, transform
9
- from tools.engine.config import Config
10
- from tools.utils.ckpt import load_ckpt
11
- from tools.infer_rec import build_rec_process
12
-
13
-
14
- def set_device(device):
15
- if device == 'gpu' and torch.cuda.is_available():
16
- device = torch.device('cuda:0')
17
- else:
18
- device = torch.device('cpu')
19
- return device
20
-
21
-
22
- cfg = Config('configs/rec/unirec/focalsvtr_ardecoder_unirec.yml')
23
- cfg = cfg.cfg
24
- global_config = cfg['Global']
25
-
26
- from openrec.modeling.transformers_modeling.modeling_unirec import UniRecForConditionalGenerationNew
27
- from openrec.modeling.transformers_modeling.configuration_unirec import UniRecConfig
28
- from transformers import AutoTokenizer, TextIteratorStreamer
29
-
30
- tokenizer = AutoTokenizer.from_pretrained(global_config['vlm_ocr_config'])
31
- cfg_model = UniRecConfig.from_pretrained(global_config['vlm_ocr_config'])
32
- # cfg_model._attn_implementation = "flash_attention_2"
33
- cfg_model._attn_implementation = 'eager'
34
-
35
- model = UniRecForConditionalGenerationNew(config=cfg_model)
36
- load_ckpt(model, cfg)
37
- device = set_device(cfg['Global']['device'])
38
- model.eval()
39
- model.to(device=device)
40
-
41
- transforms, ratio_resize_flag = build_rec_process(cfg)
42
- ops = create_operators(transforms, global_config)
43
-
44
- # --- 2. 定义流式生成函数 ---
45
- def stream_chat_with_image(input_image, history):
46
- if input_image is None:
47
- yield history + [('🖼️(空)', '请先上传一张图片。')]
48
- return
49
-
50
- # 创建 TextIteratorStreamer
51
- streamer = TextIteratorStreamer(tokenizer,
52
- skip_prompt=True,
53
- skip_special_tokens=False)
54
-
55
- data = {'image': input_image}
56
- batch = transform(data, ops[1:])
57
- images = np.expand_dims(batch[0], axis=0)
58
- images = torch.from_numpy(images).to(device=device)
59
- inputs = {
60
- 'pixel_values': images,
61
- 'input_ids': None,
62
- 'attention_mask': None
63
- }
64
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=2048)
65
- # 后台线程运行生成
66
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
67
- thread.start()
68
- # 流式输出
69
- history = history + [('🖼️(图片)', '')]
70
- generated_text_ori = ''
71
- for new_text in streamer:
72
- generated_text_ori += new_text
73
- generated_text = clean_special_tokens(generated_text_ori.replace(' ', ''))
74
- text = generated_text.replace('<tdcolspan=', '<td colspan=')
75
- text = text.replace('<tdrowspan=', '<td rowspan=')
76
- generated_text = text.replace('"colspan=', '" colspan=')
77
- history[-1] = ('🖼️(图片)', generated_text)
78
- yield history
79
-
80
-
81
- # --- 3. Gradio UI ---
82
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
83
- gr.HTML("""
84
- <h1 style='text-align: center;'>
85
- <a href="https://github.com/Topdu/OpenOCR">
86
- UniRec-0.1B: Unified Text and Formula Recognition with 0.1B Parameters
87
- </a>
88
- </h1>
89
- <p style='text-align: center;'>
90
- A ultralight unified text and formula recognition model
91
- (Created by <a href="https://fvl.fudan.edu.cn">FVL Lab</a>,
92
- <a href="https://github.com/Topdu/OpenOCR">OCR Team</a>)
93
- </p>
94
- <p style='text-align: center;'>
95
- <a href="https://github.com/Topdu/OpenOCR/blob/main/docs/unirec.md">[Local GPU Deployment]</a>
96
- for fast recognition experience
97
- </p>"""
98
- )
99
- gr.Markdown('Upload an image, and the system will automatically recognize text and formulas.')
100
- with gr.Row():
101
- with gr.Column(scale=1): # Left column: image + clear button
102
- image_input = gr.Image(label='Upload Image or Paste Screenshot', type='pil')
103
- clear = gr.ClearButton([image_input], value='Clear')
104
- with gr.Column(scale=2):
105
- chatbot = gr.Chatbot(
106
- label='Result (Use LaTeX renderer to display formulas)',
107
- show_copy_button=True,
108
- height='auto'
109
- )
110
- clear.add([chatbot])
111
-
112
- # Trigger after upload
113
- # image_input.upload(stream_chat_with_image, [image_input, chatbot], chatbot)
114
- image_input.change(
115
- stream_chat_with_image,
116
- [image_input, chatbot],
117
- chatbot,
118
- show_progress=False
119
- )
120
-
121
- # --- 4. Launch app ---
122
- if __name__ == '__main__':
123
- demo.queue().launch(share=True)
 
1
+ from openocr.demo_unirec import launch_demo
2
+ launch_demo()