kirchik47 commited on
Commit
379b35c
·
1 Parent(s): 7bbe42c

Moving main repo to hf repo

Browse files
.gitattributes CHANGED
@@ -1,35 +1,4 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ custom_got/model.safetensors filter=lfs diff=lfs merge=lfs -text
2
+ data_80k filter=lfs diff=lfs merge=lfs -text
3
+ data_80k/data.csv filter=lfs diff=lfs merge=lfs -text
4
+ dataset.json filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
163
+ data_80k
README.md CHANGED
@@ -1,13 +1,8 @@
1
- ---
2
- title: Ocr Task
3
- emoji: 📊
4
- colorFrom: indigo
5
- colorTo: green
6
- sdk: streamlit
7
- sdk_version: 1.38.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ # ocr_task
2
+ OCR assignment for PARIMAL IIT Roorkee Internship.
3
+
4
+ Two models for OCR were considered: GOT 2.0 and Colpali implementation of Byaldi library + Qwen2-VL. After research GOT was chosen because it has specification of extracting text from image directly without using LLM for explaining the content of the file. Besides that, GOT has direct instructions for training and fine-tuning model with data samples. Since GOT does not generate hindi symbols at all, I've needed to fine-tune the model on hindi dataset. Tokenizer already contained tokens for hindi symbols, so adding tokens was not necessary.
5
+ However, GOT is only compatible with CUDA, so on my device it won't be possible to fine-tune it. I've chosen to use Google Colab for this since it provides GPU for limited use.
6
+
7
+ During deployment on streamlit sharing encountered a problem with '\left' strings which were problematic escape sequences due to '\'. Used additional script replacer.py to replace all these string to '\\left'.
 
 
 
 
8
 
 
app.py CHANGED
@@ -1,4 +1,29 @@
1
  import streamlit as st
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from main_got import extract_text
3
+ import re
4
 
