hayas commited on
Commit
caef7ce
·
1 Parent(s): aac8486

Add files

Browse files
Files changed (6) hide show
  1. .python-version +1 -0
  2. README.md +4 -3
  3. app.py +113 -0
  4. pyproject.toml +62 -0
  5. requirements.txt +319 -0
  6. uv.lock +0 -0
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12.12
README.md CHANGED
@@ -1,10 +1,11 @@
1
  ---
2
- title: Sarashina2.2 Ocr
3
- emoji: 💻
4
  colorFrom: blue
5
- colorTo: red
6
  sdk: gradio
7
  sdk_version: 6.10.0
 
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
+ title: Sarashina2.2-OCR
3
+ emoji: 📄
4
  colorFrom: blue
5
+ colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 6.10.0
8
+ python_version: "3.12.12"
9
  app_file: app.py
10
  pinned: false
11
  ---
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Generator
2
+ from threading import Thread
3
+
4
+ import gradio as gr
5
+ import spaces
6
+ import torch
7
+ from PIL import Image
8
+ from transformers import AutoModelForCausalLM, AutoProcessor, TextIteratorStreamer
9
+
10
+ MODEL_ID = "sbintuitions/sarashina2.2-ocr"
11
+
12
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True, use_fast=False)
13
+ model = AutoModelForCausalLM.from_pretrained(
14
+ MODEL_ID,
15
+ device_map="cuda",
16
+ dtype=torch.bfloat16,
17
+ trust_remote_code=True,
18
+ )
19
+
20
+
21
+ @spaces.GPU(duration=90)
22
+ def run_ocr(image: Image.Image | None) -> Generator[tuple[str, str], None, None]:
23
+ if image is None:
24
+ yield "", ""
25
+ return
26
+
27
+ message = [{"role": "user", "content": [{"type": "image", "image": image}]}]
28
+ inputs = processor.apply_chat_template(
29
+ message,
30
+ tokenize=True,
31
+ add_generation_prompt=True,
32
+ return_dict=True,
33
+ return_tensors="pt",
34
+ ).to(model.device)
35
+
36
+ streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True, timeout=20.0)
37
+ generate_kwargs = dict(
38
+ **inputs,
39
+ max_new_tokens=6000,
40
+ temperature=0.0,
41
+ top_p=0.95,
42
+ repetition_penalty=1.2,
43
+ use_cache=True,
44
+ streamer=streamer,
45
+ )
46
+
47
+ exception_holder: list[Exception] = []
48
+
49
+ def _generate() -> None:
50
+ try:
51
+ model.generate(**generate_kwargs)
52
+ except Exception as e: # noqa: BLE001
53
+ exception_holder.append(e)
54
+
55
+ thread = Thread(target=_generate)
56
+ thread.start()
57
+
58
+ result = ""
59
+ for text in streamer:
60
+ result += text
61
+ yield result, result
62
+
63
+ thread.join()
64
+ if exception_holder:
65
+ msg = f"Generation failed: {exception_holder[0]}"
66
+ raise gr.Error(msg)
67
+
68
+
69
+ with gr.Blocks() as demo:
70
+ gr.Markdown("# Sarashina2.2-OCR Demo")
71
+ gr.Markdown(
72
+ "Upload a document image to extract text using "
73
+ "[sbintuitions/sarashina2.2-ocr](https://huggingface.co/sbintuitions/sarashina2.2-ocr)."
74
+ )
75
+ with gr.Row():
76
+ with gr.Column():
77
+ image_input = gr.Image(label="Document Image", type="pil")
78
+ run_btn = gr.Button("Run OCR")
79
+ with gr.Column():
80
+ with gr.Tab("Rendered"):
81
+ output_md = gr.Markdown(
82
+ label="Result",
83
+ latex_delimiters=[
84
+ {"left": "$$", "right": "$$", "display": True},
85
+ {"left": "$", "right": "$", "display": False},
86
+ {"left": "\\(", "right": "\\)", "display": False},
87
+ {"left": "\\[", "right": "\\]", "display": True},
88
+ ],
89
+ )
90
+ with gr.Tab("Raw"):
91
+ output_text = gr.Textbox(label="Raw Markdown", lines=20)
92
+
93
+ gr.on(
94
+ triggers=[run_btn.click, image_input.upload],
95
+ fn=run_ocr,
96
+ inputs=image_input,
97
+ outputs=[output_md, output_text],
98
+ )
99
+
100
+ gr.Examples(
101
+ examples=[
102
+ ["https://huggingface.co/sbintuitions/sarashina2.2-ocr/resolve/main/assets/sample1.jpeg"],
103
+ ["https://huggingface.co/sbintuitions/sarashina2.2-ocr/resolve/main/assets/sample2.jpeg"],
104
+ ["https://huggingface.co/sbintuitions/sarashina2.2-ocr/resolve/main/assets/sample3.jpeg"],
105
+ ["https://huggingface.co/sbintuitions/sarashina2.2-ocr/resolve/main/assets/sample4.jpeg"],
106
+ ],
107
+ inputs=image_input,
108
+ fn=run_ocr,
109
+ outputs=[output_md, output_text],
110
+ )
111
+
112
+ if __name__ == "__main__":
113
+ demo.launch()
pyproject.toml ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "sarashina2-2-ocr"
3
+ version = "0.1.0"
4
+ description = "Gradio demo for sbintuitions/sarashina2.2-ocr"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "accelerate>=1.13.0",
9
+ "gradio>=6.10.0",
10
+ "pillow>=12.1.1",
11
+ "protobuf>=7.34.1",
12
+ "sentencepiece>=0.2.1",
13
+ "spaces>=0.48.1",
14
+ "torch==2.9.1",
15
+ "torchvision>=0.24.1",
16
+ "transformers==4.57.1",
17
+ ]
18
+
19
+ [tool.ruff]
20
+ line-length = 119
21
+
22
+ [tool.ruff.lint]
23
+ select = ["ALL"]
24
+ ignore = [
25
+ "COM812", # missing-trailing-comma
26
+ "D203", # one-blank-line-before-class
27
+ "D213", # multi-line-summary-second-line
28
+ "E501", # line-too-long
29
+ "SIM117", # multiple-with-statements
30
+ #
31
+ "D100", # undocumented-public-module
32
+ "D101", # undocumented-public-class
33
+ "D102", # undocumented-public-method
34
+ "D103", # undocumented-public-function
35
+ "D104", # undocumented-public-package
36
+ "D105", # undocumented-magic-method
37
+ "D107", # undocumented-public-init
38
+ "EM101", # raw-string-in-exception
39
+ "FBT001", # boolean-type-hint-positional-argument
40
+ "FBT002", # boolean-default-value-positional-argument
41
+ "ISC001", # single-line-implicit-string-concatenation
42
+ "PGH003", # blanket-type-ignore
43
+ "PLR0913", # too-many-arguments
44
+ "PLR0915", # too-many-statements
45
+ "TRY003", # raise-vanilla-args
46
+ ]
47
+ unfixable = [
48
+ "F401", # unused-import
49
+ ]
50
+
51
+ [tool.ruff.lint.pydocstyle]
52
+ convention = "google"
53
+
54
+ [tool.ruff.lint.per-file-ignores]
55
+ "app.py" = ["INP001"]
56
+
57
+ [tool.ruff.format]
58
+ docstring-code-format = true
59
+
60
+ [dependency-groups]
61
+ dev = ["ruff>=0.15.8"]
62
+ hf-spaces = ["datasets>=4.8.4"]
requirements.txt ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv export --no-hashes --no-dev --group hf-spaces --no-emit-package typer-slim --no-emit-package spaces -o requirements.txt
3
+ accelerate==1.13.0
4
+ # via sarashina2-2-ocr
5
+ aiofiles==24.1.0
6
+ # via gradio
7
+ aiohappyeyeballs==2.6.1
8
+ # via aiohttp
9
+ aiohttp==3.13.4
10
+ # via fsspec
11
+ aiosignal==1.4.0
12
+ # via aiohttp
13
+ annotated-doc==0.0.4
14
+ # via
15
+ # fastapi
16
+ # typer
17
+ annotated-types==0.7.0
18
+ # via pydantic
19
+ anyio==4.13.0
20
+ # via
21
+ # gradio
22
+ # httpx
23
+ # starlette
24
+ attrs==26.1.0
25
+ # via aiohttp
26
+ audioop-lts==0.2.2 ; python_full_version >= '3.13'
27
+ # via gradio
28
+ brotli==1.2.0
29
+ # via gradio
30
+ certifi==2026.2.25
31
+ # via
32
+ # httpcore
33
+ # httpx
34
+ # requests
35
+ charset-normalizer==3.4.6
36
+ # via requests
37
+ click==8.3.1
38
+ # via
39
+ # typer
40
+ # uvicorn
41
+ colorama==0.4.6 ; sys_platform == 'win32'
42
+ # via
43
+ # click
44
+ # tqdm
45
+ datasets==4.8.4
46
+ dill==0.4.1
47
+ # via
48
+ # datasets
49
+ # multiprocess
50
+ fastapi==0.135.2
51
+ # via gradio
52
+ ffmpy==1.0.0
53
+ # via gradio
54
+ filelock==3.25.2
55
+ # via
56
+ # datasets
57
+ # huggingface-hub
58
+ # torch
59
+ # transformers
60
+ frozenlist==1.8.0
61
+ # via
62
+ # aiohttp
63
+ # aiosignal
64
+ fsspec==2026.2.0
65
+ # via
66
+ # datasets
67
+ # gradio-client
68
+ # huggingface-hub
69
+ # torch
70
+ gradio==6.10.0
71
+ # via
72
+ # sarashina2-2-ocr
73
+ # spaces
74
+ gradio-client==2.4.0
75
+ # via
76
+ # gradio
77
+ # hf-gradio
78
+ groovy==0.1.2
79
+ # via gradio
80
+ h11==0.16.0
81
+ # via
82
+ # httpcore
83
+ # uvicorn
84
+ hf-gradio==0.3.0
85
+ # via gradio
86
+ hf-xet==1.4.2 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
87
+ # via huggingface-hub
88
+ httpcore==1.0.9
89
+ # via httpx
90
+ httpx==0.28.1
91
+ # via
92
+ # datasets
93
+ # gradio
94
+ # gradio-client
95
+ # safehttpx
96
+ # spaces
97
+ huggingface-hub==0.36.2
98
+ # via
99
+ # accelerate
100
+ # datasets
101
+ # gradio
102
+ # gradio-client
103
+ # tokenizers
104
+ # transformers
105
+ idna==3.11
106
+ # via
107
+ # anyio
108
+ # httpx
109
+ # requests
110
+ # yarl
111
+ jinja2==3.1.6
112
+ # via
113
+ # gradio
114
+ # torch
115
+ markdown-it-py==4.0.0
116
+ # via rich
117
+ markupsafe==3.0.3
118
+ # via
119
+ # gradio
120
+ # jinja2
121
+ mdurl==0.1.2
122
+ # via markdown-it-py
123
+ mpmath==1.3.0
124
+ # via sympy
125
+ multidict==6.7.1
126
+ # via
127
+ # aiohttp
128
+ # yarl
129
+ multiprocess==0.70.19
130
+ # via datasets
131
+ networkx==3.6.1
132
+ # via torch
133
+ numpy==2.4.4
134
+ # via
135
+ # accelerate
136
+ # datasets
137
+ # gradio
138
+ # pandas
139
+ # torchvision
140
+ # transformers
141
+ nvidia-cublas-cu12==12.8.4.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
142
+ # via
143
+ # nvidia-cudnn-cu12
144
+ # nvidia-cusolver-cu12
145
+ # torch
146
+ nvidia-cuda-cupti-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
147
+ # via torch
148
+ nvidia-cuda-nvrtc-cu12==12.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
149
+ # via torch
150
+ nvidia-cuda-runtime-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
151
+ # via torch
152
+ nvidia-cudnn-cu12==9.10.2.21 ; platform_machine == 'x86_64' and sys_platform == 'linux'
153
+ # via torch
154
+ nvidia-cufft-cu12==11.3.3.83 ; platform_machine == 'x86_64' and sys_platform == 'linux'
155
+ # via torch
156
+ nvidia-cufile-cu12==1.13.1.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
157
+ # via torch
158
+ nvidia-curand-cu12==10.3.9.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
159
+ # via torch
160
+ nvidia-cusolver-cu12==11.7.3.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
161
+ # via torch
162
+ nvidia-cusparse-cu12==12.5.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
163
+ # via
164
+ # nvidia-cusolver-cu12
165
+ # torch
166
+ nvidia-cusparselt-cu12==0.7.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
167
+ # via torch
168
+ nvidia-nccl-cu12==2.27.5 ; platform_machine == 'x86_64' and sys_platform == 'linux'
169
+ # via torch
170
+ nvidia-nvjitlink-cu12==12.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
171
+ # via
172
+ # nvidia-cufft-cu12
173
+ # nvidia-cusolver-cu12
174
+ # nvidia-cusparse-cu12
175
+ # torch
176
+ nvidia-nvshmem-cu12==3.3.20 ; platform_machine == 'x86_64' and sys_platform == 'linux'
177
+ # via torch
178
+ nvidia-nvtx-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
179
+ # via torch
180
+ orjson==3.11.7
181
+ # via gradio
182
+ packaging==26.0
183
+ # via
184
+ # accelerate
185
+ # datasets
186
+ # gradio
187
+ # gradio-client
188
+ # huggingface-hub
189
+ # spaces
190
+ # transformers
191
+ pandas==3.0.2
192
+ # via
193
+ # datasets
194
+ # gradio
195
+ pillow==12.1.1
196
+ # via
197
+ # gradio
198
+ # sarashina2-2-ocr
199
+ # torchvision
200
+ propcache==0.4.1
201
+ # via
202
+ # aiohttp
203
+ # yarl
204
+ protobuf==7.34.1
205
+ # via sarashina2-2-ocr
206
+ psutil==5.9.8
207
+ # via
208
+ # accelerate
209
+ # spaces
210
+ pyarrow==23.0.1
211
+ # via datasets
212
+ pydantic==2.12.5
213
+ # via
214
+ # fastapi
215
+ # gradio
216
+ # spaces
217
+ pydantic-core==2.41.5
218
+ # via pydantic
219
+ pydub==0.25.1
220
+ # via gradio
221
+ pygments==2.20.0
222
+ # via rich
223
+ python-dateutil==2.9.0.post0
224
+ # via pandas
225
+ python-multipart==0.0.22
226
+ # via gradio
227
+ pytz==2026.1.post1
228
+ # via gradio
229
+ pyyaml==6.0.3
230
+ # via
231
+ # accelerate
232
+ # datasets
233
+ # gradio
234
+ # huggingface-hub
235
+ # transformers
236
+ regex==2026.3.32
237
+ # via transformers
238
+ requests==2.33.1
239
+ # via
240
+ # datasets
241
+ # huggingface-hub
242
+ # spaces
243
+ # transformers
244
+ rich==14.3.3
245
+ # via typer
246
+ safehttpx==0.1.7
247
+ # via gradio
248
+ safetensors==0.7.0
249
+ # via
250
+ # accelerate
251
+ # transformers
252
+ semantic-version==2.10.0
253
+ # via gradio
254
+ sentencepiece==0.2.1
255
+ # via sarashina2-2-ocr
256
+ setuptools==82.0.1
257
+ # via torch
258
+ shellingham==1.5.4
259
+ # via typer
260
+ six==1.17.0
261
+ # via python-dateutil
262
+ starlette==0.52.1
263
+ # via
264
+ # fastapi
265
+ # gradio
266
+ sympy==1.14.0
267
+ # via torch
268
+ tokenizers==0.22.2
269
+ # via transformers
270
+ tomlkit==0.13.3
271
+ # via gradio
272
+ torch==2.9.1
273
+ # via
274
+ # accelerate
275
+ # sarashina2-2-ocr
276
+ # torchvision
277
+ torchvision==0.24.1
278
+ # via sarashina2-2-ocr
279
+ tqdm==4.67.3
280
+ # via
281
+ # datasets
282
+ # huggingface-hub
283
+ # transformers
284
+ transformers==4.57.1
285
+ # via sarashina2-2-ocr
286
+ triton==3.5.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
287
+ # via torch
288
+ typer==0.24.1
289
+ # via
290
+ # gradio
291
+ # hf-gradio
292
+ typing-extensions==4.15.0
293
+ # via
294
+ # aiosignal
295
+ # anyio
296
+ # fastapi
297
+ # gradio
298
+ # gradio-client
299
+ # huggingface-hub
300
+ # pydantic
301
+ # pydantic-core
302
+ # spaces
303
+ # starlette
304
+ # torch
305
+ # typing-inspection
306
+ typing-inspection==0.4.2
307
+ # via
308
+ # fastapi
309
+ # pydantic
310
+ tzdata==2025.3 ; sys_platform == 'emscripten' or sys_platform == 'win32'
311
+ # via pandas
312
+ urllib3==2.6.3
313
+ # via requests
314
+ uvicorn==0.42.0
315
+ # via gradio
316
+ xxhash==3.6.0
317
+ # via datasets
318
+ yarl==1.23.0
319
+ # via aiohttp
uv.lock ADDED
The diff for this file is too large to render. See raw diff