Moving main repo to hf repo
Browse files- .gitattributes +4 -35
- .gitignore +163 -0
- README.md +7 -12
- app.py +27 -2
- custom_got/assets/got_logo.png +0 -0
- custom_got/assets/got_support.jpg +0 -0
- custom_got/assets/train_sample.jpg +0 -0
- custom_got/config.json +38 -0
- custom_got/generation_config.json +6 -0
- custom_got/got_vision_b.py +468 -0
- custom_got/model.safetensors +3 -0
- custom_got/modeling_GOT.py +881 -0
- custom_got/qwen.tiktoken +0 -0
- custom_got/render_tools.py +96 -0
- custom_got/special_tokens_map.json +9 -0
- custom_got/tokenization_qwen.py +264 -0
- custom_got/tokenizer_config.json +14 -0
- dataset.json +3 -0
- dataset_creation.py +21 -0
- main.py +47 -0
- main_got.py +26 -0
- requirements.txt +0 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,4 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
custom_got/model.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
data_80k filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
data_80k/data.csv filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
dataset.json filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# poetry
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 102 |
+
#poetry.lock
|
| 103 |
+
|
| 104 |
+
# pdm
|
| 105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 106 |
+
#pdm.lock
|
| 107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 108 |
+
# in version control.
|
| 109 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
| 110 |
+
.pdm.toml
|
| 111 |
+
.pdm-python
|
| 112 |
+
.pdm-build/
|
| 113 |
+
|
| 114 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 115 |
+
__pypackages__/
|
| 116 |
+
|
| 117 |
+
# Celery stuff
|
| 118 |
+
celerybeat-schedule
|
| 119 |
+
celerybeat.pid
|
| 120 |
+
|
| 121 |
+
# SageMath parsed files
|
| 122 |
+
*.sage.py
|
| 123 |
+
|
| 124 |
+
# Environments
|
| 125 |
+
.env
|
| 126 |
+
.venv
|
| 127 |
+
env/
|
| 128 |
+
venv/
|
| 129 |
+
ENV/
|
| 130 |
+
env.bak/
|
| 131 |
+
venv.bak/
|
| 132 |
+
|
| 133 |
+
# Spyder project settings
|
| 134 |
+
.spyderproject
|
| 135 |
+
.spyproject
|
| 136 |
+
|
| 137 |
+
# Rope project settings
|
| 138 |
+
.ropeproject
|
| 139 |
+
|
| 140 |
+
# mkdocs documentation
|
| 141 |
+
/site
|
| 142 |
+
|
| 143 |
+
# mypy
|
| 144 |
+
.mypy_cache/
|
| 145 |
+
.dmypy.json
|
| 146 |
+
dmypy.json
|
| 147 |
+
|
| 148 |
+
# Pyre type checker
|
| 149 |
+
.pyre/
|
| 150 |
+
|
| 151 |
+
# pytype static type analyzer
|
| 152 |
+
.pytype/
|
| 153 |
+
|
| 154 |
+
# Cython debug symbols
|
| 155 |
+
cython_debug/
|
| 156 |
+
|
| 157 |
+
# PyCharm
|
| 158 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 159 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 160 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 161 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 162 |
+
#.idea/
|
| 163 |
+
data_80k
|
README.md
CHANGED
|
@@ -1,13 +1,8 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
license: apache-2.0
|
| 11 |
-
---
|
| 12 |
|
| 13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
+
# ocr_task
|
| 2 |
+
OCR assignment for PARIMAL IIT Roorkee Internship.
|
| 3 |
+
|
| 4 |
+
Two models for OCR were considered: GOT 2.0 and Colpali implementation of Byaldi library + Qwen2-VL. After research GOT was chosen because it has specification of extracting text from image directly without using LLM for explaining the content of the file. Besides that, GOT has direct instructions for training and fine-tuning model with data samples. Since GOT does not generate hindi symbols at all, I've needed to fine-tune the model on hindi dataset. Tokenizer already contained tokens for hindi symbols, so adding tokens was not necessary.
|
| 5 |
+
However, GOT is only compatible with CUDA, so on my device it won't be possible to fine-tune it. I've chosen to use Google Colab for this since it provides GPU for limited use.
|
| 6 |
+
|
| 7 |
+
During deployment on streamlit sharing encountered a problem with '\left' strings which were problematic escape sequences due to '\'. Used additional script replacer.py to replace all these string to '\\left'.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
|
|
app.py
CHANGED
|
@@ -1,4 +1,29 @@
|
|
| 1 |
import streamlit as st
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
from main_got import extract_text
|
| 3 |
+
import re
|
| 4 |
|
| 5 |
+
|
| 6 |
+
# Streamlit UI
|
| 7 |
+
st.title("OCR and Document Search Web App")
|
| 8 |
+
|
| 9 |
+
# Image upload
|
| 10 |
+
uploaded_image = st.file_uploader("Upload an image for OCR", type=["jpg", "png", "jpeg"])
|
| 11 |
+
|
| 12 |
+
if uploaded_image is not None:
|
| 13 |
+
with st.spinner("Processing image..."):
|
| 14 |
+
# Extract text from the uploaded image
|
| 15 |
+
extracted_text = extract_text(uploaded_image)
|
| 16 |
+
st.subheader("Extracted Text")
|
| 17 |
+
st.write(extracted_text)
|
| 18 |
+
|
| 19 |
+
# Search functionality
|
| 20 |
+
search_query = st.text_input("Enter a keyword to search within the text")
|
| 21 |
+
if search_query:
|
| 22 |
+
results = [match.start() for match in re.finditer(search_query, extracted_text)]
|
| 23 |
+
if results:
|
| 24 |
+
st.subheader("Search Results")
|
| 25 |
+
for result in results:
|
| 26 |
+
st.write(f"Keyword found at index: {result}")
|
| 27 |
+
else:
|
| 28 |
+
st.write("No results found.")
|
| 29 |
+
|
custom_got/assets/got_logo.png
ADDED
|
custom_got/assets/got_support.jpg
ADDED
|
custom_got/assets/train_sample.jpg
ADDED
|
custom_got/config.json
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "ucaslcl/GOT-OCR2_0",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"GOTQwenForCausalLM"
|
| 5 |
+
],
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoConfig": "modeling_GOT.GOTConfig",
|
| 8 |
+
"AutoModel": "modeling_GOT.GOTQwenForCausalLM"
|
| 9 |
+
},
|
| 10 |
+
"attention_dropout": 0.0,
|
| 11 |
+
"bos_token_id": 151643,
|
| 12 |
+
"eos_token_id": 151643,
|
| 13 |
+
"freeze_vision_tower": false,
|
| 14 |
+
"hidden_act": "silu",
|
| 15 |
+
"hidden_size": 1024,
|
| 16 |
+
"im_end_token": 151858,
|
| 17 |
+
"im_patch_token": 151859,
|
| 18 |
+
"im_start_token": 151857,
|
| 19 |
+
"image_token_len": 256,
|
| 20 |
+
"initializer_range": 0.02,
|
| 21 |
+
"intermediate_size": 2816,
|
| 22 |
+
"max_position_embeddings": 32768,
|
| 23 |
+
"max_window_layers": 21,
|
| 24 |
+
"model_type": "GOT",
|
| 25 |
+
"num_attention_heads": 16,
|
| 26 |
+
"num_hidden_layers": 24,
|
| 27 |
+
"num_key_value_heads": 16,
|
| 28 |
+
"rms_norm_eps": 1e-06,
|
| 29 |
+
"rope_theta": 1000000.0,
|
| 30 |
+
"sliding_window": 32768,
|
| 31 |
+
"tie_word_embeddings": true,
|
| 32 |
+
"torch_dtype": "bfloat16",
|
| 33 |
+
"transformers_version": "4.37.2",
|
| 34 |
+
"use_cache": true,
|
| 35 |
+
"use_im_start_end": true,
|
| 36 |
+
"use_sliding_window": false,
|
| 37 |
+
"vocab_size": 151860
|
| 38 |
+
}
|
custom_got/generation_config.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token_id": 151643,
|
| 3 |
+
"eos_token_id": 151643,
|
| 4 |
+
"max_new_tokens": 2048,
|
| 5 |
+
"transformers_version": "4.37.2"
|
| 6 |
+
}
|
custom_got/got_vision_b.py
ADDED
|
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from typing import Optional, Tuple, Type
|
| 4 |
+
from functools import partial
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from typing import Type
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class MLPBlock(nn.Module):
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
embedding_dim: int,
|
| 14 |
+
mlp_dim: int,
|
| 15 |
+
act: Type[nn.Module] = nn.GELU,
|
| 16 |
+
) -> None:
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
|
| 19 |
+
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
|
| 20 |
+
self.act = act()
|
| 21 |
+
|
| 22 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 23 |
+
return self.lin2(self.act(self.lin1(x)))
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class LayerNorm2d(nn.Module):
|
| 28 |
+
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.weight = nn.Parameter(torch.ones(num_channels))
|
| 31 |
+
self.bias = nn.Parameter(torch.zeros(num_channels))
|
| 32 |
+
self.eps = eps
|
| 33 |
+
|
| 34 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 35 |
+
u = x.mean(1, keepdim=True)
|
| 36 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
| 37 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
| 38 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
| 39 |
+
return x
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class ImageEncoderViT(nn.Module):
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
img_size: int = 1024,
|
| 47 |
+
patch_size: int = 16,
|
| 48 |
+
in_chans: int = 3,
|
| 49 |
+
embed_dim: int = 768,
|
| 50 |
+
depth: int = 12,
|
| 51 |
+
num_heads: int = 12,
|
| 52 |
+
mlp_ratio: float = 4.0,
|
| 53 |
+
out_chans: int = 256,
|
| 54 |
+
qkv_bias: bool = True,
|
| 55 |
+
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
| 56 |
+
act_layer: Type[nn.Module] = nn.GELU,
|
| 57 |
+
use_abs_pos: bool = True,
|
| 58 |
+
use_rel_pos: bool = False,
|
| 59 |
+
rel_pos_zero_init: bool = True,
|
| 60 |
+
window_size: int = 0,
|
| 61 |
+
global_attn_indexes: Tuple[int, ...] = (),
|
| 62 |
+
) -> None:
|
| 63 |
+
"""
|
| 64 |
+
Args:
|
| 65 |
+
img_size (int): Input image size.
|
| 66 |
+
patch_size (int): Patch size.
|
| 67 |
+
in_chans (int): Number of input image channels.
|
| 68 |
+
embed_dim (int): Patch embedding dimension.
|
| 69 |
+
depth (int): Depth of ViT.
|
| 70 |
+
num_heads (int): Number of attention heads in each ViT block.
|
| 71 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 72 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
| 73 |
+
norm_layer (nn.Module): Normalization layer.
|
| 74 |
+
act_layer (nn.Module): Activation layer.
|
| 75 |
+
use_abs_pos (bool): If True, use absolute positional embeddings.
|
| 76 |
+
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
| 77 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
| 78 |
+
window_size (int): Window size for window attention blocks.
|
| 79 |
+
global_attn_indexes (list): Indexes for blocks using global attention.
|
| 80 |
+
"""
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.img_size = img_size
|
| 83 |
+
|
| 84 |
+
self.patch_embed = PatchEmbed(
|
| 85 |
+
kernel_size=(patch_size, patch_size),
|
| 86 |
+
stride=(patch_size, patch_size),
|
| 87 |
+
in_chans=in_chans,
|
| 88 |
+
embed_dim=embed_dim,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
self.pos_embed: Optional[nn.Parameter] = None
|
| 92 |
+
if use_abs_pos:
|
| 93 |
+
# Initialize absolute positional embedding with pretrain image size.
|
| 94 |
+
self.pos_embed = nn.Parameter(
|
| 95 |
+
torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
self.blocks = nn.ModuleList()
|
| 99 |
+
for i in range(depth):
|
| 100 |
+
block = Block(
|
| 101 |
+
dim=embed_dim,
|
| 102 |
+
num_heads=num_heads,
|
| 103 |
+
mlp_ratio=mlp_ratio,
|
| 104 |
+
qkv_bias=qkv_bias,
|
| 105 |
+
norm_layer=norm_layer,
|
| 106 |
+
act_layer=act_layer,
|
| 107 |
+
use_rel_pos=use_rel_pos,
|
| 108 |
+
rel_pos_zero_init=rel_pos_zero_init,
|
| 109 |
+
window_size=window_size if i not in global_attn_indexes else 0,
|
| 110 |
+
input_size=(img_size // patch_size, img_size // patch_size),
|
| 111 |
+
)
|
| 112 |
+
self.blocks.append(block)
|
| 113 |
+
|
| 114 |
+
self.neck = nn.Sequential(
|
| 115 |
+
nn.Conv2d(
|
| 116 |
+
embed_dim,
|
| 117 |
+
out_chans,
|
| 118 |
+
kernel_size=1,
|
| 119 |
+
bias=False,
|
| 120 |
+
),
|
| 121 |
+
LayerNorm2d(out_chans),
|
| 122 |
+
nn.Conv2d(
|
| 123 |
+
out_chans,
|
| 124 |
+
out_chans,
|
| 125 |
+
kernel_size=3,
|
| 126 |
+
padding=1,
|
| 127 |
+
bias=False,
|
| 128 |
+
),
|
| 129 |
+
LayerNorm2d(out_chans),
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
|
| 134 |
+
self.net_3 = nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, bias=False)
|
| 135 |
+
|
| 136 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 137 |
+
x = self.patch_embed(x)
|
| 138 |
+
if self.pos_embed is not None:
|
| 139 |
+
x = x + self.pos_embed
|
| 140 |
+
|
| 141 |
+
for blk in self.blocks:
|
| 142 |
+
x = blk(x)
|
| 143 |
+
|
| 144 |
+
x = self.neck(x.permute(0, 3, 1, 2))
|
| 145 |
+
x = self.net_2(x)
|
| 146 |
+
x = self.net_3(x)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
return x
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class Block(nn.Module):
|
| 153 |
+
"""Transformer blocks with support of window attention and residual propagation blocks"""
|
| 154 |
+
|
| 155 |
+
def __init__(
|
| 156 |
+
self,
|
| 157 |
+
dim: int,
|
| 158 |
+
num_heads: int,
|
| 159 |
+
mlp_ratio: float = 4.0,
|
| 160 |
+
qkv_bias: bool = True,
|
| 161 |
+
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
| 162 |
+
act_layer: Type[nn.Module] = nn.GELU,
|
| 163 |
+
use_rel_pos: bool = False,
|
| 164 |
+
rel_pos_zero_init: bool = True,
|
| 165 |
+
window_size: int = 0,
|
| 166 |
+
input_size: Optional[Tuple[int, int]] = None,
|
| 167 |
+
) -> None:
|
| 168 |
+
"""
|
| 169 |
+
Args:
|
| 170 |
+
dim (int): Number of input channels.
|
| 171 |
+
num_heads (int): Number of attention heads in each ViT block.
|
| 172 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 173 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
| 174 |
+
norm_layer (nn.Module): Normalization layer.
|
| 175 |
+
act_layer (nn.Module): Activation layer.
|
| 176 |
+
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
| 177 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
| 178 |
+
window_size (int): Window size for window attention blocks. If it equals 0, then
|
| 179 |
+
use global attention.
|
| 180 |
+
input_size (tuple(int, int) or None): Input resolution for calculating the relative
|
| 181 |
+
positional parameter size.
|
| 182 |
+
"""
|
| 183 |
+
super().__init__()
|
| 184 |
+
self.norm1 = norm_layer(dim)
|
| 185 |
+
self.attn = Attention(
|
| 186 |
+
dim,
|
| 187 |
+
num_heads=num_heads,
|
| 188 |
+
qkv_bias=qkv_bias,
|
| 189 |
+
use_rel_pos=use_rel_pos,
|
| 190 |
+
rel_pos_zero_init=rel_pos_zero_init,
|
| 191 |
+
input_size=input_size if window_size == 0 else (window_size, window_size),
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
self.norm2 = norm_layer(dim)
|
| 195 |
+
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
|
| 196 |
+
|
| 197 |
+
self.window_size = window_size
|
| 198 |
+
|
| 199 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 200 |
+
shortcut = x
|
| 201 |
+
x = self.norm1(x)
|
| 202 |
+
# Window partition
|
| 203 |
+
if self.window_size > 0:
|
| 204 |
+
H, W = x.shape[1], x.shape[2]
|
| 205 |
+
x, pad_hw = window_partition(x, self.window_size)
|
| 206 |
+
|
| 207 |
+
x = self.attn(x)
|
| 208 |
+
# Reverse window partition
|
| 209 |
+
if self.window_size > 0:
|
| 210 |
+
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
| 211 |
+
|
| 212 |
+
x = shortcut + x
|
| 213 |
+
x = x + self.mlp(self.norm2(x))
|
| 214 |
+
|
| 215 |
+
return x
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class Attention(nn.Module):
|
| 219 |
+
"""Multi-head Attention block with relative position embeddings."""
|
| 220 |
+
|
| 221 |
+
def __init__(
|
| 222 |
+
self,
|
| 223 |
+
dim: int,
|
| 224 |
+
num_heads: int = 8,
|
| 225 |
+
qkv_bias: bool = True,
|
| 226 |
+
use_rel_pos: bool = False,
|
| 227 |
+
rel_pos_zero_init: bool = True,
|
| 228 |
+
input_size: Optional[Tuple[int, int]] = None,
|
| 229 |
+
) -> None:
|
| 230 |
+
"""
|
| 231 |
+
Args:
|
| 232 |
+
dim (int): Number of input channels.
|
| 233 |
+
num_heads (int): Number of attention heads.
|
| 234 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
| 235 |
+
rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
| 236 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
| 237 |
+
input_size (tuple(int, int) or None): Input resolution for calculating the relative
|
| 238 |
+
positional parameter size.
|
| 239 |
+
"""
|
| 240 |
+
super().__init__()
|
| 241 |
+
self.num_heads = num_heads
|
| 242 |
+
head_dim = dim // num_heads
|
| 243 |
+
self.scale = head_dim**-0.5
|
| 244 |
+
|
| 245 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 246 |
+
self.proj = nn.Linear(dim, dim)
|
| 247 |
+
|
| 248 |
+
self.use_rel_pos = use_rel_pos
|
| 249 |
+
if self.use_rel_pos:
|
| 250 |
+
assert (
|
| 251 |
+
input_size is not None
|
| 252 |
+
), "Input size must be provided if using relative positional encoding."
|
| 253 |
+
# initialize relative positional embeddings
|
| 254 |
+
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
|
| 255 |
+
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
|
| 256 |
+
|
| 257 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 258 |
+
B, H, W, _ = x.shape
|
| 259 |
+
# qkv with shape (3, B, nHead, H * W, C)
|
| 260 |
+
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
| 261 |
+
# q, k, v with shape (B * nHead, H * W, C)
|
| 262 |
+
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
|
| 263 |
+
|
| 264 |
+
attn = (q * self.scale) @ k.transpose(-2, -1)
|
| 265 |
+
|
| 266 |
+
if self.use_rel_pos:
|
| 267 |
+
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
|
| 268 |
+
|
| 269 |
+
attn = attn.softmax(dim=-1)
|
| 270 |
+
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
|
| 271 |
+
x = self.proj(x)
|
| 272 |
+
|
| 273 |
+
return x
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
| 277 |
+
"""
|
| 278 |
+
Partition into non-overlapping windows with padding if needed.
|
| 279 |
+
Args:
|
| 280 |
+
x (tensor): input tokens with [B, H, W, C].
|
| 281 |
+
window_size (int): window size.
|
| 282 |
+
|
| 283 |
+
Returns:
|
| 284 |
+
windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
| 285 |
+
(Hp, Wp): padded height and width before partition
|
| 286 |
+
"""
|
| 287 |
+
B, H, W, C = x.shape
|
| 288 |
+
|
| 289 |
+
pad_h = (window_size - H % window_size) % window_size
|
| 290 |
+
pad_w = (window_size - W % window_size) % window_size
|
| 291 |
+
if pad_h > 0 or pad_w > 0:
|
| 292 |
+
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
| 293 |
+
Hp, Wp = H + pad_h, W + pad_w
|
| 294 |
+
|
| 295 |
+
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
| 296 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
| 297 |
+
return windows, (Hp, Wp)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def window_unpartition(
|
| 301 |
+
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
|
| 302 |
+
) -> torch.Tensor:
|
| 303 |
+
"""
|
| 304 |
+
Window unpartition into original sequences and removing padding.
|
| 305 |
+
Args:
|
| 306 |
+
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
| 307 |
+
window_size (int): window size.
|
| 308 |
+
pad_hw (Tuple): padded height and width (Hp, Wp).
|
| 309 |
+
hw (Tuple): original height and width (H, W) before padding.
|
| 310 |
+
|
| 311 |
+
Returns:
|
| 312 |
+
x: unpartitioned sequences with [B, H, W, C].
|
| 313 |
+
"""
|
| 314 |
+
Hp, Wp = pad_hw
|
| 315 |
+
H, W = hw
|
| 316 |
+
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
| 317 |
+
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
|
| 318 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
| 319 |
+
|
| 320 |
+
if Hp > H or Wp > W:
|
| 321 |
+
x = x[:, :H, :W, :].contiguous()
|
| 322 |
+
return x
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
|
| 326 |
+
"""
|
| 327 |
+
Get relative positional embeddings according to the relative positions of
|
| 328 |
+
query and key sizes.
|
| 329 |
+
Args:
|
| 330 |
+
q_size (int): size of query q.
|
| 331 |
+
k_size (int): size of key k.
|
| 332 |
+
rel_pos (Tensor): relative position embeddings (L, C).
|
| 333 |
+
|
| 334 |
+
Returns:
|
| 335 |
+
Extracted positional embeddings according to relative positions.
|
| 336 |
+
"""
|
| 337 |
+
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
| 338 |
+
# Interpolate rel pos if needed.
|
| 339 |
+
if rel_pos.shape[0] != max_rel_dist:
|
| 340 |
+
# Interpolate rel pos.
|
| 341 |
+
rel_pos_resized = F.interpolate(
|
| 342 |
+
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
| 343 |
+
size=max_rel_dist,
|
| 344 |
+
mode="linear",
|
| 345 |
+
)
|
| 346 |
+
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
| 347 |
+
else:
|
| 348 |
+
rel_pos_resized = rel_pos
|
| 349 |
+
|
| 350 |
+
# Scale the coords with short length if shapes for q and k are different.
|
| 351 |
+
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
| 352 |
+
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
| 353 |
+
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
| 354 |
+
|
| 355 |
+
return rel_pos_resized[relative_coords.long()]
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def add_decomposed_rel_pos(
|
| 359 |
+
attn: torch.Tensor,
|
| 360 |
+
q: torch.Tensor,
|
| 361 |
+
rel_pos_h: torch.Tensor,
|
| 362 |
+
rel_pos_w: torch.Tensor,
|
| 363 |
+
q_size: Tuple[int, int],
|
| 364 |
+
k_size: Tuple[int, int],
|
| 365 |
+
) -> torch.Tensor:
|
| 366 |
+
"""
|
| 367 |
+
Args:
|
| 368 |
+
attn (Tensor): attention map.
|
| 369 |
+
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
|
| 370 |
+
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
|
| 371 |
+
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
|
| 372 |
+
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
|
| 373 |
+
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
|
| 374 |
+
|
| 375 |
+
Returns:
|
| 376 |
+
attn (Tensor): attention map with added relative positional embeddings.
|
| 377 |
+
"""
|
| 378 |
+
q_h, q_w = q_size
|
| 379 |
+
k_h, k_w = k_size
|
| 380 |
+
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
| 381 |
+
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
| 382 |
+
|
| 383 |
+
B, _, dim = q.shape
|
| 384 |
+
r_q = q.reshape(B, q_h, q_w, dim)
|
| 385 |
+
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
| 386 |
+
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
| 387 |
+
|
| 388 |
+
attn = (
|
| 389 |
+
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
|
| 390 |
+
).view(B, q_h * q_w, k_h * k_w)
|
| 391 |
+
|
| 392 |
+
return attn
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
class PatchEmbed(nn.Module):
|
| 396 |
+
"""
|
| 397 |
+
Image to Patch Embedding.
|
| 398 |
+
"""
|
| 399 |
+
|
| 400 |
+
def __init__(
|
| 401 |
+
self,
|
| 402 |
+
kernel_size: Tuple[int, int] = (16, 16),
|
| 403 |
+
stride: Tuple[int, int] = (16, 16),
|
| 404 |
+
padding: Tuple[int, int] = (0, 0),
|
| 405 |
+
in_chans: int = 3,
|
| 406 |
+
embed_dim: int = 768,
|
| 407 |
+
) -> None:
|
| 408 |
+
"""
|
| 409 |
+
Args:
|
| 410 |
+
kernel_size (Tuple): kernel size of the projection layer.
|
| 411 |
+
stride (Tuple): stride of the projection layer.
|
| 412 |
+
padding (Tuple): padding size of the projection layer.
|
| 413 |
+
in_chans (int): Number of input image channels.
|
| 414 |
+
embed_dim (int): Patch embedding dimension.
|
| 415 |
+
"""
|
| 416 |
+
super().__init__()
|
| 417 |
+
|
| 418 |
+
self.proj = nn.Conv2d(
|
| 419 |
+
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 423 |
+
x = self.proj(x)
|
| 424 |
+
# B C H W -> B H W C
|
| 425 |
+
x = x.permute(0, 2, 3, 1)
|
| 426 |
+
return x
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def build_GOT_vit_b(checkpoint=None):
|
| 431 |
+
return _build_GOT_vision(
|
| 432 |
+
encoder_embed_dim=768,
|
| 433 |
+
encoder_depth=12,
|
| 434 |
+
encoder_num_heads=12,
|
| 435 |
+
encoder_global_attn_indexes=[2, 5, 8, 11],
|
| 436 |
+
checkpoint=checkpoint,
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def _build_GOT_vision(
|
| 441 |
+
encoder_embed_dim,
|
| 442 |
+
encoder_depth,
|
| 443 |
+
encoder_num_heads,
|
| 444 |
+
encoder_global_attn_indexes,
|
| 445 |
+
checkpoint=None,
|
| 446 |
+
):
|
| 447 |
+
prompt_embed_dim = 256
|
| 448 |
+
image_size = 1024
|
| 449 |
+
vit_patch_size = 16
|
| 450 |
+
image_embedding_size = image_size // vit_patch_size
|
| 451 |
+
image_encoder=ImageEncoderViT(
|
| 452 |
+
depth=encoder_depth,
|
| 453 |
+
embed_dim=encoder_embed_dim,
|
| 454 |
+
img_size=image_size,
|
| 455 |
+
mlp_ratio=4,
|
| 456 |
+
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
|
| 457 |
+
num_heads=encoder_num_heads,
|
| 458 |
+
patch_size=vit_patch_size,
|
| 459 |
+
qkv_bias=True,
|
| 460 |
+
use_rel_pos=True,
|
| 461 |
+
global_attn_indexes=encoder_global_attn_indexes,
|
| 462 |
+
window_size=14,
|
| 463 |
+
out_chans=prompt_embed_dim,
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
return image_encoder
|
| 468 |
+
|
custom_got/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:77d6144039548b14253176b6eb264896bc39eba532f8894700f210a7fd2a5956
|
| 3 |
+
size 1432121416
|
custom_got/modeling_GOT.py
ADDED
|
@@ -0,0 +1,881 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM, StoppingCriteria, TextStreamer
|
| 2 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 3 |
+
from typing import List, Optional, Tuple, Union
|
| 4 |
+
from transformers.cache_utils import Cache
|
| 5 |
+
import requests
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from io import BytesIO
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from torch.nn import CrossEntropyLoss
|
| 11 |
+
from .got_vision_b import build_GOT_vit_b
|
| 12 |
+
from torchvision import transforms
|
| 13 |
+
from torchvision.transforms.functional import InterpolationMode
|
| 14 |
+
import dataclasses
|
| 15 |
+
###
|
| 16 |
+
|
| 17 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
| 18 |
+
DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
|
| 19 |
+
DEFAULT_IM_START_TOKEN = '<img>'
|
| 20 |
+
DEFAULT_IM_END_TOKEN = '</img>'
|
| 21 |
+
|
| 22 |
+
from enum import auto, Enum
|
| 23 |
+
class SeparatorStyle(Enum):
|
| 24 |
+
"""Different separator style."""
|
| 25 |
+
SINGLE = auto()
|
| 26 |
+
TWO = auto()
|
| 27 |
+
MPT = auto()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclasses.dataclass
|
| 31 |
+
class Conversation:
|
| 32 |
+
"""A class that keeps all conversation history."""
|
| 33 |
+
system: str
|
| 34 |
+
roles: List[str]
|
| 35 |
+
messages: List[List[str]]
|
| 36 |
+
offset: int
|
| 37 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
| 38 |
+
sep: str = "<|im_end|>"
|
| 39 |
+
sep2: str = None
|
| 40 |
+
version: str = "Unknown"
|
| 41 |
+
|
| 42 |
+
skip_next: bool = False
|
| 43 |
+
|
| 44 |
+
def get_prompt(self):
|
| 45 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
| 46 |
+
ret = self.system + self.sep + '\n'
|
| 47 |
+
for role, message in self.messages:
|
| 48 |
+
if message:
|
| 49 |
+
if type(message) is tuple:
|
| 50 |
+
message, _, _ = message
|
| 51 |
+
ret += role + ": " + message + self.sep
|
| 52 |
+
else:
|
| 53 |
+
ret += role + ":"
|
| 54 |
+
return ret
|
| 55 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
| 56 |
+
seps = [self.sep, self.sep2]
|
| 57 |
+
ret = self.system + seps[0]
|
| 58 |
+
for i, (role, message) in enumerate(self.messages):
|
| 59 |
+
if message:
|
| 60 |
+
if type(message) is tuple:
|
| 61 |
+
message, _, _ = message
|
| 62 |
+
ret += role + ": " + message + seps[i % 2]
|
| 63 |
+
else:
|
| 64 |
+
ret += role + ":"
|
| 65 |
+
return ret
|
| 66 |
+
if self.sep_style == SeparatorStyle.MPT:
|
| 67 |
+
if self.system:
|
| 68 |
+
ret = self.system + self.sep
|
| 69 |
+
else:
|
| 70 |
+
ret = ''
|
| 71 |
+
for role, message in self.messages:
|
| 72 |
+
if message:
|
| 73 |
+
if type(message) is tuple:
|
| 74 |
+
message, _, _ = message
|
| 75 |
+
ret += role + message + self.sep
|
| 76 |
+
else:
|
| 77 |
+
ret += role
|
| 78 |
+
return ret
|
| 79 |
+
else:
|
| 80 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def append_message(self, role, message):
|
| 84 |
+
self.messages.append([role, message])
|
| 85 |
+
|
| 86 |
+
def copy(self):
|
| 87 |
+
return Conversation(
|
| 88 |
+
system=self.system,
|
| 89 |
+
roles=self.roles,
|
| 90 |
+
messages=[[x, y] for x, y in self.messages],
|
| 91 |
+
offset=self.offset,
|
| 92 |
+
sep_style=self.sep_style,
|
| 93 |
+
sep=self.sep,
|
| 94 |
+
sep2=self.sep2)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
| 99 |
+
def __init__(self, keywords, tokenizer, input_ids):
|
| 100 |
+
self.keywords = keywords
|
| 101 |
+
self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords]
|
| 102 |
+
self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1]
|
| 103 |
+
self.tokenizer = tokenizer
|
| 104 |
+
self.start_len = None
|
| 105 |
+
self.input_ids = input_ids
|
| 106 |
+
|
| 107 |
+
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
| 108 |
+
if self.start_len is None:
|
| 109 |
+
self.start_len = self.input_ids.shape[1]
|
| 110 |
+
else:
|
| 111 |
+
for keyword_id in self.keyword_ids:
|
| 112 |
+
if output_ids[0, -1] == keyword_id:
|
| 113 |
+
return True
|
| 114 |
+
outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
|
| 115 |
+
for keyword in self.keywords:
|
| 116 |
+
if keyword in outputs:
|
| 117 |
+
return True
|
| 118 |
+
return False
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class GOTImageEvalProcessor:
|
| 122 |
+
def __init__(self, image_size=384, mean=None, std=None):
|
| 123 |
+
if mean is None:
|
| 124 |
+
mean = (0.48145466, 0.4578275, 0.40821073)
|
| 125 |
+
if std is None:
|
| 126 |
+
std = (0.26862954, 0.26130258, 0.27577711)
|
| 127 |
+
|
| 128 |
+
self.normalize = transforms.Normalize(mean, std)
|
| 129 |
+
|
| 130 |
+
self.transform = transforms.Compose(
|
| 131 |
+
[
|
| 132 |
+
transforms.Resize(
|
| 133 |
+
(image_size, image_size), interpolation=InterpolationMode.BICUBIC
|
| 134 |
+
),
|
| 135 |
+
transforms.ToTensor(),
|
| 136 |
+
self.normalize,
|
| 137 |
+
]
|
| 138 |
+
)
|
| 139 |
+
def __call__(self, item):
|
| 140 |
+
return self.transform(item)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class GOTConfig(Qwen2Config):
|
| 145 |
+
model_type = "GOT"
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class GOTQwenModel(Qwen2Model):
|
| 149 |
+
config_class = GOTConfig
|
| 150 |
+
|
| 151 |
+
def __init__(self, config: Qwen2Config):
|
| 152 |
+
super(GOTQwenModel, self).__init__(config)
|
| 153 |
+
|
| 154 |
+
self.vision_tower_high = build_GOT_vit_b()
|
| 155 |
+
|
| 156 |
+
self.mm_projector_vary = nn.Linear(1024, 1024)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def initialize_vision_modules(
|
| 160 |
+
self,
|
| 161 |
+
vision_tower,
|
| 162 |
+
pretrained_stage1_model=None,
|
| 163 |
+
freeze_vision_tower=False,
|
| 164 |
+
use_im_start_end=False,
|
| 165 |
+
vision_select_layer=-1,
|
| 166 |
+
dtype=torch.float16,
|
| 167 |
+
device="cuda"
|
| 168 |
+
):
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
image_processor_high = GOTImageEvalProcessor(image_size=1024)
|
| 172 |
+
|
| 173 |
+
self.vision_tower_high = self.vision_tower_high.to(dtype=dtype, device=device)
|
| 174 |
+
|
| 175 |
+
self.mm_projector_vary = self.mm_projector_vary.to(dtype=dtype, device=device)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
image_token_len = 256
|
| 179 |
+
|
| 180 |
+
self.config.vision_tower = vision_tower
|
| 181 |
+
self.config.image_token_len = image_token_len
|
| 182 |
+
|
| 183 |
+
self.config.use_im_start_end = True
|
| 184 |
+
|
| 185 |
+
self.config.vision_select_layer = vision_select_layer
|
| 186 |
+
self.config.freeze_vision_tower = freeze_vision_tower
|
| 187 |
+
|
| 188 |
+
return dict(
|
| 189 |
+
image_processor_high=image_processor_high,
|
| 190 |
+
image_token_len=image_token_len,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def forward(
|
| 195 |
+
self,
|
| 196 |
+
input_ids: torch.LongTensor = None,
|
| 197 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 198 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 199 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 200 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 201 |
+
use_cache: Optional[bool] = None,
|
| 202 |
+
output_attentions: Optional[bool] = None,
|
| 203 |
+
output_hidden_states: Optional[bool] = None,
|
| 204 |
+
images: Optional[torch.FloatTensor] = None,
|
| 205 |
+
return_dict: Optional[bool] = None,
|
| 206 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 207 |
+
|
| 208 |
+
# HACK: replace back original embeddings for LLaVA pretraining
|
| 209 |
+
orig_embeds_params = getattr(self, 'orig_embeds_params', None)
|
| 210 |
+
if orig_embeds_params is not None:
|
| 211 |
+
with torch.no_grad():
|
| 212 |
+
self.get_input_embeddings().weight[:-self.num_new_tokens] = orig_embeds_params[:-self.num_new_tokens].data
|
| 213 |
+
|
| 214 |
+
if inputs_embeds is None:
|
| 215 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
vision_tower_high = getattr(self, 'vision_tower_high', None)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
if vision_tower_high is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
|
| 222 |
+
use_im_start_end = getattr(self.config, "use_im_start_end", -1)
|
| 223 |
+
|
| 224 |
+
vision_select_layer = getattr(self.config, "vision_select_layer", -1)
|
| 225 |
+
im_patch_token = getattr(self.config, "im_patch_token", -1)
|
| 226 |
+
im_start_token = getattr(self.config, "im_start_token", -1)
|
| 227 |
+
im_end_token = getattr(self.config, "im_end_token", -1)
|
| 228 |
+
freeze_vision_tower = getattr(self.config, "freeze_vision_tower", False)
|
| 229 |
+
|
| 230 |
+
im_patch_token = 151859
|
| 231 |
+
|
| 232 |
+
im_start_token = 151857
|
| 233 |
+
|
| 234 |
+
im_end_token = 151858
|
| 235 |
+
|
| 236 |
+
image_features = []
|
| 237 |
+
|
| 238 |
+
for image in images:
|
| 239 |
+
P, C, H, W = image.shape
|
| 240 |
+
if P == 1:
|
| 241 |
+
with torch.set_grad_enabled(False):
|
| 242 |
+
cnn_feature = vision_tower_high(image)
|
| 243 |
+
cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # 256*1024
|
| 244 |
+
image_feature = self.mm_projector_vary(cnn_feature)
|
| 245 |
+
image_features.append(image_feature)
|
| 246 |
+
|
| 247 |
+
else:
|
| 248 |
+
image_patches = torch.unbind(image)
|
| 249 |
+
image_patches_features = []
|
| 250 |
+
for image_patch in image_patches:
|
| 251 |
+
image_p = torch.stack([image_patch])
|
| 252 |
+
|
| 253 |
+
with torch.set_grad_enabled(False):
|
| 254 |
+
cnn_feature_p = vision_tower_high(image_p)
|
| 255 |
+
cnn_feature_p = cnn_feature_p.flatten(2).permute(0, 2, 1)
|
| 256 |
+
image_feature_p = self.mm_projector_vary(cnn_feature_p)
|
| 257 |
+
image_patches_features.append(image_feature_p)
|
| 258 |
+
image_feature = torch.cat(image_patches_features, dim=1)
|
| 259 |
+
image_features.append(image_feature)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
dummy_image_features_2 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
|
| 263 |
+
dummy_image_features = dummy_image_features_2
|
| 264 |
+
use_im_start_end = True
|
| 265 |
+
new_input_embeds = []
|
| 266 |
+
for cur_input_ids, cur_input_embeds, cur_image_features in zip(input_ids, inputs_embeds, image_features):
|
| 267 |
+
if (cur_input_ids == im_patch_token).sum() == 0:
|
| 268 |
+
cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
|
| 269 |
+
new_input_embeds.append(cur_input_embeds)
|
| 270 |
+
continue
|
| 271 |
+
|
| 272 |
+
if use_im_start_end:
|
| 273 |
+
if (cur_input_ids == im_start_token).sum() != (cur_input_ids == im_end_token).sum():
|
| 274 |
+
raise ValueError("The number of image start tokens and image end tokens should be the same.")
|
| 275 |
+
|
| 276 |
+
image_start_tokens = torch.where(cur_input_ids == im_start_token)[0]
|
| 277 |
+
for image_start_token_pos, per_cur_image_features in zip(image_start_tokens, cur_image_features):
|
| 278 |
+
per_cur_image_features = per_cur_image_features.to(device=cur_input_embeds.device)
|
| 279 |
+
num_patches = per_cur_image_features.shape[0]
|
| 280 |
+
|
| 281 |
+
if cur_input_ids[image_start_token_pos + num_patches + 1] != im_end_token:
|
| 282 |
+
raise ValueError("The image end token should follow the image start token.")
|
| 283 |
+
|
| 284 |
+
cur_input_embeds = torch.cat(
|
| 285 |
+
(
|
| 286 |
+
cur_input_embeds[:image_start_token_pos+1],
|
| 287 |
+
per_cur_image_features,
|
| 288 |
+
cur_input_embeds[image_start_token_pos + num_patches + 1:]
|
| 289 |
+
),
|
| 290 |
+
dim=0
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
new_input_embeds.append(cur_input_embeds)
|
| 295 |
+
else:
|
| 296 |
+
raise NotImplementedError
|
| 297 |
+
|
| 298 |
+
inputs_embeds = torch.stack(new_input_embeds, dim=0)
|
| 299 |
+
|
| 300 |
+
return super(GOTQwenModel, self).forward(
|
| 301 |
+
input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
|
| 302 |
+
inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids,
|
| 303 |
+
output_attentions=output_attentions, output_hidden_states=output_hidden_states,
|
| 304 |
+
return_dict=return_dict
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
| 310 |
+
config_class = GOTConfig
|
| 311 |
+
# supports_gradient_checkpointing = True
|
| 312 |
+
|
| 313 |
+
def __init__(self, config):
|
| 314 |
+
super(Qwen2ForCausalLM, self).__init__(config)
|
| 315 |
+
self.model = GOTQwenModel(config)
|
| 316 |
+
|
| 317 |
+
self.vocab_size = config.vocab_size
|
| 318 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 319 |
+
|
| 320 |
+
# Initialize weights and apply final processing
|
| 321 |
+
self.post_init()
|
| 322 |
+
|
| 323 |
+
def get_model(self):
|
| 324 |
+
return self.model
|
| 325 |
+
|
| 326 |
+
def forward(
|
| 327 |
+
self,
|
| 328 |
+
input_ids: torch.LongTensor = None,
|
| 329 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 330 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 331 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 332 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 333 |
+
labels: Optional[torch.LongTensor] = None,
|
| 334 |
+
use_cache: Optional[bool] = None,
|
| 335 |
+
output_attentions: Optional[bool] = None,
|
| 336 |
+
output_hidden_states: Optional[bool] = None,
|
| 337 |
+
images: Optional[torch.FloatTensor] = None,
|
| 338 |
+
return_dict: Optional[bool] = None,
|
| 339 |
+
|
| 340 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 341 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 342 |
+
output_hidden_states = (
|
| 343 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 344 |
+
)
|
| 345 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 346 |
+
|
| 347 |
+
outputs = self.model(
|
| 348 |
+
input_ids=input_ids,
|
| 349 |
+
past_key_values=past_key_values,
|
| 350 |
+
attention_mask=attention_mask,
|
| 351 |
+
position_ids=position_ids,
|
| 352 |
+
inputs_embeds=inputs_embeds,
|
| 353 |
+
use_cache=use_cache,
|
| 354 |
+
output_attentions=output_attentions,
|
| 355 |
+
output_hidden_states=output_hidden_states,
|
| 356 |
+
images=images,
|
| 357 |
+
return_dict=return_dict
|
| 358 |
+
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
hidden_states = outputs[0]
|
| 362 |
+
logits = self.lm_head(hidden_states)
|
| 363 |
+
logits = logits.float()
|
| 364 |
+
|
| 365 |
+
# logits
|
| 366 |
+
|
| 367 |
+
loss = None
|
| 368 |
+
if labels is not None:
|
| 369 |
+
# Shift so that tokens < n predict n
|
| 370 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 371 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 372 |
+
# Flatten the tokens
|
| 373 |
+
loss_fct = CrossEntropyLoss()
|
| 374 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
| 375 |
+
shift_labels = shift_labels.view(-1)
|
| 376 |
+
# Enable model parallelism
|
| 377 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
| 378 |
+
loss = loss_fct(shift_logits, shift_labels)
|
| 379 |
+
|
| 380 |
+
if not return_dict:
|
| 381 |
+
output = (logits,) + outputs[1:]
|
| 382 |
+
return (loss,) + output if loss is not None else output
|
| 383 |
+
|
| 384 |
+
return CausalLMOutputWithPast(
|
| 385 |
+
loss=loss,
|
| 386 |
+
logits=logits,
|
| 387 |
+
past_key_values=outputs.past_key_values,
|
| 388 |
+
hidden_states=outputs.hidden_states,
|
| 389 |
+
attentions=outputs.attentions,
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def prepare_inputs_for_generation(
|
| 394 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
| 395 |
+
):
|
| 396 |
+
# Omit tokens covered by past_key_values
|
| 397 |
+
if past_key_values is not None:
|
| 398 |
+
if isinstance(past_key_values, Cache):
|
| 399 |
+
cache_length = past_key_values.get_seq_length()
|
| 400 |
+
past_length = past_key_values.seen_tokens
|
| 401 |
+
max_cache_length = past_key_values.get_max_length()
|
| 402 |
+
else:
|
| 403 |
+
cache_length = past_length = past_key_values[0][0].shape[2]
|
| 404 |
+
max_cache_length = None
|
| 405 |
+
|
| 406 |
+
# Keep only the unprocessed tokens:
|
| 407 |
+
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
| 408 |
+
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
| 409 |
+
# input)
|
| 410 |
+
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
| 411 |
+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
| 412 |
+
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
| 413 |
+
# input_ids based on the past_length.
|
| 414 |
+
elif past_length < input_ids.shape[1]:
|
| 415 |
+
input_ids = input_ids[:, past_length:]
|
| 416 |
+
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
| 417 |
+
|
| 418 |
+
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
| 419 |
+
if (
|
| 420 |
+
max_cache_length is not None
|
| 421 |
+
and attention_mask is not None
|
| 422 |
+
and cache_length + input_ids.shape[1] > max_cache_length
|
| 423 |
+
):
|
| 424 |
+
attention_mask = attention_mask[:, -max_cache_length:]
|
| 425 |
+
|
| 426 |
+
position_ids = kwargs.get("position_ids", None)
|
| 427 |
+
if attention_mask is not None and position_ids is None:
|
| 428 |
+
# create position_ids on the fly for batch generation
|
| 429 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 430 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 431 |
+
if past_key_values:
|
| 432 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
| 433 |
+
|
| 434 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 435 |
+
if inputs_embeds is not None and past_key_values is None:
|
| 436 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
| 437 |
+
else:
|
| 438 |
+
model_inputs = {"input_ids": input_ids}
|
| 439 |
+
|
| 440 |
+
model_inputs.update(
|
| 441 |
+
{
|
| 442 |
+
"position_ids": position_ids,
|
| 443 |
+
"past_key_values": past_key_values,
|
| 444 |
+
"use_cache": kwargs.get("use_cache"),
|
| 445 |
+
"attention_mask": attention_mask,
|
| 446 |
+
"images": kwargs.get("images", None),
|
| 447 |
+
}
|
| 448 |
+
)
|
| 449 |
+
return model_inputs
|
| 450 |
+
|
| 451 |
+
def initialize_vision_tokenizer(
|
| 452 |
+
self,
|
| 453 |
+
tokenizer,
|
| 454 |
+
freeze_lm_model=False,
|
| 455 |
+
pretrained_stage1_model=None,
|
| 456 |
+
device="cuda"
|
| 457 |
+
):
|
| 458 |
+
config = self.get_model().config
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
self.resize_token_embeddings(len(tokenizer))
|
| 462 |
+
|
| 463 |
+
config.im_patch_token = 151859
|
| 464 |
+
|
| 465 |
+
config.use_im_start_end = True
|
| 466 |
+
|
| 467 |
+
if config.use_im_start_end:
|
| 468 |
+
self.resize_token_embeddings(len(tokenizer))
|
| 469 |
+
config.im_start_token, config.im_end_token = 151857, 151858
|
| 470 |
+
|
| 471 |
+
def load_image(self, image_file):
|
| 472 |
+
if image_file.startswith('http') or image_file.startswith('https'):
|
| 473 |
+
response = requests.get(image_file)
|
| 474 |
+
image = Image.open(BytesIO(response.content)).convert('RGB')
|
| 475 |
+
else:
|
| 476 |
+
image = Image.open(image_file).convert('RGB')
|
| 477 |
+
return image
|
| 478 |
+
|
| 479 |
+
def disable_torch_init(self):
|
| 480 |
+
"""
|
| 481 |
+
Disable the redundant torch default initialization to accelerate model creation.
|
| 482 |
+
"""
|
| 483 |
+
import torch
|
| 484 |
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
| 485 |
+
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
| 486 |
+
|
| 487 |
+
def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False):
|
| 488 |
+
|
| 489 |
+
self.disable_torch_init()
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
image_processor_high = GOTImageEvalProcessor(image_size=1024)
|
| 493 |
+
|
| 494 |
+
use_im_start_end = True
|
| 495 |
+
|
| 496 |
+
image_token_len = 256
|
| 497 |
+
|
| 498 |
+
if gradio_input:
|
| 499 |
+
image = image_file.copy()
|
| 500 |
+
else:
|
| 501 |
+
image = self.load_image(image_file)
|
| 502 |
+
|
| 503 |
+
w, h = image.size
|
| 504 |
+
|
| 505 |
+
if ocr_type == 'format':
|
| 506 |
+
qs = 'OCR with format: '
|
| 507 |
+
else:
|
| 508 |
+
qs = 'OCR: '
|
| 509 |
+
|
| 510 |
+
if ocr_box:
|
| 511 |
+
bbox = eval(ocr_box)
|
| 512 |
+
if len(bbox) == 2:
|
| 513 |
+
bbox[0] = int(bbox[0]/w*1000)
|
| 514 |
+
bbox[1] = int(bbox[1]/h*1000)
|
| 515 |
+
if len(bbox) == 4:
|
| 516 |
+
bbox[0] = int(bbox[0]/w*1000)
|
| 517 |
+
bbox[1] = int(bbox[1]/h*1000)
|
| 518 |
+
bbox[2] = int(bbox[2]/w*1000)
|
| 519 |
+
bbox[3] = int(bbox[3]/h*1000)
|
| 520 |
+
if ocr_type == 'format':
|
| 521 |
+
qs = str(bbox) + ' ' + 'OCR with format: '
|
| 522 |
+
else:
|
| 523 |
+
qs = str(bbox) + ' ' + 'OCR: '
|
| 524 |
+
|
| 525 |
+
if ocr_color:
|
| 526 |
+
if ocr_type == 'format':
|
| 527 |
+
qs = '[' + ocr_color + ']' + ' ' + 'OCR with format: '
|
| 528 |
+
else:
|
| 529 |
+
qs = '[' + ocr_color + ']' + ' ' + 'OCR: '
|
| 530 |
+
|
| 531 |
+
if use_im_start_end:
|
| 532 |
+
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN + '\n' + qs
|
| 533 |
+
else:
|
| 534 |
+
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
conv_mpt = Conversation(
|
| 538 |
+
system="""<|im_start|>system
|
| 539 |
+
You should follow the instructions carefully and explain your answers in detail.""",
|
| 540 |
+
# system = None,
|
| 541 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
| 542 |
+
version="mpt",
|
| 543 |
+
messages=(),
|
| 544 |
+
offset=0,
|
| 545 |
+
sep_style=SeparatorStyle.MPT,
|
| 546 |
+
sep="<|im_end|>",
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
conv = conv_mpt.copy()
|
| 550 |
+
conv.append_message(conv.roles[0], qs)
|
| 551 |
+
conv.append_message(conv.roles[1], None)
|
| 552 |
+
prompt = conv.get_prompt()
|
| 553 |
+
|
| 554 |
+
if print_prompt:
|
| 555 |
+
print(prompt)
|
| 556 |
+
|
| 557 |
+
inputs = tokenizer([prompt])
|
| 558 |
+
|
| 559 |
+
image_tensor_1 = image_processor_high(image)
|
| 560 |
+
|
| 561 |
+
input_ids = torch.as_tensor(inputs.input_ids).cuda()
|
| 562 |
+
|
| 563 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
| 564 |
+
keywords = [stop_str]
|
| 565 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
| 566 |
+
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 567 |
+
|
| 568 |
+
if stream_flag:
|
| 569 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
| 570 |
+
output_ids = self.generate(
|
| 571 |
+
input_ids,
|
| 572 |
+
images=[image_tensor_1.unsqueeze(0).half().cuda()],
|
| 573 |
+
do_sample=False,
|
| 574 |
+
num_beams = 1,
|
| 575 |
+
no_repeat_ngram_size = 20,
|
| 576 |
+
streamer=streamer,
|
| 577 |
+
max_new_tokens=4096,
|
| 578 |
+
stopping_criteria=[stopping_criteria]
|
| 579 |
+
)
|
| 580 |
+
else:
|
| 581 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
| 582 |
+
output_ids = self.generate(
|
| 583 |
+
input_ids,
|
| 584 |
+
images=[image_tensor_1.unsqueeze(0).half().cuda()],
|
| 585 |
+
do_sample=False,
|
| 586 |
+
num_beams = 1,
|
| 587 |
+
no_repeat_ngram_size = 20,
|
| 588 |
+
# streamer=streamer,
|
| 589 |
+
max_new_tokens=4096,
|
| 590 |
+
stopping_criteria=[stopping_criteria]
|
| 591 |
+
)
|
| 592 |
+
|
| 593 |
+
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
| 594 |
+
|
| 595 |
+
if outputs.endswith(stop_str):
|
| 596 |
+
outputs = outputs[:-len(stop_str)]
|
| 597 |
+
outputs = outputs.strip()
|
| 598 |
+
response_str = outputs
|
| 599 |
+
|
| 600 |
+
if render:
|
| 601 |
+
print('==============rendering===============')
|
| 602 |
+
from .render_tools import svg_to_html, content_mmd_to_html, tik_html, translation_table
|
| 603 |
+
|
| 604 |
+
if '**kern' in outputs:
|
| 605 |
+
import verovio
|
| 606 |
+
tk = verovio.toolkit()
|
| 607 |
+
tk.loadData(outputs)
|
| 608 |
+
tk.setOptions({"pageWidth": 2100, "footer": 'none',
|
| 609 |
+
'barLineWidth': 0.5, 'beamMaxSlope': 15,
|
| 610 |
+
'staffLineWidth': 0.2, 'spacingStaff': 6})
|
| 611 |
+
tk.getPageCount()
|
| 612 |
+
svg = tk.renderToSVG()
|
| 613 |
+
svg = svg.replace("overflow=\"inherit\"", "overflow=\"visible\"")
|
| 614 |
+
|
| 615 |
+
svg_to_html(svg, save_render_file)
|
| 616 |
+
|
| 617 |
+
if ocr_type == 'format' and '**kern' not in outputs:
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
if '\\begin{tikzpicture}' not in outputs:
|
| 621 |
+
html_path_2 = save_render_file
|
| 622 |
+
right_num = outputs.count('\\right')
|
| 623 |
+
left_num = outputs.count('\\left')
|
| 624 |
+
|
| 625 |
+
if right_num != left_num:
|
| 626 |
+
outputs = outputs.replace('\\left(', '(').replace('\\right)', ')').replace('\\left[', '[').replace('\\right]', ']').replace('\\left{', '{').replace('\\right}', '}').replace('\\left|', '|').replace('\\right|', '|').replace('\\left.', '.').replace('\\right.', '.')
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
outputs = outputs.replace('"', '``').replace('$', '')
|
| 630 |
+
|
| 631 |
+
outputs_list = outputs.split('\n')
|
| 632 |
+
gt= ''
|
| 633 |
+
for out in outputs_list:
|
| 634 |
+
gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
|
| 635 |
+
|
| 636 |
+
gt = gt[:-2]
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
lines = content_mmd_to_html
|
| 640 |
+
lines = lines.split("const text =")
|
| 641 |
+
new_web = lines[0] + 'const text =' + gt + lines[1]
|
| 642 |
+
|
| 643 |
+
else:
|
| 644 |
+
html_path_2 = save_render_file
|
| 645 |
+
outputs = outputs.translate(translation_table)
|
| 646 |
+
outputs_list = outputs.split('\n')
|
| 647 |
+
gt= ''
|
| 648 |
+
for out in outputs_list:
|
| 649 |
+
if out:
|
| 650 |
+
if '\\begin{tikzpicture}' not in out and '\\end{tikzpicture}' not in out:
|
| 651 |
+
while out[-1] == ' ':
|
| 652 |
+
out = out[:-1]
|
| 653 |
+
if out is None:
|
| 654 |
+
break
|
| 655 |
+
|
| 656 |
+
if out:
|
| 657 |
+
if out[-1] != ';':
|
| 658 |
+
gt += out[:-1] + ';\n'
|
| 659 |
+
else:
|
| 660 |
+
gt += out + '\n'
|
| 661 |
+
else:
|
| 662 |
+
gt += out + '\n'
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
lines = tik_html
|
| 666 |
+
lines = lines.split("const text =")
|
| 667 |
+
new_web = lines[0] + gt + lines[1]
|
| 668 |
+
|
| 669 |
+
with open(html_path_2, 'w') as web_f_new:
|
| 670 |
+
web_f_new.write(new_web)
|
| 671 |
+
return response_str
|
| 672 |
+
|
| 673 |
+
def dynamic_preprocess(self, image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True):
|
| 674 |
+
|
| 675 |
+
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
| 676 |
+
best_ratio_diff = float('inf')
|
| 677 |
+
best_ratio = (1, 1)
|
| 678 |
+
area = width * height
|
| 679 |
+
for ratio in target_ratios:
|
| 680 |
+
target_aspect_ratio = ratio[0] / ratio[1]
|
| 681 |
+
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
| 682 |
+
if ratio_diff < best_ratio_diff:
|
| 683 |
+
best_ratio_diff = ratio_diff
|
| 684 |
+
best_ratio = ratio
|
| 685 |
+
elif ratio_diff == best_ratio_diff:
|
| 686 |
+
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
| 687 |
+
best_ratio = ratio
|
| 688 |
+
# print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
|
| 689 |
+
return best_ratio
|
| 690 |
+
|
| 691 |
+
orig_width, orig_height = image.size
|
| 692 |
+
aspect_ratio = orig_width / orig_height
|
| 693 |
+
|
| 694 |
+
# calculate the existing image aspect ratio
|
| 695 |
+
target_ratios = set(
|
| 696 |
+
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
|
| 697 |
+
i * j <= max_num and i * j >= min_num)
|
| 698 |
+
# print(target_ratios)
|
| 699 |
+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
| 700 |
+
|
| 701 |
+
# find the closest aspect ratio to the target
|
| 702 |
+
target_aspect_ratio = find_closest_aspect_ratio(
|
| 703 |
+
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
|
| 704 |
+
|
| 705 |
+
# print(target_aspect_ratio)
|
| 706 |
+
# calculate the target width and height
|
| 707 |
+
target_width = image_size * target_aspect_ratio[0]
|
| 708 |
+
target_height = image_size * target_aspect_ratio[1]
|
| 709 |
+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
| 710 |
+
|
| 711 |
+
# resize the image
|
| 712 |
+
resized_img = image.resize((target_width, target_height))
|
| 713 |
+
processed_images = []
|
| 714 |
+
for i in range(blocks):
|
| 715 |
+
box = (
|
| 716 |
+
(i % (target_width // image_size)) * image_size,
|
| 717 |
+
(i // (target_width // image_size)) * image_size,
|
| 718 |
+
((i % (target_width // image_size)) + 1) * image_size,
|
| 719 |
+
((i // (target_width // image_size)) + 1) * image_size
|
| 720 |
+
)
|
| 721 |
+
# split the image
|
| 722 |
+
split_img = resized_img.crop(box)
|
| 723 |
+
processed_images.append(split_img)
|
| 724 |
+
assert len(processed_images) == blocks
|
| 725 |
+
if use_thumbnail and len(processed_images) != 1:
|
| 726 |
+
thumbnail_img = image.resize((image_size, image_size))
|
| 727 |
+
processed_images.append(thumbnail_img)
|
| 728 |
+
return processed_images
|
| 729 |
+
|
| 730 |
+
|
| 731 |
+
def chat_crop(self, tokenizer, image_file, ocr_type, render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False):
|
| 732 |
+
# Model
|
| 733 |
+
self.disable_torch_init()
|
| 734 |
+
multi_page=False
|
| 735 |
+
|
| 736 |
+
|
| 737 |
+
image_processor_high = GOTImageEvalProcessor(image_size=1024)
|
| 738 |
+
|
| 739 |
+
use_im_start_end = True
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
image_token_len = 256
|
| 743 |
+
|
| 744 |
+
image_list = []
|
| 745 |
+
|
| 746 |
+
# if len(image_file_list)>1:
|
| 747 |
+
# multi_page = True
|
| 748 |
+
|
| 749 |
+
if multi_page:
|
| 750 |
+
qs = 'OCR with format across multi pages: '
|
| 751 |
+
# only for png files
|
| 752 |
+
# import glob
|
| 753 |
+
# from natsort import natsorted
|
| 754 |
+
# patches = glob.glob(image_file + '/*png')
|
| 755 |
+
patches = image_file
|
| 756 |
+
# patches = natsorted(patches)
|
| 757 |
+
sub_images = []
|
| 758 |
+
for sub_image in patches:
|
| 759 |
+
sub_images.append(self.load_image(sub_image))
|
| 760 |
+
|
| 761 |
+
ll = len(patches)
|
| 762 |
+
# print(patches)
|
| 763 |
+
# print("len ll: ", ll)
|
| 764 |
+
|
| 765 |
+
else:
|
| 766 |
+
if ocr_type == 'format':
|
| 767 |
+
qs = 'OCR with format upon the patch reference: '
|
| 768 |
+
else:
|
| 769 |
+
qs = 'OCR upon the patch reference: '
|
| 770 |
+
if gradio_input:
|
| 771 |
+
img = image_file.copy()
|
| 772 |
+
else:
|
| 773 |
+
img = self.load_image(image_file)
|
| 774 |
+
sub_images = self.dynamic_preprocess(img)
|
| 775 |
+
ll = len(sub_images)
|
| 776 |
+
|
| 777 |
+
for image in sub_images:
|
| 778 |
+
image_tensor_1 = image_processor_high(image)
|
| 779 |
+
image_list.append(image_tensor_1)
|
| 780 |
+
|
| 781 |
+
|
| 782 |
+
image_list = torch.stack(image_list)
|
| 783 |
+
|
| 784 |
+
print('====new images batch size======: \n',image_list.shape)
|
| 785 |
+
|
| 786 |
+
|
| 787 |
+
if use_im_start_end:
|
| 788 |
+
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len*ll + DEFAULT_IM_END_TOKEN + '\n' + qs
|
| 789 |
+
else:
|
| 790 |
+
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
|
| 791 |
+
|
| 792 |
+
|
| 793 |
+
conv_mpt = Conversation(
|
| 794 |
+
system="""<|im_start|>system
|
| 795 |
+
You should follow the instructions carefully and explain your answers in detail.""",
|
| 796 |
+
# system = None,
|
| 797 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
| 798 |
+
version="mpt",
|
| 799 |
+
messages=(),
|
| 800 |
+
offset=0,
|
| 801 |
+
sep_style=SeparatorStyle.MPT,
|
| 802 |
+
sep="<|im_end|>",
|
| 803 |
+
)
|
| 804 |
+
|
| 805 |
+
conv = conv_mpt.copy()
|
| 806 |
+
conv.append_message(conv.roles[0], qs)
|
| 807 |
+
conv.append_message(conv.roles[1], None)
|
| 808 |
+
prompt = conv.get_prompt()
|
| 809 |
+
|
| 810 |
+
if print_prompt:
|
| 811 |
+
print(prompt)
|
| 812 |
+
|
| 813 |
+
inputs = tokenizer([prompt])
|
| 814 |
+
|
| 815 |
+
input_ids = torch.as_tensor(inputs.input_ids).cuda()
|
| 816 |
+
|
| 817 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
| 818 |
+
keywords = [stop_str]
|
| 819 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
| 820 |
+
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 821 |
+
|
| 822 |
+
if stream_flag:
|
| 823 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
| 824 |
+
output_ids = self.generate(
|
| 825 |
+
input_ids,
|
| 826 |
+
images=[image_list.half().cuda()],
|
| 827 |
+
do_sample=False,
|
| 828 |
+
num_beams = 1,
|
| 829 |
+
# no_repeat_ngram_size = 20,
|
| 830 |
+
streamer=streamer,
|
| 831 |
+
max_new_tokens=4096,
|
| 832 |
+
stopping_criteria=[stopping_criteria]
|
| 833 |
+
)
|
| 834 |
+
else:
|
| 835 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
| 836 |
+
output_ids = self.generate(
|
| 837 |
+
input_ids,
|
| 838 |
+
images=[image_list.half().cuda()],
|
| 839 |
+
do_sample=False,
|
| 840 |
+
num_beams = 1,
|
| 841 |
+
# no_repeat_ngram_size = 20,
|
| 842 |
+
# streamer=streamer,
|
| 843 |
+
max_new_tokens=4096,
|
| 844 |
+
stopping_criteria=[stopping_criteria]
|
| 845 |
+
)
|
| 846 |
+
|
| 847 |
+
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
| 848 |
+
|
| 849 |
+
if outputs.endswith(stop_str):
|
| 850 |
+
outputs = outputs[:-len(stop_str)]
|
| 851 |
+
outputs = outputs.strip()
|
| 852 |
+
response_str = outputs
|
| 853 |
+
|
| 854 |
+
if render:
|
| 855 |
+
print('==============rendering===============')
|
| 856 |
+
from .render_tools import content_mmd_to_html
|
| 857 |
+
html_path_2 = save_render_file
|
| 858 |
+
right_num = outputs.count('\\right')
|
| 859 |
+
left_num = outputs.count('\\left')
|
| 860 |
+
|
| 861 |
+
if right_num != left_num:
|
| 862 |
+
outputs = outputs.replace('\\left(', '(').replace('\\right)', ')').replace('\\left[', '[').replace('\\right]', ']').replace('\\left{', '{').replace('\\right}', '}').replace('\\left|', '|').replace('\\right|', '|').replace('\\left.', '.').replace('\\right.', '.')
|
| 863 |
+
|
| 864 |
+
|
| 865 |
+
outputs = outputs.replace('"', '``').replace('$', '')
|
| 866 |
+
|
| 867 |
+
outputs_list = outputs.split('\n')
|
| 868 |
+
gt= ''
|
| 869 |
+
for out in outputs_list:
|
| 870 |
+
gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
|
| 871 |
+
|
| 872 |
+
gt = gt[:-2]
|
| 873 |
+
|
| 874 |
+
lines = content_mmd_to_html
|
| 875 |
+
lines = lines.split("const text =")
|
| 876 |
+
new_web = lines[0] + 'const text =' + gt + lines[1]
|
| 877 |
+
|
| 878 |
+
with open(html_path_2, 'w') as web_f_new:
|
| 879 |
+
web_f_new.write(new_web)
|
| 880 |
+
|
| 881 |
+
return response_str
|
custom_got/qwen.tiktoken
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
custom_got/render_tools.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
punctuation_dict = {
|
| 3 |
+
",": ",",
|
| 4 |
+
"。": ".",
|
| 5 |
+
|
| 6 |
+
}
|
| 7 |
+
translation_table = str.maketrans(punctuation_dict)
|
| 8 |
+
|
| 9 |
+
def svg_to_html(svg_content, output_filename):
|
| 10 |
+
|
| 11 |
+
html_content = f"""
|
| 12 |
+
<!DOCTYPE html>
|
| 13 |
+
<html lang="en">
|
| 14 |
+
<head>
|
| 15 |
+
<meta charset="UTF-8">
|
| 16 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 17 |
+
<title>SVG Embedded in HTML</title>
|
| 18 |
+
</head>
|
| 19 |
+
<body>
|
| 20 |
+
<svg width="2100" height="15000" xmlns="http://www.w3.org/2000/svg">
|
| 21 |
+
{svg_content}
|
| 22 |
+
</svg>
|
| 23 |
+
</body>
|
| 24 |
+
</html>
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
with open(output_filename, 'w') as file:
|
| 28 |
+
file.write(html_content)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
content_mmd_to_html = """<!DOCTYPE html>
|
| 33 |
+
<html lang="en" data-lt-installed="true"><head>
|
| 34 |
+
<meta charset="UTF-8">
|
| 35 |
+
<title>Title</title>
|
| 36 |
+
<script>
|
| 37 |
+
const text =
|
| 38 |
+
</script>
|
| 39 |
+
<style>
|
| 40 |
+
#content {
|
| 41 |
+
max-width: 800px;
|
| 42 |
+
margin: auto;
|
| 43 |
+
}
|
| 44 |
+
</style>
|
| 45 |
+
<script>
|
| 46 |
+
let script = document.createElement('script');
|
| 47 |
+
script.src = "https://cdn.jsdelivr.net/npm/mathpix-markdown-it@1.3.6/es5/bundle.js";
|
| 48 |
+
document.head.append(script);
|
| 49 |
+
|
| 50 |
+
script.onload = function() {
|
| 51 |
+
const isLoaded = window.loadMathJax();
|
| 52 |
+
if (isLoaded) {
|
| 53 |
+
console.log('Styles loaded!')
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
const el = window.document.getElementById('content-text');
|
| 57 |
+
if (el) {
|
| 58 |
+
const options = {
|
| 59 |
+
htmlTags: true
|
| 60 |
+
};
|
| 61 |
+
const html = window.render(text, options);
|
| 62 |
+
el.outerHTML = html;
|
| 63 |
+
}
|
| 64 |
+
};
|
| 65 |
+
</script>
|
| 66 |
+
</head>
|
| 67 |
+
<body>
|
| 68 |
+
<div id="content"><div id="content-text"></div></div>
|
| 69 |
+
</body>
|
| 70 |
+
</html>
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
tik_html = """
|
| 76 |
+
<!DOCTYPE html>
|
| 77 |
+
|
| 78 |
+
<html>
|
| 79 |
+
|
| 80 |
+
<head>
|
| 81 |
+
<meta charset="UTF-8">
|
| 82 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 83 |
+
<title>Document</title>
|
| 84 |
+
<link rel="stylesheet" type="text/css" href="https://tikzjax.com/v1/fonts.css">
|
| 85 |
+
<script src="https://tikzjax.com/v1/tikzjax.js"></script>
|
| 86 |
+
</head>
|
| 87 |
+
<body>
|
| 88 |
+
<script type="text/tikz">
|
| 89 |
+
const text =
|
| 90 |
+
</script>
|
| 91 |
+
</body>
|
| 92 |
+
</html>"""
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# print(tik_html)
|
custom_got/special_tokens_map.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"pad_token": {
|
| 3 |
+
"content": "<|endoftext|>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
}
|
| 9 |
+
}
|
custom_got/tokenization_qwen.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba Cloud.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
"""Tokenization classes for QWen."""
|
| 7 |
+
|
| 8 |
+
import base64
|
| 9 |
+
import logging
|
| 10 |
+
import os
|
| 11 |
+
import unicodedata
|
| 12 |
+
from typing import Collection, Dict, List, Set, Tuple, Union
|
| 13 |
+
|
| 14 |
+
import tiktoken
|
| 15 |
+
from transformers import PreTrainedTokenizer, AddedToken
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken"}
|
| 21 |
+
|
| 22 |
+
PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
|
| 23 |
+
ENDOFTEXT = "<|endoftext|>"
|
| 24 |
+
IMSTART = "<|im_start|>"
|
| 25 |
+
IMEND = "<|im_end|>"
|
| 26 |
+
# as the default behavior is changed to allow special tokens in
|
| 27 |
+
# regular texts, the surface forms of special tokens need to be
|
| 28 |
+
# as different as possible to minimize the impact
|
| 29 |
+
EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
|
| 30 |
+
SPECIAL_TOKENS = (
|
| 31 |
+
ENDOFTEXT,
|
| 32 |
+
IMSTART,
|
| 33 |
+
IMEND,
|
| 34 |
+
) + EXTRAS
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
|
| 38 |
+
with open(tiktoken_bpe_file, "rb") as f:
|
| 39 |
+
contents = f.read()
|
| 40 |
+
return {
|
| 41 |
+
base64.b64decode(token): int(rank)
|
| 42 |
+
for token, rank in (line.split() for line in contents.splitlines() if line)
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
class QWenTokenizer(PreTrainedTokenizer):
|
| 46 |
+
"""QWen tokenizer."""
|
| 47 |
+
|
| 48 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 49 |
+
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
vocab_file,
|
| 53 |
+
errors="replace",
|
| 54 |
+
image_start_tag='<img>',
|
| 55 |
+
image_end_tag='</img>',
|
| 56 |
+
image_pad_tag='<imgpad>',
|
| 57 |
+
ref_start_tag='<ref>',
|
| 58 |
+
ref_end_tag='</ref>',
|
| 59 |
+
box_start_tag='<box>',
|
| 60 |
+
box_end_tag='</box>',
|
| 61 |
+
quad_start_tag='<quad>',
|
| 62 |
+
quad_end_tag='</quad>',
|
| 63 |
+
**kwargs,
|
| 64 |
+
):
|
| 65 |
+
super().__init__(**kwargs)
|
| 66 |
+
|
| 67 |
+
self.image_start_tag = image_start_tag
|
| 68 |
+
self.image_end_tag = image_end_tag
|
| 69 |
+
self.image_pad_tag = image_pad_tag
|
| 70 |
+
self.ref_start_tag = ref_start_tag
|
| 71 |
+
self.ref_end_tag = ref_end_tag
|
| 72 |
+
self.box_start_tag = box_start_tag
|
| 73 |
+
self.box_end_tag = box_end_tag
|
| 74 |
+
self.quad_start_tag = quad_start_tag
|
| 75 |
+
self.quad_end_tag = quad_end_tag
|
| 76 |
+
self.IMAGE_ST = (
|
| 77 |
+
ref_start_tag, ref_end_tag,
|
| 78 |
+
box_start_tag, box_end_tag,
|
| 79 |
+
quad_start_tag, quad_end_tag,
|
| 80 |
+
image_start_tag, image_end_tag,
|
| 81 |
+
image_pad_tag
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
self.errors = errors # how to handle errors in decoding
|
| 85 |
+
|
| 86 |
+
self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: dict[bytes, int]
|
| 87 |
+
self.special_tokens = {
|
| 88 |
+
token: index
|
| 89 |
+
for index, token in enumerate(
|
| 90 |
+
SPECIAL_TOKENS + self.IMAGE_ST, start=len(self.mergeable_ranks)
|
| 91 |
+
)
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
self.img_start_id = self.special_tokens[self.image_start_tag]
|
| 95 |
+
self.img_end_id = self.special_tokens[self.image_end_tag]
|
| 96 |
+
self.img_pad_id = self.special_tokens[self.image_pad_tag]
|
| 97 |
+
self.ref_start_id = self.special_tokens[self.ref_start_tag]
|
| 98 |
+
self.ref_end_id = self.special_tokens[self.ref_end_tag]
|
| 99 |
+
self.box_start_id = self.special_tokens[self.box_start_tag]
|
| 100 |
+
self.box_end_id = self.special_tokens[self.box_end_tag]
|
| 101 |
+
self.quad_start_id = self.special_tokens[self.quad_start_tag]
|
| 102 |
+
self.quad_end_id = self.special_tokens[self.quad_end_tag]
|
| 103 |
+
|
| 104 |
+
enc = tiktoken.Encoding(
|
| 105 |
+
"Qwen",
|
| 106 |
+
pat_str=PAT_STR,
|
| 107 |
+
mergeable_ranks=self.mergeable_ranks,
|
| 108 |
+
special_tokens=self.special_tokens,
|
| 109 |
+
)
|
| 110 |
+
assert (
|
| 111 |
+
len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
|
| 112 |
+
), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
|
| 113 |
+
|
| 114 |
+
self.decoder = {
|
| 115 |
+
v: k for k, v in self.mergeable_ranks.items()
|
| 116 |
+
} # type: dict[int, bytes|str]
|
| 117 |
+
self.decoder.update({v: k for k, v in self.special_tokens.items()})
|
| 118 |
+
|
| 119 |
+
self.tokenizer = enc # type: tiktoken.Encoding
|
| 120 |
+
|
| 121 |
+
self.eod_id = self.tokenizer.eot_token
|
| 122 |
+
self.im_start_id = self.special_tokens[IMSTART]
|
| 123 |
+
self.im_end_id = self.special_tokens[IMEND]
|
| 124 |
+
|
| 125 |
+
def __len__(self) -> int:
|
| 126 |
+
return self.tokenizer.n_vocab
|
| 127 |
+
|
| 128 |
+
def get_vocab(self) -> Dict[bytes, int]:
|
| 129 |
+
return self.mergeable_ranks
|
| 130 |
+
|
| 131 |
+
def convert_tokens_to_ids(
|
| 132 |
+
self, tokens: Union[bytes, str, List[Union[bytes, str]]]
|
| 133 |
+
) -> List[int]:
|
| 134 |
+
ids = []
|
| 135 |
+
if isinstance(tokens, (str, bytes)):
|
| 136 |
+
if tokens in self.special_tokens:
|
| 137 |
+
return self.special_tokens[tokens]
|
| 138 |
+
else:
|
| 139 |
+
return self.mergeable_ranks.get(tokens)
|
| 140 |
+
for token in tokens:
|
| 141 |
+
if token in self.special_tokens:
|
| 142 |
+
ids.append(self.special_tokens[token])
|
| 143 |
+
else:
|
| 144 |
+
ids.append(self.mergeable_ranks.get(token))
|
| 145 |
+
return ids
|
| 146 |
+
|
| 147 |
+
def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
|
| 148 |
+
if not special_tokens and new_tokens:
|
| 149 |
+
raise ValueError('Adding regular tokens is not supported')
|
| 150 |
+
for token in new_tokens:
|
| 151 |
+
surface_form = token.content if isinstance(token, AddedToken) else token
|
| 152 |
+
if surface_form not in SPECIAL_TOKENS:
|
| 153 |
+
raise ValueError('Adding unknown special tokens is not supported')
|
| 154 |
+
return 0
|
| 155 |
+
|
| 156 |
+
def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
|
| 157 |
+
"""
|
| 158 |
+
Save only the vocabulary of the tokenizer (vocabulary).
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
`Tuple(str)`: Paths to the files saved.
|
| 162 |
+
"""
|
| 163 |
+
file_path = os.path.join(save_directory, "qwen.tiktoken")
|
| 164 |
+
with open(file_path, "w", encoding="utf8") as w:
|
| 165 |
+
for k, v in self.mergeable_ranks.items():
|
| 166 |
+
line = base64.b64encode(k).decode("utf8") + " " + str(v) + "\n"
|
| 167 |
+
w.write(line)
|
| 168 |
+
return (file_path,)
|
| 169 |
+
|
| 170 |
+
def tokenize(
|
| 171 |
+
self,
|
| 172 |
+
text: str,
|
| 173 |
+
allowed_special: Union[Set, str] = "all",
|
| 174 |
+
disallowed_special: Union[Collection, str] = (),
|
| 175 |
+
**kwargs,
|
| 176 |
+
) -> List[Union[bytes, str]]:
|
| 177 |
+
"""
|
| 178 |
+
Converts a string in a sequence of tokens.
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
text (`str`):
|
| 182 |
+
The sequence to be encoded.
|
| 183 |
+
allowed_special (`Literal["all"]` or `set`):
|
| 184 |
+
The surface forms of the tokens to be encoded as special tokens in regular texts.
|
| 185 |
+
Default to "all".
|
| 186 |
+
disallowed_special (`Literal["all"]` or `Collection`):
|
| 187 |
+
The surface forms of the tokens that should not be in regular texts and trigger errors.
|
| 188 |
+
Default to an empty tuple.
|
| 189 |
+
|
| 190 |
+
kwargs (additional keyword arguments, *optional*):
|
| 191 |
+
Will be passed to the underlying model specific encode method.
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
`List[bytes|str]`: The list of tokens.
|
| 195 |
+
"""
|
| 196 |
+
tokens = []
|
| 197 |
+
text = unicodedata.normalize("NFC", text)
|
| 198 |
+
|
| 199 |
+
# this implementation takes a detour: text -> token id -> token surface forms
|
| 200 |
+
for t in self.tokenizer.encode(
|
| 201 |
+
text, allowed_special=allowed_special, disallowed_special=disallowed_special
|
| 202 |
+
):
|
| 203 |
+
tokens.append(self.decoder[t])
|
| 204 |
+
return tokens
|
| 205 |
+
|
| 206 |
+
def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
|
| 207 |
+
"""
|
| 208 |
+
Converts a sequence of tokens in a single string.
|
| 209 |
+
"""
|
| 210 |
+
text = ""
|
| 211 |
+
temp = b""
|
| 212 |
+
for t in tokens:
|
| 213 |
+
if isinstance(t, str):
|
| 214 |
+
if temp:
|
| 215 |
+
text += temp.decode("utf-8", errors=self.errors)
|
| 216 |
+
temp = b""
|
| 217 |
+
text += t
|
| 218 |
+
elif isinstance(t, bytes):
|
| 219 |
+
temp += t
|
| 220 |
+
else:
|
| 221 |
+
raise TypeError("token should only be of type types or str")
|
| 222 |
+
if temp:
|
| 223 |
+
text += temp.decode("utf-8", errors=self.errors)
|
| 224 |
+
return text
|
| 225 |
+
|
| 226 |
+
@property
|
| 227 |
+
def vocab_size(self):
|
| 228 |
+
return self.tokenizer.n_vocab
|
| 229 |
+
|
| 230 |
+
def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
|
| 231 |
+
"""Converts an id to a token, special tokens included"""
|
| 232 |
+
if index in self.decoder:
|
| 233 |
+
return self.decoder[index]
|
| 234 |
+
raise ValueError("unknown ids")
|
| 235 |
+
|
| 236 |
+
def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
|
| 237 |
+
"""Converts a token to an id using the vocab, special tokens included"""
|
| 238 |
+
if token in self.special_tokens:
|
| 239 |
+
return self.special_tokens[token]
|
| 240 |
+
if token in self.mergeable_ranks:
|
| 241 |
+
return self.mergeable_ranks[token]
|
| 242 |
+
raise ValueError("unknown token")
|
| 243 |
+
|
| 244 |
+
def _tokenize(self, text: str, **kwargs):
|
| 245 |
+
"""
|
| 246 |
+
Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
|
| 247 |
+
vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
|
| 248 |
+
|
| 249 |
+
Do NOT take care of added tokens.
|
| 250 |
+
"""
|
| 251 |
+
raise NotImplementedError
|
| 252 |
+
|
| 253 |
+
def _decode(
|
| 254 |
+
self,
|
| 255 |
+
token_ids: Union[int, List[int]],
|
| 256 |
+
skip_special_tokens: bool = False,
|
| 257 |
+
errors: str = None,
|
| 258 |
+
**kwargs,
|
| 259 |
+
) -> str:
|
| 260 |
+
if isinstance(token_ids, int):
|
| 261 |
+
token_ids = [token_ids]
|
| 262 |
+
if skip_special_tokens:
|
| 263 |
+
token_ids = [i for i in token_ids if i < self.eod_id]
|
| 264 |
+
return self.tokenizer.decode(token_ids, errors=errors or self.errors)
|
custom_got/tokenizer_config.json
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {},
|
| 3 |
+
"auto_map": {
|
| 4 |
+
"AutoTokenizer": [
|
| 5 |
+
"tokenization_qwen.QWenTokenizer",
|
| 6 |
+
null
|
| 7 |
+
]
|
| 8 |
+
},
|
| 9 |
+
"clean_up_tokenization_spaces": true,
|
| 10 |
+
"model_max_length": 8000,
|
| 11 |
+
"pad_token": "<|endoftext|>",
|
| 12 |
+
"padding_side": "right",
|
| 13 |
+
"tokenizer_class": "QWenTokenizer"
|
| 14 |
+
}
|
dataset.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5c52de9875d5559635129df10b3a3466167b4041def7208b435c72531c970320
|
| 3 |
+
size 23713278
|
dataset_creation.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
dataset = pd.read_csv('ocr_task/data_80k/data.csv')
|
| 7 |
+
labels = dataset['image_file']
|
| 8 |
+
text = dataset['text']
|
| 9 |
+
json_data = []
|
| 10 |
+
images_path = 'drive/MyDrive/data_80k/output_images/'
|
| 11 |
+
for i in range(len(labels)):
|
| 12 |
+
json_data.append(
|
| 13 |
+
{
|
| 14 |
+
"query": "<image>",
|
| 15 |
+
"response": text[i],
|
| 16 |
+
"images": [os.path.join(images_path, labels[i])],
|
| 17 |
+
}
|
| 18 |
+
)
|
| 19 |
+
with open('dataset.json', 'w') as f:
|
| 20 |
+
json.dump(json_data, f)
|
| 21 |
+
|
main.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoModel, AutoTokenizer
|
| 2 |
+
import torch
|
| 3 |
+
from byaldi.RAGModel import RAGMultiModalModel
|
| 4 |
+
from byaldi.colpali import ColPaliModel
|
| 5 |
+
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
| 6 |
+
from qwen_vl_utils import process_vision_info
|
| 7 |
+
import torch
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
colpali_model = ColPaliModel.from_pretrained('vidore/colpali')
|
| 13 |
+
print(colpali_model.doc_id_to_metadata)
|
| 14 |
+
model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype=torch.bfloat16).eval()
|
| 15 |
+
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True)
|
| 16 |
+
messages = [
|
| 17 |
+
{
|
| 18 |
+
"role": "user",
|
| 19 |
+
"content": [
|
| 20 |
+
{
|
| 21 |
+
"type": "image",
|
| 22 |
+
"image": Image.open('template.jpg'),
|
| 23 |
+
},
|
| 24 |
+
{"type": "text", "text": 'Return full text of the document as a plain text'},
|
| 25 |
+
],
|
| 26 |
+
}
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
text = processor.apply_chat_template(
|
| 30 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 31 |
+
)
|
| 32 |
+
img = Image.open('docs/hindi_template.jpg')
|
| 33 |
+
inputs = processor(
|
| 34 |
+
text=text,
|
| 35 |
+
images=img,
|
| 36 |
+
padding=True,
|
| 37 |
+
return_tensors="pt",
|
| 38 |
+
)
|
| 39 |
+
inputs = inputs.to("cpu")
|
| 40 |
+
generated_ids = model.generate(**inputs, max_new_tokens=5000)
|
| 41 |
+
generated_ids_trimmed = [
|
| 42 |
+
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 43 |
+
]
|
| 44 |
+
output_text = processor.batch_decode(
|
| 45 |
+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 46 |
+
)
|
| 47 |
+
print(output_text)
|
main_got.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoModel, AutoTokenizer
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def extract_text(image_path):
|
| 6 |
+
if torch.cuda.is_available():
|
| 7 |
+
device = torch.device('cuda') # If cuda is available, use it, otherwise use CPU
|
| 8 |
+
else:
|
| 9 |
+
device = torch.device('cpu')
|
| 10 |
+
|
| 11 |
+
tokenizer = AutoTokenizer.from_pretrained('custom_got',
|
| 12 |
+
trust_remote_code=True # Allows custom code to load model from hub
|
| 13 |
+
)
|
| 14 |
+
model = AutoModel.from_pretrained('custom_got',
|
| 15 |
+
trust_remote_code=True,
|
| 16 |
+
low_cpu_mem_usage=True,
|
| 17 |
+
device_map=device.type,
|
| 18 |
+
use_safetensors=True, # This format is faster, more memory efficient
|
| 19 |
+
# and provides safe deserialization unlike pickle-based one
|
| 20 |
+
pad_token_id=tokenizer.eos_token_id, # Set the pad token from tokenizer
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
image_file = image_path
|
| 24 |
+
# Extract text
|
| 25 |
+
res = model.chat(tokenizer, image_file, ocr_type='ocr')
|
| 26 |
+
return res
|
requirements.txt
ADDED
|
Binary file (7.14 kB). View file
|
|
|