5
+
6
+ # Streamlit UI
7
+ st.title("OCR and Document Search Web App")
8
+
9
+ # Image upload
10
+ uploaded_image = st.file_uploader("Upload an image for OCR", type=["jpg", "png", "jpeg"])
11
+
12
+ if uploaded_image is not None:
13
+ with st.spinner("Processing image..."):
14
+ # Extract text from the uploaded image
15
+ extracted_text = extract_text(uploaded_image)
16
+ st.subheader("Extracted Text")
17
+ st.write(extracted_text)
18
+
19
+ # Search functionality
20
+ search_query = st.text_input("Enter a keyword to search within the text")
21
+ if search_query:
22
+ results = [match.start() for match in re.finditer(search_query, extracted_text)]
23
+ if results:
24
+ st.subheader("Search Results")
25
+ for result in results:
26
+ st.write(f"Keyword found at index: {result}")
27
+ else:
28
+ st.write("No results found.")
29
+
custom_got/assets/got_logo.png ADDED
custom_got/assets/got_support.jpg ADDED
custom_got/assets/train_sample.jpg ADDED
custom_got/config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "ucaslcl/GOT-OCR2_0",
3
+ "architectures": [
4
+ "GOTQwenForCausalLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "modeling_GOT.GOTConfig",
8
+ "AutoModel": "modeling_GOT.GOTQwenForCausalLM"
9
+ },
10
+ "attention_dropout": 0.0,
11
+ "bos_token_id": 151643,
12
+ "eos_token_id": 151643,
13
+ "freeze_vision_tower": false,
14
+ "hidden_act": "silu",
15
+ "hidden_size": 1024,
16
+ "im_end_token": 151858,
17
+ "im_patch_token": 151859,
18
+ "im_start_token": 151857,
19
+ "image_token_len": 256,
20
+ "initializer_range": 0.02,
21
+ "intermediate_size": 2816,
22
+ "max_position_embeddings": 32768,
23
+ "max_window_layers": 21,
24
+ "model_type": "GOT",
25
+ "num_attention_heads": 16,
26
+ "num_hidden_layers": 24,
27
+ "num_key_value_heads": 16,
28
+ "rms_norm_eps": 1e-06,
29
+ "rope_theta": 1000000.0,
30
+ "sliding_window": 32768,
31
+ "tie_word_embeddings": true,
32
+ "torch_dtype": "bfloat16",
33
+ "transformers_version": "4.37.2",
34
+ "use_cache": true,
35
+ "use_im_start_end": true,
36
+ "use_sliding_window": false,
37
+ "vocab_size": 151860
38
+ }
custom_got/generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "eos_token_id": 151643,
4
+ "max_new_tokens": 2048,
5
+ "transformers_version": "4.37.2"
6
+ }
custom_got/got_vision_b.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from typing import Optional, Tuple, Type
4
+ from functools import partial
5
+ import torch.nn as nn
6
+ from typing import Type
7
+
8
+
9
+
10
+ class MLPBlock(nn.Module):
11
+ def __init__(
12
+ self,
13
+ embedding_dim: int,
14
+ mlp_dim: int,
15
+ act: Type[nn.Module] = nn.GELU,
16
+ ) -> None:
17
+ super().__init__()
18
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
19
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
20
+ self.act = act()
21
+
22
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
23
+ return self.lin2(self.act(self.lin1(x)))
24
+
25
+
26
+
27
+ class LayerNorm2d(nn.Module):
28
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
29
+ super().__init__()
30
+ self.weight = nn.Parameter(torch.ones(num_channels))
31
+ self.bias = nn.Parameter(torch.zeros(num_channels))
32
+ self.eps = eps
33
+
34
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
35
+ u = x.mean(1, keepdim=True)
36
+ s = (x - u).pow(2).mean(1, keepdim=True)
37
+ x = (x - u) / torch.sqrt(s + self.eps)
38
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
39
+ return x
40
+
41
+
42
+
43
+ class ImageEncoderViT(nn.Module):
44
+ def __init__(
45
+ self,
46
+ img_size: int = 1024,
47
+ patch_size: int = 16,
48
+ in_chans: int = 3,
49
+ embed_dim: int = 768,
50
+ depth: int = 12,
51
+ num_heads: int = 12,
52
+ mlp_ratio: float = 4.0,
53
+ out_chans: int = 256,
54
+ qkv_bias: bool = True,
55
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
56
+ act_layer: Type[nn.Module] = nn.GELU,
57
+ use_abs_pos: bool = True,
58
+ use_rel_pos: bool = False,
59
+ rel_pos_zero_init: bool = True,
60
+ window_size: int = 0,
61
+ global_attn_indexes: Tuple[int, ...] = (),
62
+ ) -> None:
63
+ """
64
+ Args:
65
+ img_size (int): Input image size.
66
+ patch_size (int): Patch size.
67
+ in_chans (int): Number of input image channels.
68
+ embed_dim (int): Patch embedding dimension.
69
+ depth (int): Depth of ViT.
70
+ num_heads (int): Number of attention heads in each ViT block.
71
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
72
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
73
+ norm_layer (nn.Module): Normalization layer.
74
+ act_layer (nn.Module): Activation layer.
75
+ use_abs_pos (bool): If True, use absolute positional embeddings.
76
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
77
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
78
+ window_size (int): Window size for window attention blocks.
79
+ global_attn_indexes (list): Indexes for blocks using global attention.
80
+ """
81
+ super().__init__()
82
+ self.img_size = img_size
83
+
84
+ self.patch_embed = PatchEmbed(
85
+ kernel_size=(patch_size, patch_size),
86
+ stride=(patch_size, patch_size),
87
+ in_chans=in_chans,
88
+ embed_dim=embed_dim,
89
+ )
90
+
91
+ self.pos_embed: Optional[nn.Parameter] = None
92
+ if use_abs_pos:
93
+ # Initialize absolute positional embedding with pretrain image size.
94
+ self.pos_embed = nn.Parameter(
95
+ torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
96
+ )
97
+
98
+ self.blocks = nn.ModuleList()
99
+ for i in range(depth):
100
+ block = Block(
101
+ dim=embed_dim,
102
+ num_heads=num_heads,
103
+ mlp_ratio=mlp_ratio,
104
+ qkv_bias=qkv_bias,
105
+ norm_layer=norm_layer,
106
+ act_layer=act_layer,
107
+ use_rel_pos=use_rel_pos,
108
+ rel_pos_zero_init=rel_pos_zero_init,
109
+ window_size=window_size if i not in global_attn_indexes else 0,
110
+ input_size=(img_size // patch_size, img_size // patch_size),
111
+ )
112
+ self.blocks.append(block)
113
+
114
+ self.neck = nn.Sequential(
115
+ nn.Conv2d(
116
+ embed_dim,
117
+ out_chans,
118
+ kernel_size=1,
119
+ bias=False,
120
+ ),
121
+ LayerNorm2d(out_chans),
122
+ nn.Conv2d(
123
+ out_chans,
124
+ out_chans,
125
+ kernel_size=3,
126
+ padding=1,
127
+ bias=False,
128
+ ),
129
+ LayerNorm2d(out_chans),
130
+ )
131
+
132
+
133
+ self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
134
+ self.net_3 = nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, bias=False)
135
+
136
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
137
+ x = self.patch_embed(x)
138
+ if self.pos_embed is not None:
139
+ x = x + self.pos_embed
140
+
141
+ for blk in self.blocks:
142
+ x = blk(x)
143
+
144
+ x = self.neck(x.permute(0, 3, 1, 2))
145
+ x = self.net_2(x)
146
+ x = self.net_3(x)
147
+
148
+
149
+ return x
150
+
151
+
152
+ class Block(nn.Module):
153
+ """Transformer blocks with support of window attention and residual propagation blocks"""
154
+
155
+ def __init__(
156
+ self,
157
+ dim: int,
158
+ num_heads: int,
159
+ mlp_ratio: float = 4.0,
160
+ qkv_bias: bool = True,
161
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
162
+ act_layer: Type[nn.Module] = nn.GELU,
163
+ use_rel_pos: bool = False,
164
+ rel_pos_zero_init: bool = True,
165
+ window_size: int = 0,
166
+ input_size: Optional[Tuple[int, int]] = None,
167
+ ) -> None:
168
+ """
169
+ Args:
170
+ dim (int): Number of input channels.
171
+ num_heads (int): Number of attention heads in each ViT block.
172
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
173
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
174
+ norm_layer (nn.Module): Normalization layer.
175
+ act_layer (nn.Module): Activation layer.
176
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
177
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
178
+ window_size (int): Window size for window attention blocks. If it equals 0, then
179
+ use global attention.
180
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
181
+ positional parameter size.
182
+ """
183
+ super().__init__()
184
+ self.norm1 = norm_layer(dim)
185
+ self.attn = Attention(
186
+ dim,
187
+ num_heads=num_heads,
188
+ qkv_bias=qkv_bias,
189
+ use_rel_pos=use_rel_pos,
190
+ rel_pos_zero_init=rel_pos_zero_init,
191
+ input_size=input_size if window_size == 0 else (window_size, window_size),
192
+ )
193
+
194
+ self.norm2 = norm_layer(dim)
195
+ self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
196
+
197
+ self.window_size = window_size
198
+
199
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
200
+ shortcut = x
201
+ x = self.norm1(x)
202
+ # Window partition
203
+ if self.window_size > 0:
204
+ H, W = x.shape[1], x.shape[2]
205
+ x, pad_hw = window_partition(x, self.window_size)
206
+
207
+ x = self.attn(x)
208
+ # Reverse window partition
209
+ if self.window_size > 0:
210
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
211
+
212
+ x = shortcut + x
213
+ x = x + self.mlp(self.norm2(x))
214
+
215
+ return x
216
+
217
+
218
+ class Attention(nn.Module):
219
+ """Multi-head Attention block with relative position embeddings."""
220
+
221
+ def __init__(
222
+ self,
223
+ dim: int,
224
+ num_heads: int = 8,
225
+ qkv_bias: bool = True,
226
+ use_rel_pos: bool = False,
227
+ rel_pos_zero_init: bool = True,
228
+ input_size: Optional[Tuple[int, int]] = None,
229
+ ) -> None:
230
+ """
231
+ Args:
232
+ dim (int): Number of input channels.
233
+ num_heads (int): Number of attention heads.
234
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
235
+ rel_pos (bool): If True, add relative positional embeddings to the attention map.
236
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
237
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
238
+ positional parameter size.
239
+ """
240
+ super().__init__()
241
+ self.num_heads = num_heads
242
+ head_dim = dim // num_heads
243
+ self.scale = head_dim**-0.5
244
+
245
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
246
+ self.proj = nn.Linear(dim, dim)
247
+
248
+ self.use_rel_pos = use_rel_pos
249
+ if self.use_rel_pos:
250
+ assert (
251
+ input_size is not None
252
+ ), "Input size must be provided if using relative positional encoding."
253
+ # initialize relative positional embeddings
254
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
255
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
256
+
257
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
258
+ B, H, W, _ = x.shape
259
+ # qkv with shape (3, B, nHead, H * W, C)
260
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
261
+ # q, k, v with shape (B * nHead, H * W, C)
262
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
263
+
264
+ attn = (q * self.scale) @ k.transpose(-2, -1)
265
+
266
+ if self.use_rel_pos:
267
+ attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
268
+
269
+ attn = attn.softmax(dim=-1)
270
+ x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
271
+ x = self.proj(x)
272
+
273
+ return x
274
+
275
+
276
+ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
277
+ """
278
+ Partition into non-overlapping windows with padding if needed.
279
+ Args:
280
+ x (tensor): input tokens with [B, H, W, C].
281
+ window_size (int): window size.
282
+
283
+ Returns:
284
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
285
+ (Hp, Wp): padded height and width before partition
286
+ """
287
+ B, H, W, C = x.shape
288
+
289
+ pad_h = (window_size - H % window_size) % window_size
290
+ pad_w = (window_size - W % window_size) % window_size
291
+ if pad_h > 0 or pad_w > 0:
292
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
293
+ Hp, Wp = H + pad_h, W + pad_w
294
+
295
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
296
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
297
+ return windows, (Hp, Wp)
298
+
299
+
300
+ def window_unpartition(
301
+ windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
302
+ ) -> torch.Tensor:
303
+ """
304
+ Window unpartition into original sequences and removing padding.
305
+ Args:
306
+ windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
307
+ window_size (int): window size.
308
+ pad_hw (Tuple): padded height and width (Hp, Wp).
309
+ hw (Tuple): original height and width (H, W) before padding.
310
+
311
+ Returns:
312
+ x: unpartitioned sequences with [B, H, W, C].
313
+ """
314
+ Hp, Wp = pad_hw
315
+ H, W = hw
316
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
317
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
318
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
319
+
320
+ if Hp > H or Wp > W:
321
+ x = x[:, :H, :W, :].contiguous()
322
+ return x
323
+
324
+
325
+ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
326
+ """
327
+ Get relative positional embeddings according to the relative positions of
328
+ query and key sizes.
329
+ Args:
330
+ q_size (int): size of query q.
331
+ k_size (int): size of key k.
332
+ rel_pos (Tensor): relative position embeddings (L, C).
333
+
334
+ Returns:
335
+ Extracted positional embeddings according to relative positions.
336
+ """
337
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
338
+ # Interpolate rel pos if needed.
339
+ if rel_pos.shape[0] != max_rel_dist:
340
+ # Interpolate rel pos.
341
+ rel_pos_resized = F.interpolate(
342
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
343
+ size=max_rel_dist,
344
+ mode="linear",
345
+ )
346
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
347
+ else:
348
+ rel_pos_resized = rel_pos
349
+
350
+ # Scale the coords with short length if shapes for q and k are different.
351
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
352
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
353
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
354
+
355
+ return rel_pos_resized[relative_coords.long()]
356
+
357
+
358
+ def add_decomposed_rel_pos(
359
+ attn: torch.Tensor,
360
+ q: torch.Tensor,
361
+ rel_pos_h: torch.Tensor,
362
+ rel_pos_w: torch.Tensor,
363
+ q_size: Tuple[int, int],
364
+ k_size: Tuple[int, int],
365
+ ) -> torch.Tensor:
366
+ """
367
+ Args:
368
+ attn (Tensor): attention map.
369
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
370
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
371
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
372
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
373
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
374
+
375
+ Returns:
376
+ attn (Tensor): attention map with added relative positional embeddings.
377
+ """
378
+ q_h, q_w = q_size
379
+ k_h, k_w = k_size
380
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
381
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
382
+
383
+ B, _, dim = q.shape
384
+ r_q = q.reshape(B, q_h, q_w, dim)
385
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
386
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
387
+
388
+ attn = (
389
+ attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
390
+ ).view(B, q_h * q_w, k_h * k_w)
391
+
392
+ return attn
393
+
394
+
395
+ class PatchEmbed(nn.Module):
396
+ """
397
+ Image to Patch Embedding.
398
+ """
399
+
400
+ def __init__(
401
+ self,
402
+ kernel_size: Tuple[int, int] = (16, 16),
403
+ stride: Tuple[int, int] = (16, 16),
404
+ padding: Tuple[int, int] = (0, 0),
405
+ in_chans: int = 3,
406
+ embed_dim: int = 768,
407
+ ) -> None:
408
+ """
409
+ Args:
410
+ kernel_size (Tuple): kernel size of the projection layer.
411
+ stride (Tuple): stride of the projection layer.
412
+ padding (Tuple): padding size of the projection layer.
413
+ in_chans (int): Number of input image channels.
414
+ embed_dim (int): Patch embedding dimension.
415
+ """
416
+ super().__init__()
417
+
418
+ self.proj = nn.Conv2d(
419
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
420
+ )
421
+
422
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
423
+ x = self.proj(x)
424
+ # B C H W -> B H W C
425
+ x = x.permute(0, 2, 3, 1)
426
+ return x
427
+
428
+
429
+
430
+ def build_GOT_vit_b(checkpoint=None):
431
+ return _build_GOT_vision(
432
+ encoder_embed_dim=768,
433
+ encoder_depth=12,
434
+ encoder_num_heads=12,
435
+ encoder_global_attn_indexes=[2, 5, 8, 11],
436
+ checkpoint=checkpoint,
437
+ )
438
+
439
+
440
+ def _build_GOT_vision(
441
+ encoder_embed_dim,
442
+ encoder_depth,
443
+ encoder_num_heads,
444
+ encoder_global_attn_indexes,
445
+ checkpoint=None,
446
+ ):
447
+ prompt_embed_dim = 256
448
+ image_size = 1024
449
+ vit_patch_size = 16
450
+ image_embedding_size = image_size // vit_patch_size
451
+ image_encoder=ImageEncoderViT(
452
+ depth=encoder_depth,
453
+ embed_dim=encoder_embed_dim,
454
+ img_size=image_size,
455
+ mlp_ratio=4,
456
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
457
+ num_heads=encoder_num_heads,
458
+ patch_size=vit_patch_size,
459
+ qkv_bias=True,
460
+ use_rel_pos=True,
461
+ global_attn_indexes=encoder_global_attn_indexes,
462
+ window_size=14,
463
+ out_chans=prompt_embed_dim,
464
+ )
465
+
466
+
467
+ return image_encoder
468
+
custom_got/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:77d6144039548b14253176b6eb264896bc39eba532f8894700f210a7fd2a5956
3
+ size 1432121416
custom_got/modeling_GOT.py ADDED
@@ -0,0 +1,881 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM, StoppingCriteria, TextStreamer
2
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
3
+ from typing import List, Optional, Tuple, Union
4
+ from transformers.cache_utils import Cache
5
+ import requests
6
+ from PIL import Image
7
+ from io import BytesIO
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.nn import CrossEntropyLoss
11
+ from .got_vision_b import build_GOT_vit_b
12
+ from torchvision import transforms
13
+ from torchvision.transforms.functional import InterpolationMode
14
+ import dataclasses
15
+ ###
16
+
17
+ DEFAULT_IMAGE_TOKEN = "<image>"
18
+ DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
19
+ DEFAULT_IM_START_TOKEN = '<img>'
20
+ DEFAULT_IM_END_TOKEN = '</img>'
21
+
22
+ from enum import auto, Enum
23
+ class SeparatorStyle(Enum):
24
+ """Different separator style."""
25
+ SINGLE = auto()
26
+ TWO = auto()
27
+ MPT = auto()
28
+
29
+
30
+ @dataclasses.dataclass
31
+ class Conversation:
32
+ """A class that keeps all conversation history."""
33
+ system: str
34
+ roles: List[str]
35
+ messages: List[List[str]]
36
+ offset: int
37
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
38
+ sep: str = "<|im_end|>"
39
+ sep2: str = None
40
+ version: str = "Unknown"
41
+
42
+ skip_next: bool = False
43
+
44
+ def get_prompt(self):
45
+ if self.sep_style == SeparatorStyle.SINGLE:
46
+ ret = self.system + self.sep + '\n'
47
+ for role, message in self.messages:
48
+ if message:
49
+ if type(message) is tuple:
50
+ message, _, _ = message
51
+ ret += role + ": " + message + self.sep
52
+ else:
53
+ ret += role + ":"
54
+ return ret
55
+ elif self.sep_style == SeparatorStyle.TWO:
56
+ seps = [self.sep, self.sep2]
57
+ ret = self.system + seps[0]
58
+ for i, (role, message) in enumerate(self.messages):
59
+ if message:
60
+ if type(message) is tuple:
61
+ message, _, _ = message
62
+ ret += role + ": " + message + seps[i % 2]
63
+ else:
64
+ ret += role + ":"
65
+ return ret
66
+ if self.sep_style == SeparatorStyle.MPT:
67
+ if self.system:
68
+ ret = self.system + self.sep
69
+ else:
70
+ ret = ''
71
+ for role, message in self.messages:
72
+ if message:
73
+ if type(message) is tuple:
74
+ message, _, _ = message
75
+ ret += role + message + self.sep
76
+ else:
77
+ ret += role
78
+ return ret
79
+ else:
80
+ raise ValueError(f"Invalid style: {self.sep_style}")
81
+
82
+
83
+ def append_message(self, role, message):
84
+ self.messages.append([role, message])
85
+
86
+ def copy(self):
87
+ return Conversation(
88
+ system=self.system,
89
+ roles=self.roles,
90
+ messages=[[x, y] for x, y in self.messages],
91
+ offset=self.offset,
92
+ sep_style=self.sep_style,
93
+ sep=self.sep,
94
+ sep2=self.sep2)
95
+
96
+
97
+
98
+ class KeywordsStoppingCriteria(StoppingCriteria):
99
+ def __init__(self, keywords, tokenizer, input_ids):
100
+ self.keywords = keywords
101
+ self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords]
102
+ self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1]
103
+ self.tokenizer = tokenizer
104
+ self.start_len = None
105
+ self.input_ids = input_ids
106
+
107
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
108
+ if self.start_len is None:
109
+ self.start_len = self.input_ids.shape[1]
110
+ else:
111
+ for keyword_id in self.keyword_ids:
112
+ if output_ids[0, -1] == keyword_id:
113
+ return True
114
+ outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
115
+ for keyword in self.keywords:
116
+ if keyword in outputs:
117
+ return True
118
+ return False
119
+
120
+
121
+ class GOTImageEvalProcessor:
122
+ def __init__(self, image_size=384, mean=None, std=None):
123
+ if mean is None:
124
+ mean = (0.48145466, 0.4578275, 0.40821073)
125
+ if std is None:
126
+ std = (0.26862954, 0.26130258, 0.27577711)
127
+
128
+ self.normalize = transforms.Normalize(mean, std)
129
+
130
+ self.transform = transforms.Compose(
131
+ [
132
+ transforms.Resize(
133
+ (image_size, image_size), interpolation=InterpolationMode.BICUBIC
134
+ ),
135
+ transforms.ToTensor(),
136
+ self.normalize,
137
+ ]
138
+ )
139
+ def __call__(self, item):
140
+ return self.transform(item)
141
+
142
+
143
+
144
+ class GOTConfig(Qwen2Config):
145
+ model_type = "GOT"
146
+
147
+
148
+ class GOTQwenModel(Qwen2Model):
149
+ config_class = GOTConfig
150
+
151
+ def __init__(self, config: Qwen2Config):
152
+ super(GOTQwenModel, self).__init__(config)
153
+
154
+ self.vision_tower_high = build_GOT_vit_b()
155
+
156
+ self.mm_projector_vary = nn.Linear(1024, 1024)
157
+
158
+
159
+ def initialize_vision_modules(
160
+ self,
161
+ vision_tower,
162
+ pretrained_stage1_model=None,
163
+ freeze_vision_tower=False,
164
+ use_im_start_end=False,
165
+ vision_select_layer=-1,
166
+ dtype=torch.float16,
167
+ device="cuda"
168
+ ):
169
+
170
+
171
+ image_processor_high = GOTImageEvalProcessor(image_size=1024)
172
+
173
+ self.vision_tower_high = self.vision_tower_high.to(dtype=dtype, device=device)
174
+
175
+ self.mm_projector_vary = self.mm_projector_vary.to(dtype=dtype, device=device)
176
+
177
+
178
+ image_token_len = 256
179
+
180
+ self.config.vision_tower = vision_tower
181
+ self.config.image_token_len = image_token_len
182
+
183
+ self.config.use_im_start_end = True
184
+
185
+ self.config.vision_select_layer = vision_select_layer
186
+ self.config.freeze_vision_tower = freeze_vision_tower
187
+
188
+ return dict(
189
+ image_processor_high=image_processor_high,
190
+ image_token_len=image_token_len,
191
+ )
192
+
193
+
194
+ def forward(
195
+ self,
196
+ input_ids: torch.LongTensor = None,
197
+ attention_mask: Optional[torch.Tensor] = None,
198
+ position_ids: Optional[torch.LongTensor] = None,
199
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
200
+ inputs_embeds: Optional[torch.FloatTensor] = None,
201
+ use_cache: Optional[bool] = None,
202
+ output_attentions: Optional[bool] = None,
203
+ output_hidden_states: Optional[bool] = None,
204
+ images: Optional[torch.FloatTensor] = None,
205
+ return_dict: Optional[bool] = None,
206
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
207
+
208
+ # HACK: replace back original embeddings for LLaVA pretraining
209
+ orig_embeds_params = getattr(self, 'orig_embeds_params', None)
210
+ if orig_embeds_params is not None:
211
+ with torch.no_grad():
212
+ self.get_input_embeddings().weight[:-self.num_new_tokens] = orig_embeds_params[:-self.num_new_tokens].data
213
+
214
+ if inputs_embeds is None:
215
+ inputs_embeds = self.embed_tokens(input_ids)
216
+
217
+
218
+ vision_tower_high = getattr(self, 'vision_tower_high', None)
219
+
220
+
221
+ if vision_tower_high is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
222
+ use_im_start_end = getattr(self.config, "use_im_start_end", -1)
223
+
224
+ vision_select_layer = getattr(self.config, "vision_select_layer", -1)
225
+ im_patch_token = getattr(self.config, "im_patch_token", -1)
226
+ im_start_token = getattr(self.config, "im_start_token", -1)
227
+ im_end_token = getattr(self.config, "im_end_token", -1)
228
+ freeze_vision_tower = getattr(self.config, "freeze_vision_tower", False)
229
+
230
+ im_patch_token = 151859
231
+
232
+ im_start_token = 151857
233
+
234
+ im_end_token = 151858
235
+
236
+ image_features = []
237
+
238
+ for image in images:
239
+ P, C, H, W = image.shape
240
+ if P == 1:
241
+ with torch.set_grad_enabled(False):
242
+ cnn_feature = vision_tower_high(image)
243
+ cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # 256*1024
244
+ image_feature = self.mm_projector_vary(cnn_feature)
245
+ image_features.append(image_feature)
246
+
247
+ else:
248
+ image_patches = torch.unbind(image)
249
+ image_patches_features = []
250
+ for image_patch in image_patches:
251
+ image_p = torch.stack([image_patch])
252
+
253
+ with torch.set_grad_enabled(False):
254
+ cnn_feature_p = vision_tower_high(image_p)
255
+ cnn_feature_p = cnn_feature_p.flatten(2).permute(0, 2, 1)
256
+ image_feature_p = self.mm_projector_vary(cnn_feature_p)
257
+ image_patches_features.append(image_feature_p)
258
+ image_feature = torch.cat(image_patches_features, dim=1)
259
+ image_features.append(image_feature)
260
+
261
+
262
+ dummy_image_features_2 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
263
+ dummy_image_features = dummy_image_features_2
264
+ use_im_start_end = True
265
+ new_input_embeds = []
266
+ for cur_input_ids, cur_input_embeds, cur_image_features in zip(input_ids, inputs_embeds, image_features):
267
+ if (cur_input_ids == im_patch_token).sum() == 0:
268
+ cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
269
+ new_input_embeds.append(cur_input_embeds)
270
+ continue
271
+
272
+ if use_im_start_end:
273
+ if (cur_input_ids == im_start_token).sum() != (cur_input_ids == im_end_token).sum():
274
+ raise ValueError("The number of image start tokens and image end tokens should be the same.")
275
+
276
+ image_start_tokens = torch.where(cur_input_ids == im_start_token)[0]
277
+ for image_start_token_pos, per_cur_image_features in zip(image_start_tokens, cur_image_features):
278
+ per_cur_image_features = per_cur_image_features.to(device=cur_input_embeds.device)
279
+ num_patches = per_cur_image_features.shape[0]
280
+
281
+ if cur_input_ids[image_start_token_pos + num_patches + 1] != im_end_token:
282
+ raise ValueError("The image end token should follow the image start token.")
283
+
284
+ cur_input_embeds = torch.cat(
285
+ (
286
+ cur_input_embeds[:image_start_token_pos+1],
287
+ per_cur_image_features,
288
+ cur_input_embeds[image_start_token_pos + num_patches + 1:]
289
+ ),
290
+ dim=0
291
+ )
292
+
293
+
294
+ new_input_embeds.append(cur_input_embeds)
295
+ else:
296
+ raise NotImplementedError
297
+
298
+ inputs_embeds = torch.stack(new_input_embeds, dim=0)
299
+
300
+ return super(GOTQwenModel, self).forward(
301
+ input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
302
+ inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids,
303
+ output_attentions=output_attentions, output_hidden_states=output_hidden_states,
304
+ return_dict=return_dict
305
+ )
306
+
307
+
308
+
309
+ class GOTQwenForCausalLM(Qwen2ForCausalLM):
310
+ config_class = GOTConfig
311
+ # supports_gradient_checkpointing = True
312
+
313
+ def __init__(self, config):
314
+ super(Qwen2ForCausalLM, self).__init__(config)
315
+ self.model = GOTQwenModel(config)
316
+
317
+ self.vocab_size = config.vocab_size
318
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
319
+
320
+ # Initialize weights and apply final processing
321
+ self.post_init()
322
+
323
+ def get_model(self):
324
+ return self.model
325
+
326
+ def forward(
327
+ self,
328
+ input_ids: torch.LongTensor = None,
329
+ attention_mask: Optional[torch.Tensor] = None,
330
+ position_ids: Optional[torch.LongTensor] = None,
331
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
332
+ inputs_embeds: Optional[torch.FloatTensor] = None,
333
+ labels: Optional[torch.LongTensor] = None,
334
+ use_cache: Optional[bool] = None,
335
+ output_attentions: Optional[bool] = None,
336
+ output_hidden_states: Optional[bool] = None,
337
+ images: Optional[torch.FloatTensor] = None,
338
+ return_dict: Optional[bool] = None,
339
+
340
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
341
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
342
+ output_hidden_states = (
343
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
344
+ )
345
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
346
+
347
+ outputs = self.model(
348
+ input_ids=input_ids,
349
+ past_key_values=past_key_values,
350
+ attention_mask=attention_mask,
351
+ position_ids=position_ids,
352
+ inputs_embeds=inputs_embeds,
353
+ use_cache=use_cache,
354
+ output_attentions=output_attentions,
355
+ output_hidden_states=output_hidden_states,
356
+ images=images,
357
+ return_dict=return_dict
358
+
359
+ )
360
+
361
+ hidden_states = outputs[0]
362
+ logits = self.lm_head(hidden_states)
363
+ logits = logits.float()
364
+
365
+ # logits
366
+
367
+ loss = None
368
+ if labels is not None:
369
+ # Shift so that tokens < n predict n
370
+ shift_logits = logits[..., :-1, :].contiguous()
371
+ shift_labels = labels[..., 1:].contiguous()
372
+ # Flatten the tokens
373
+ loss_fct = CrossEntropyLoss()
374
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
375
+ shift_labels = shift_labels.view(-1)
376
+ # Enable model parallelism
377
+ shift_labels = shift_labels.to(shift_logits.device)
378
+ loss = loss_fct(shift_logits, shift_labels)
379
+
380
+ if not return_dict:
381
+ output = (logits,) + outputs[1:]
382
+ return (loss,) + output if loss is not None else output
383
+
384
+ return CausalLMOutputWithPast(
385
+ loss=loss,
386
+ logits=logits,
387
+ past_key_values=outputs.past_key_values,
388
+ hidden_states=outputs.hidden_states,
389
+ attentions=outputs.attentions,
390
+ )
391
+
392
+
393
+ def prepare_inputs_for_generation(
394
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
395
+ ):
396
+ # Omit tokens covered by past_key_values
397
+ if past_key_values is not None:
398
+ if isinstance(past_key_values, Cache):
399
+ cache_length = past_key_values.get_seq_length()
400
+ past_length = past_key_values.seen_tokens
401
+ max_cache_length = past_key_values.get_max_length()
402
+ else:
403
+ cache_length = past_length = past_key_values[0][0].shape[2]
404
+ max_cache_length = None
405
+
406
+ # Keep only the unprocessed tokens:
407
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
408
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
409
+ # input)
410
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
411
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
412
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
413
+ # input_ids based on the past_length.
414
+ elif past_length < input_ids.shape[1]:
415
+ input_ids = input_ids[:, past_length:]
416
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
417
+
418
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
419
+ if (
420
+ max_cache_length is not None
421
+ and attention_mask is not None
422
+ and cache_length + input_ids.shape[1] > max_cache_length
423
+ ):
424
+ attention_mask = attention_mask[:, -max_cache_length:]
425
+
426
+ position_ids = kwargs.get("position_ids", None)
427
+ if attention_mask is not None and position_ids is None:
428
+ # create position_ids on the fly for batch generation
429
+ position_ids = attention_mask.long().cumsum(-1) - 1
430
+ position_ids.masked_fill_(attention_mask == 0, 1)
431
+ if past_key_values:
432
+ position_ids = position_ids[:, -input_ids.shape[1] :]
433
+
434
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
435
+ if inputs_embeds is not None and past_key_values is None:
436
+ model_inputs = {"inputs_embeds": inputs_embeds}
437
+ else:
438
+ model_inputs = {"input_ids": input_ids}
439
+
440
+ model_inputs.update(
441
+ {
442
+ "position_ids": position_ids,
443
+ "past_key_values": past_key_values,
444
+ "use_cache": kwargs.get("use_cache"),
445
+ "attention_mask": attention_mask,
446
+ "images": kwargs.get("images", None),
447
+ }
448
+ )
449
+ return model_inputs
450
+
451
+ def initialize_vision_tokenizer(
452
+ self,
453
+ tokenizer,
454
+ freeze_lm_model=False,
455
+ pretrained_stage1_model=None,
456
+ device="cuda"
457
+ ):
458
+ config = self.get_model().config
459
+
460
+
461
+ self.resize_token_embeddings(len(tokenizer))
462
+
463
+ config.im_patch_token = 151859
464
+
465
+ config.use_im_start_end = True
466
+
467
+ if config.use_im_start_end:
468
+ self.resize_token_embeddings(len(tokenizer))
469
+ config.im_start_token, config.im_end_token = 151857, 151858
470
+
471
+ def load_image(self, image_file):
472
+ if image_file.startswith('http') or image_file.startswith('https'):
473
+ response = requests.get(image_file)
474
+ image = Image.open(BytesIO(response.content)).convert('RGB')
475
+ else:
476
+ image = Image.open(image_file).convert('RGB')
477
+ return image
478
+
479
+ def disable_torch_init(self):
480
+ """
481
+ Disable the redundant torch default initialization to accelerate model creation.
482
+ """
483
+ import torch
484
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
485
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
486
+
487
+ def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False):
488
+
489
+ self.disable_torch_init()
490
+
491
+
492
+ image_processor_high = GOTImageEvalProcessor(image_size=1024)
493
+
494
+ use_im_start_end = True
495
+
496
+ image_token_len = 256
497
+
498
+ if gradio_input:
499
+ image = image_file.copy()
500
+ else:
501
+ image = self.load_image(image_file)
502
+
503
+ w, h = image.size
504
+
505
+ if ocr_type == 'format':
506
+ qs = 'OCR with format: '
507
+ else:
508
+ qs = 'OCR: '
509
+
510
+ if ocr_box:
511
+ bbox = eval(ocr_box)
512
+ if len(bbox) == 2:
513
+ bbox[0] = int(bbox[0]/w*1000)
514
+ bbox[1] = int(bbox[1]/h*1000)
515
+ if len(bbox) == 4:
516
+ bbox[0] = int(bbox[0]/w*1000)
517
+ bbox[1] = int(bbox[1]/h*1000)
518
+ bbox[2] = int(bbox[2]/w*1000)
519
+ bbox[3] = int(bbox[3]/h*1000)
520
+ if ocr_type == 'format':
521
+ qs = str(bbox) + ' ' + 'OCR with format: '
522
+ else:
523
+ qs = str(bbox) + ' ' + 'OCR: '
524
+
525
+ if ocr_color:
526
+ if ocr_type == 'format':
527
+ qs = '[' + ocr_color + ']' + ' ' + 'OCR with format: '
528
+ else:
529
+ qs = '[' + ocr_color + ']' + ' ' + 'OCR: '
530
+
531
+ if use_im_start_end:
532
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN + '\n' + qs
533
+ else:
534
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
535
+
536
+
537
+ conv_mpt = Conversation(
538
+ system="""<|im_start|>system
539
+ You should follow the instructions carefully and explain your answers in detail.""",
540
+ # system = None,
541
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
542
+ version="mpt",
543
+ messages=(),
544
+ offset=0,
545
+ sep_style=SeparatorStyle.MPT,
546
+ sep="<|im_end|>",
547
+ )
548
+
549
+ conv = conv_mpt.copy()
550
+ conv.append_message(conv.roles[0], qs)
551
+ conv.append_message(conv.roles[1], None)
552
+ prompt = conv.get_prompt()
553
+
554
+ if print_prompt:
555
+ print(prompt)
556
+
557
+ inputs = tokenizer([prompt])
558
+
559
+ image_tensor_1 = image_processor_high(image)
560
+
561
+ input_ids = torch.as_tensor(inputs.input_ids).cuda()
562
+
563
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
564
+ keywords = [stop_str]
565
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
566
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
567
+
568
+ if stream_flag:
569
+ with torch.autocast("cuda", dtype=torch.bfloat16):
570
+ output_ids = self.generate(
571
+ input_ids,
572
+ images=[image_tensor_1.unsqueeze(0).half().cuda()],
573
+ do_sample=False,
574
+ num_beams = 1,
575
+ no_repeat_ngram_size = 20,
576
+ streamer=streamer,
577
+ max_new_tokens=4096,
578
+ stopping_criteria=[stopping_criteria]
579
+ )
580
+ else:
581
+ with torch.autocast("cuda", dtype=torch.bfloat16):
582
+ output_ids = self.generate(
583
+ input_ids,
584
+ images=[image_tensor_1.unsqueeze(0).half().cuda()],
585
+ do_sample=False,
586
+ num_beams = 1,
587
+ no_repeat_ngram_size = 20,
588
+ # streamer=streamer,
589
+ max_new_tokens=4096,
590
+ stopping_criteria=[stopping_criteria]
591
+ )
592
+
593
+ outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
594
+
595
+ if outputs.endswith(stop_str):
596
+ outputs = outputs[:-len(stop_str)]
597
+ outputs = outputs.strip()
598
+ response_str = outputs
599
+
600
+ if render:
601
+ print('==============rendering===============')
602
+ from .render_tools import svg_to_html, content_mmd_to_html, tik_html, translation_table
603
+
604
+ if '**kern' in outputs:
605
+ import verovio
606
+ tk = verovio.toolkit()
607
+ tk.loadData(outputs)
608
+ tk.setOptions({"pageWidth": 2100, "footer": 'none',
609
+ 'barLineWidth': 0.5, 'beamMaxSlope': 15,
610
+ 'staffLineWidth': 0.2, 'spacingStaff': 6})
611
+ tk.getPageCount()
612
+ svg = tk.renderToSVG()
613
+ svg = svg.replace("overflow=\"inherit\"", "overflow=\"visible\"")
614
+
615
+ svg_to_html(svg, save_render_file)
616
+
617
+ if ocr_type == 'format' and '**kern' not in outputs:
618
+
619
+
620
+ if '\\begin{tikzpicture}' not in outputs:
621
+ html_path_2 = save_render_file
622
+ right_num = outputs.count('\\right')
623
+ left_num = outputs.count('\\left')
624
+
625
+ if right_num != left_num:
626
+ outputs = outputs.replace('\\left(', '(').replace('\\right)', ')').replace('\\left[', '[').replace('\\right]', ']').replace('\\left{', '{').replace('\\right}', '}').replace('\\left|', '|').replace('\\right|', '|').replace('\\left.', '.').replace('\\right.', '.')
627
+
628
+
629
+ outputs = outputs.replace('"', '``').replace('$', '')
630
+
631
+ outputs_list = outputs.split('\n')
632
+ gt= ''
633
+ for out in outputs_list:
634
+ gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
635
+
636
+ gt = gt[:-2]
637
+
638
+
639
+ lines = content_mmd_to_html
640
+ lines = lines.split("const text =")
641
+ new_web = lines[0] + 'const text =' + gt + lines[1]
642
+
643
+ else:
644
+ html_path_2 = save_render_file
645
+ outputs = outputs.translate(translation_table)
646
+ outputs_list = outputs.split('\n')
647
+ gt= ''
648
+ for out in outputs_list:
649
+ if out:
650
+ if '\\begin{tikzpicture}' not in out and '\\end{tikzpicture}' not in out:
651
+ while out[-1] == ' ':
652
+ out = out[:-1]
653
+ if out is None:
654
+ break
655
+
656
+ if out:
657
+ if out[-1] != ';':
658
+ gt += out[:-1] + ';\n'
659
+ else:
660
+ gt += out + '\n'
661
+ else:
662
+ gt += out + '\n'
663
+
664
+
665
+ lines = tik_html
666
+ lines = lines.split("const text =")
667
+ new_web = lines[0] + gt + lines[1]
668
+
669
+ with open(html_path_2, 'w') as web_f_new:
670
+ web_f_new.write(new_web)
671
+ return response_str
672
+
673
+ def dynamic_preprocess(self, image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True):
674
+
675
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
676
+ best_ratio_diff = float('inf')
677
+ best_ratio = (1, 1)
678
+ area = width * height
679
+ for ratio in target_ratios:
680
+ target_aspect_ratio = ratio[0] / ratio[1]
681
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
682
+ if ratio_diff < best_ratio_diff:
683
+ best_ratio_diff = ratio_diff
684
+ best_ratio = ratio
685
+ elif ratio_diff == best_ratio_diff:
686
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
687
+ best_ratio = ratio
688
+ # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
689
+ return best_ratio
690
+
691
+ orig_width, orig_height = image.size
692
+ aspect_ratio = orig_width / orig_height
693
+
694
+ # calculate the existing image aspect ratio
695
+ target_ratios = set(
696
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
697
+ i * j <= max_num and i * j >= min_num)
698
+ # print(target_ratios)
699
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
700
+
701
+ # find the closest aspect ratio to the target
702
+ target_aspect_ratio = find_closest_aspect_ratio(
703
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
704
+
705
+ # print(target_aspect_ratio)
706
+ # calculate the target width and height
707
+ target_width = image_size * target_aspect_ratio[0]
708
+ target_height = image_size * target_aspect_ratio[1]
709
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
710
+
711
+ # resize the image
712
+ resized_img = image.resize((target_width, target_height))
713
+ processed_images = []
714
+ for i in range(blocks):
715
+ box = (
716
+ (i % (target_width // image_size)) * image_size,
717
+ (i // (target_width // image_size)) * image_size,
718
+ ((i % (target_width // image_size)) + 1) * image_size,
719
+ ((i // (target_width // image_size)) + 1) * image_size
720
+ )
721
+ # split the image
722
+ split_img = resized_img.crop(box)
723
+ processed_images.append(split_img)
724
+ assert len(processed_images) == blocks
725
+ if use_thumbnail and len(processed_images) != 1:
726
+ thumbnail_img = image.resize((image_size, image_size))
727
+ processed_images.append(thumbnail_img)
728
+ return processed_images
729
+
730
+
731
+ def chat_crop(self, tokenizer, image_file, ocr_type, render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False):
732
+ # Model
733
+ self.disable_torch_init()
734
+ multi_page=False
735
+
736
+
737
+ image_processor_high = GOTImageEvalProcessor(image_size=1024)
738
+
739
+ use_im_start_end = True
740
+
741
+
742
+ image_token_len = 256
743
+
744
+ image_list = []
745
+
746
+ # if len(image_file_list)>1:
747
+ # multi_page = True
748
+
749
+ if multi_page:
750
+ qs = 'OCR with format across multi pages: '
751
+ # only for png files
752
+ # import glob
753
+ # from natsort import natsorted
754
+ # patches = glob.glob(image_file + '/*png')
755
+ patches = image_file
756
+ # patches = natsorted(patches)
757
+ sub_images = []
758
+ for sub_image in patches:
759
+ sub_images.append(self.load_image(sub_image))
760
+
761
+ ll = len(patches)
762
+ # print(patches)
763
+ # print("len ll: ", ll)
764
+
765
+ else:
766
+ if ocr_type == 'format':
767
+ qs = 'OCR with format upon the patch reference: '
768
+ else:
769
+ qs = 'OCR upon the patch reference: '
770
+ if gradio_input:
771
+ img = image_file.copy()
772
+ else:
773
+ img = self.load_image(image_file)
774
+ sub_images = self.dynamic_preprocess(img)
775
+ ll = len(sub_images)
776
+
777
+ for image in sub_images:
778
+ image_tensor_1 = image_processor_high(image)
779
+ image_list.append(image_tensor_1)
780
+
781
+
782
+ image_list = torch.stack(image_list)
783
+
784
+ print('====new images batch size======: \n',image_list.shape)
785
+
786
+
787
+ if use_im_start_end:
788
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len*ll + DEFAULT_IM_END_TOKEN + '\n' + qs
789
+ else:
790
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
791
+
792
+
793
+ conv_mpt = Conversation(
794
+ system="""<|im_start|>system
795
+ You should follow the instructions carefully and explain your answers in detail.""",
796
+ # system = None,
797
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
798
+ version="mpt",
799
+ messages=(),
800
+ offset=0,
801
+ sep_style=SeparatorStyle.MPT,
802
+ sep="<|im_end|>",
803
+ )
804
+
805
+ conv = conv_mpt.copy()
806
+ conv.append_message(conv.roles[0], qs)
807
+ conv.append_message(conv.roles[1], None)
808
+ prompt = conv.get_prompt()
809
+
810
+ if print_prompt:
811
+ print(prompt)
812
+
813
+ inputs = tokenizer([prompt])
814
+
815
+ input_ids = torch.as_tensor(inputs.input_ids).cuda()
816
+
817
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
818
+ keywords = [stop_str]
819
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
820
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
821
+
822
+ if stream_flag:
823
+ with torch.autocast("cuda", dtype=torch.bfloat16):
824
+ output_ids = self.generate(
825
+ input_ids,
826
+ images=[image_list.half().cuda()],
827
+ do_sample=False,
828
+ num_beams = 1,
829
+ # no_repeat_ngram_size = 20,
830
+ streamer=streamer,
831
+ max_new_tokens=4096,
832
+ stopping_criteria=[stopping_criteria]
833
+ )
834
+ else:
835
+ with torch.autocast("cuda", dtype=torch.bfloat16):
836
+ output_ids = self.generate(
837
+ input_ids,
838
+ images=[image_list.half().cuda()],
839
+ do_sample=False,
840
+ num_beams = 1,
841
+ # no_repeat_ngram_size = 20,
842
+ # streamer=streamer,
843
+ max_new_tokens=4096,
844
+ stopping_criteria=[stopping_criteria]
845
+ )
846
+
847
+ outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
848
+
849
+ if outputs.endswith(stop_str):
850
+ outputs = outputs[:-len(stop_str)]
851
+ outputs = outputs.strip()
852
+ response_str = outputs
853
+
854
+ if render:
855
+ print('==============rendering===============')
856
+ from .render_tools import content_mmd_to_html
857
+ html_path_2 = save_render_file
858
+ right_num = outputs.count('\\right')
859
+ left_num = outputs.count('\\left')
860
+
861
+ if right_num != left_num:
862
+ outputs = outputs.replace('\\left(', '(').replace('\\right)', ')').replace('\\left[', '[').replace('\\right]', ']').replace('\\left{', '{').replace('\\right}', '}').replace('\\left|', '|').replace('\\right|', '|').replace('\\left.', '.').replace('\\right.', '.')
863
+
864
+
865
+ outputs = outputs.replace('"', '``').replace('$', '')
866
+
867
+ outputs_list = outputs.split('\n')
868
+ gt= ''
869
+ for out in outputs_list:
870
+ gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
871
+
872
+ gt = gt[:-2]
873
+
874
+ lines = content_mmd_to_html
875
+ lines = lines.split("const text =")
876
+ new_web = lines[0] + 'const text =' + gt + lines[1]
877
+
878
+ with open(html_path_2, 'w') as web_f_new:
879
+ web_f_new.write(new_web)
880
+
881
+ return response_str
custom_got/qwen.tiktoken ADDED
The diff for this file is too large to render. See raw diff
 
custom_got/render_tools.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ punctuation_dict = {
3
+ ",": ",",
4
+ "。": ".",
5
+
6
+ }
7
+ translation_table = str.maketrans(punctuation_dict)
8
+
9
+ def svg_to_html(svg_content, output_filename):
10
+
11
+ html_content = f"""
12
+ <!DOCTYPE html>
13
+ <html lang="en">
14
+ <head>
15
+ <meta charset="UTF-8">
16
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
17
+ <title>SVG Embedded in HTML</title>
18
+ </head>
19
+ <body>
20
+ <svg width="2100" height="15000" xmlns="http://www.w3.org/2000/svg">
21
+ {svg_content}
22
+ </svg>
23
+ </body>
24
+ </html>
25
+ """
26
+
27
+ with open(output_filename, 'w') as file:
28
+ file.write(html_content)
29
+
30
+
31
+
32
+ content_mmd_to_html = """<!DOCTYPE html>
33
+ <html lang="en" data-lt-installed="true"><head>
34
+ <meta charset="UTF-8">
35
+ <title>Title</title>
36
+ <script>
37
+ const text =
38
+ </script>
39
+ <style>
40
+ #content {
41
+ max-width: 800px;
42
+ margin: auto;
43
+ }
44
+ </style>
45
+ <script>
46
+ let script = document.createElement('script');
47
+ script.src = "https://cdn.jsdelivr.net/npm/mathpix-markdown-it@1.3.6/es5/bundle.js";
48
+ document.head.append(script);
49
+
50
+ script.onload = function() {
51
+ const isLoaded = window.loadMathJax();
52
+ if (isLoaded) {
53
+ console.log('Styles loaded!')
54
+ }
55
+
56
+ const el = window.document.getElementById('content-text');
57
+ if (el) {
58
+ const options = {
59
+ htmlTags: true
60
+ };
61
+ const html = window.render(text, options);
62
+ el.outerHTML = html;
63
+ }
64
+ };
65
+ </script>
66
+ </head>
67
+ <body>
68
+ <div id="content"><div id="content-text"></div></div>
69
+ </body>
70
+ </html>
71
+ """
72
+
73
+
74
+
75
+ tik_html = """
76
+ <!DOCTYPE html>
77
+
78
+ <html>
79
+
80
+ <head>
81
+ <meta charset="UTF-8">
82
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
83
+ <title>Document</title>
84
+ <link rel="stylesheet" type="text/css" href="https://tikzjax.com/v1/fonts.css">
85
+ <script src="https://tikzjax.com/v1/tikzjax.js"></script>
86
+ </head>
87
+ <body>
88
+ <script type="text/tikz">
89
+ const text =
90
+ </script>
91
+ </body>
92
+ </html>"""
93
+
94
+
95
+
96
+ # print(tik_html)
custom_got/special_tokens_map.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "pad_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ }
9
+ }
custom_got/tokenization_qwen.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """Tokenization classes for QWen."""
7
+
8
+ import base64
9
+ import logging
10
+ import os
11
+ import unicodedata
12
+ from typing import Collection, Dict, List, Set, Tuple, Union
13
+
14
+ import tiktoken
15
+ from transformers import PreTrainedTokenizer, AddedToken
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken"}
21
+
22
+ PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
23
+ ENDOFTEXT = "<|endoftext|>"
24
+ IMSTART = "<|im_start|>"
25
+ IMEND = "<|im_end|>"
26
+ # as the default behavior is changed to allow special tokens in
27
+ # regular texts, the surface forms of special tokens need to be
28
+ # as different as possible to minimize the impact
29
+ EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
30
+ SPECIAL_TOKENS = (
31
+ ENDOFTEXT,
32
+ IMSTART,
33
+ IMEND,
34
+ ) + EXTRAS
35
+
36
+
37
+ def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
38
+ with open(tiktoken_bpe_file, "rb") as f:
39
+ contents = f.read()
40
+ return {
41
+ base64.b64decode(token): int(rank)
42
+ for token, rank in (line.split() for line in contents.splitlines() if line)
43
+ }
44
+
45
+ class QWenTokenizer(PreTrainedTokenizer):
46
+ """QWen tokenizer."""
47
+
48
+ vocab_files_names = VOCAB_FILES_NAMES
49
+
50
+ def __init__(
51
+ self,
52
+ vocab_file,
53
+ errors="replace",
54
+ image_start_tag='<img>',
55
+ image_end_tag='</img>',
56
+ image_pad_tag='<imgpad>',
57
+ ref_start_tag='<ref>',
58
+ ref_end_tag='</ref>',
59
+ box_start_tag='<box>',
60
+ box_end_tag='</box>',
61
+ quad_start_tag='<quad>',
62
+ quad_end_tag='</quad>',
63
+ **kwargs,
64
+ ):
65
+ super().__init__(**kwargs)
66
+
67
+ self.image_start_tag = image_start_tag
68
+ self.image_end_tag = image_end_tag
69
+ self.image_pad_tag = image_pad_tag
70
+ self.ref_start_tag = ref_start_tag
71
+ self.ref_end_tag = ref_end_tag
72
+ self.box_start_tag = box_start_tag
73
+ self.box_end_tag = box_end_tag
74
+ self.quad_start_tag = quad_start_tag
75
+ self.quad_end_tag = quad_end_tag
76
+ self.IMAGE_ST = (
77
+ ref_start_tag, ref_end_tag,
78
+ box_start_tag, box_end_tag,
79
+ quad_start_tag, quad_end_tag,
80
+ image_start_tag, image_end_tag,
81
+ image_pad_tag
82
+ )
83
+
84
+ self.errors = errors # how to handle errors in decoding
85
+
86
+ self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: dict[bytes, int]
87
+ self.special_tokens = {
88
+ token: index
89
+ for index, token in enumerate(
90
+ SPECIAL_TOKENS + self.IMAGE_ST, start=len(self.mergeable_ranks)
91
+ )
92
+ }
93
+
94
+ self.img_start_id = self.special_tokens[self.image_start_tag]
95
+ self.img_end_id = self.special_tokens[self.image_end_tag]
96
+ self.img_pad_id = self.special_tokens[self.image_pad_tag]
97
+ self.ref_start_id = self.special_tokens[self.ref_start_tag]
98
+ self.ref_end_id = self.special_tokens[self.ref_end_tag]
99
+ self.box_start_id = self.special_tokens[self.box_start_tag]
100
+ self.box_end_id = self.special_tokens[self.box_end_tag]
101
+ self.quad_start_id = self.special_tokens[self.quad_start_tag]
102
+ self.quad_end_id = self.special_tokens[self.quad_end_tag]
103
+
104
+ enc = tiktoken.Encoding(
105
+ "Qwen",
106
+ pat_str=PAT_STR,
107
+ mergeable_ranks=self.mergeable_ranks,
108
+ special_tokens=self.special_tokens,
109
+ )
110
+ assert (
111
+ len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
112
+ ), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
113
+
114
+ self.decoder = {
115
+ v: k for k, v in self.mergeable_ranks.items()
116
+ } # type: dict[int, bytes|str]
117
+ self.decoder.update({v: k for k, v in self.special_tokens.items()})
118
+
119
+ self.tokenizer = enc # type: tiktoken.Encoding
120
+
121
+ self.eod_id = self.tokenizer.eot_token
122
+ self.im_start_id = self.special_tokens[IMSTART]
123
+ self.im_end_id = self.special_tokens[IMEND]
124
+
125
+ def __len__(self) -> int:
126
+ return self.tokenizer.n_vocab
127
+
128
+ def get_vocab(self) -> Dict[bytes, int]:
129
+ return self.mergeable_ranks
130
+
131
+ def convert_tokens_to_ids(
132
+ self, tokens: Union[bytes, str, List[Union[bytes, str]]]
133
+ ) -> List[int]:
134
+ ids = []
135
+ if isinstance(tokens, (str, bytes)):
136
+ if tokens in self.special_tokens:
137
+ return self.special_tokens[tokens]
138
+ else:
139
+ return self.mergeable_ranks.get(tokens)
140
+ for token in tokens:
141
+ if token in self.special_tokens:
142
+ ids.append(self.special_tokens[token])
143
+ else:
144
+ ids.append(self.mergeable_ranks.get(token))
145
+ return ids
146
+
147
+ def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
148
+ if not special_tokens and new_tokens:
149
+ raise ValueError('Adding regular tokens is not supported')
150
+ for token in new_tokens:
151
+ surface_form = token.content if isinstance(token, AddedToken) else token
152
+ if surface_form not in SPECIAL_TOKENS:
153
+ raise ValueError('Adding unknown special tokens is not supported')
154
+ return 0
155
+
156
+ def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
157
+ """
158
+ Save only the vocabulary of the tokenizer (vocabulary).
159
+
160
+ Returns:
161
+ `Tuple(str)`: Paths to the files saved.
162
+ """
163
+ file_path = os.path.join(save_directory, "qwen.tiktoken")
164
+ with open(file_path, "w", encoding="utf8") as w:
165
+ for k, v in self.mergeable_ranks.items():
166
+ line = base64.b64encode(k).decode("utf8") + " " + str(v) + "\n"
167
+ w.write(line)
168
+ return (file_path,)
169
+
170
+ def tokenize(
171
+ self,
172
+ text: str,
173
+ allowed_special: Union[Set, str] = "all",
174
+ disallowed_special: Union[Collection, str] = (),
175
+ **kwargs,
176
+ ) -> List[Union[bytes, str]]:
177
+ """
178
+ Converts a string in a sequence of tokens.
179
+
180
+ Args:
181
+ text (`str`):
182
+ The sequence to be encoded.
183
+ allowed_special (`Literal["all"]` or `set`):
184
+ The surface forms of the tokens to be encoded as special tokens in regular texts.
185
+ Default to "all".
186
+ disallowed_special (`Literal["all"]` or `Collection`):
187
+ The surface forms of the tokens that should not be in regular texts and trigger errors.
188
+ Default to an empty tuple.
189
+
190
+ kwargs (additional keyword arguments, *optional*):
191
+ Will be passed to the underlying model specific encode method.
192
+
193
+ Returns:
194
+ `List[bytes|str]`: The list of tokens.
195
+ """
196
+ tokens = []
197
+ text = unicodedata.normalize("NFC", text)
198
+
199
+ # this implementation takes a detour: text -> token id -> token surface forms
200
+ for t in self.tokenizer.encode(
201
+ text, allowed_special=allowed_special, disallowed_special=disallowed_special
202
+ ):
203
+ tokens.append(self.decoder[t])
204
+ return tokens
205
+
206
+ def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
207
+ """
208
+ Converts a sequence of tokens in a single string.
209
+ """
210
+ text = ""
211
+ temp = b""
212
+ for t in tokens:
213
+ if isinstance(t, str):
214
+ if temp:
215
+ text += temp.decode("utf-8", errors=self.errors)
216
+ temp = b""
217
+ text += t
218
+ elif isinstance(t, bytes):
219
+ temp += t
220
+ else:
221
+ raise TypeError("token should only be of type types or str")
222
+ if temp:
223
+ text += temp.decode("utf-8", errors=self.errors)
224
+ return text
225
+
226
+ @property
227
+ def vocab_size(self):
228
+ return self.tokenizer.n_vocab
229
+
230
+ def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
231
+ """Converts an id to a token, special tokens included"""
232
+ if index in self.decoder:
233
+ return self.decoder[index]
234
+ raise ValueError("unknown ids")
235
+
236
+ def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
237
+ """Converts a token to an id using the vocab, special tokens included"""
238
+ if token in self.special_tokens:
239
+ return self.special_tokens[token]
240
+ if token in self.mergeable_ranks:
241
+ return self.mergeable_ranks[token]
242
+ raise ValueError("unknown token")
243
+
244
+ def _tokenize(self, text: str, **kwargs):
245
+ """
246
+ Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
247
+ vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
248
+
249
+ Do NOT take care of added tokens.
250
+ """
251
+ raise NotImplementedError
252
+
253
+ def _decode(
254
+ self,
255
+ token_ids: Union[int, List[int]],
256
+ skip_special_tokens: bool = False,
257
+ errors: str = None,
258
+ **kwargs,
259
+ ) -> str:
260
+ if isinstance(token_ids, int):
261
+ token_ids = [token_ids]
262
+ if skip_special_tokens:
263
+ token_ids = [i for i in token_ids if i < self.eod_id]
264
+ return self.tokenizer.decode(token_ids, errors=errors or self.errors)
custom_got/tokenizer_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {},
3
+ "auto_map": {
4
+ "AutoTokenizer": [
5
+ "tokenization_qwen.QWenTokenizer",
6
+ null
7
+ ]
8
+ },
9
+ "clean_up_tokenization_spaces": true,
10
+ "model_max_length": 8000,
11
+ "pad_token": "<|endoftext|>",
12
+ "padding_side": "right",
13
+ "tokenizer_class": "QWenTokenizer"
14
+ }
dataset.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5c52de9875d5559635129df10b3a3466167b4041def7208b435c72531c970320
3
+ size 23713278
dataset_creation.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import json
3
+ import os
4
+
5
+
6
+ dataset = pd.read_csv('ocr_task/data_80k/data.csv')
7
+ labels = dataset['image_file']
8
+ text = dataset['text']
9
+ json_data = []
10
+ images_path = 'drive/MyDrive/data_80k/output_images/'
11
+ for i in range(len(labels)):
12
+ json_data.append(
13
+ {
14
+ "query": "<image>",
15
+ "response": text[i],
16
+ "images": [os.path.join(images_path, labels[i])],
17
+ }
18
+ )
19
+ with open('dataset.json', 'w') as f:
20
+ json.dump(json_data, f)
21
+
main.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModel, AutoTokenizer
2
+ import torch
3
+ from byaldi.RAGModel import RAGMultiModalModel
4
+ from byaldi.colpali import ColPaliModel
5
+ from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
6
+ from qwen_vl_utils import process_vision_info
7
+ import torch
8
+ from PIL import Image
9
+ import numpy as np
10
+
11
+
12
+ colpali_model = ColPaliModel.from_pretrained('vidore/colpali')
13
+ print(colpali_model.doc_id_to_metadata)
14
+ model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype=torch.bfloat16).eval()
15
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True)
16
+ messages = [
17
+ {
18
+ "role": "user",
19
+ "content": [
20
+ {
21
+ "type": "image",
22
+ "image": Image.open('template.jpg'),
23
+ },
24
+ {"type": "text", "text": 'Return full text of the document as a plain text'},
25
+ ],
26
+ }
27
+ ]
28
+
29
+ text = processor.apply_chat_template(
30
+ messages, tokenize=False, add_generation_prompt=True
31
+ )
32
+ img = Image.open('docs/hindi_template.jpg')
33
+ inputs = processor(
34
+ text=text,
35
+ images=img,
36
+ padding=True,
37
+ return_tensors="pt",
38
+ )
39
+ inputs = inputs.to("cpu")
40
+ generated_ids = model.generate(**inputs, max_new_tokens=5000)
41
+ generated_ids_trimmed = [
42
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
43
+ ]
44
+ output_text = processor.batch_decode(
45
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
46
+ )
47
+ print(output_text)
main_got.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModel, AutoTokenizer
2
+ import torch
3
+
4
+
5
+ def extract_text(image_path):
6
+ if torch.cuda.is_available():
7
+ device = torch.device('cuda') # If cuda is available, use it, otherwise use CPU
8
+ else:
9
+ device = torch.device('cpu')
10
+
11
+ tokenizer = AutoTokenizer.from_pretrained('custom_got',
12
+ trust_remote_code=True # Allows custom code to load model from hub
13
+ )
14
+ model = AutoModel.from_pretrained('custom_got',
15
+ trust_remote_code=True,
16
+ low_cpu_mem_usage=True,
17
+ device_map=device.type,
18
+ use_safetensors=True, # This format is faster, more memory efficient
19
+ # and provides safe deserialization unlike pickle-based one
20
+ pad_token_id=tokenizer.eos_token_id, # Set the pad token from tokenizer
21
+ )
22
+
23
+ image_file = image_path
24
+ # Extract text
25
+ res = model.chat(tokenizer, image_file, ocr_type='ocr')
26
+ return res
requirements.txt ADDED
Binary file (7.14 kB). View file