Upload training code
Browse files- training/convert_atf.py +371 -0
- training/cuneiform_ocr_eval.ipynb +254 -0
- training/cuneiform_ocr_grpo.ipynb +0 -0
- training/cuneiform_ocr_sft.ipynb +397 -0
- training/get_cdli_dataset.py +279 -0
training/convert_atf.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from collections import Counter
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ParsedATF:
|
| 7 |
+
"""Represents a parsed ATF document with methods to extract data."""
|
| 8 |
+
|
| 9 |
+
# Face types
|
| 10 |
+
ALL_FACES = [
|
| 11 |
+
"obverse",
|
| 12 |
+
"reverse",
|
| 13 |
+
"left",
|
| 14 |
+
"right",
|
| 15 |
+
"top",
|
| 16 |
+
"bottom",
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self, transliterations: dict, unicodes: dict, info: dict, used_signs: set
|
| 21 |
+
):
|
| 22 |
+
"""
|
| 23 |
+
Initialize parsed ATF data.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
transliterations: Dictionary mapping face names to transliteration line lists
|
| 27 |
+
unicodes: Dictionary mapping face names to unicode line lists
|
| 28 |
+
info: Metadata dictionary (e.g., language)
|
| 29 |
+
"""
|
| 30 |
+
self._transliterations = transliterations
|
| 31 |
+
self._unicodes = unicodes
|
| 32 |
+
self._info = info
|
| 33 |
+
self._used_signs = used_signs
|
| 34 |
+
|
| 35 |
+
def get_used_signs(self) -> set[str]:
|
| 36 |
+
"""Get the set of used signs."""
|
| 37 |
+
return self._used_signs
|
| 38 |
+
|
| 39 |
+
def get_transliteration(self, face: str) -> Optional[str]:
|
| 40 |
+
"""
|
| 41 |
+
Get the transliteration for a given face.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
face: The face name (e.g., 'obverse', 'reverse')
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
The transliteration as a string with lines separated by newlines,
|
| 48 |
+
or None if the face has no content
|
| 49 |
+
"""
|
| 50 |
+
if face in self._transliterations:
|
| 51 |
+
return self._transliterations[face]
|
| 52 |
+
return None
|
| 53 |
+
|
| 54 |
+
def get_unicode(self, face: str) -> Optional[str]:
|
| 55 |
+
"""
|
| 56 |
+
Get the unicode representation for a given face.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
face: The face name (e.g., 'obverse', 'reverse')
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
The unicode representation as a string with lines separated by newlines,
|
| 63 |
+
or None if the face has no content
|
| 64 |
+
"""
|
| 65 |
+
if face in self._unicodes:
|
| 66 |
+
return self._unicodes[face]
|
| 67 |
+
return None
|
| 68 |
+
|
| 69 |
+
def get_all_unicodes(self) -> dict[str, Optional[str]]:
|
| 70 |
+
"""
|
| 71 |
+
Get unicode for all faces.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
Dictionary mapping face names to unicode strings
|
| 75 |
+
"""
|
| 76 |
+
return {
|
| 77 |
+
f"{face}_unicode": self.get_unicode(face)
|
| 78 |
+
for face in self.ALL_FACES
|
| 79 |
+
if self.get_unicode(face) is not None
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
def get_all_transliterations(self) -> dict[str, Optional[str]]:
|
| 83 |
+
"""
|
| 84 |
+
Get transliteration for all faces.
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
Dictionary mapping face names to transliteration strings
|
| 88 |
+
"""
|
| 89 |
+
return {
|
| 90 |
+
f"{face}_transliteration": self.get_transliteration(face)
|
| 91 |
+
for face in self.ALL_FACES
|
| 92 |
+
if self.get_transliteration(face) is not None
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
@property
|
| 96 |
+
def info(self) -> dict:
|
| 97 |
+
"""Get parsing info (e.g., language)."""
|
| 98 |
+
return self._info
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class ATFConverter:
|
| 102 |
+
"""Converter for ATF (ASCII Transliteration Format) cuneiform text."""
|
| 103 |
+
|
| 104 |
+
# Face types
|
| 105 |
+
ALL_FACES = [
|
| 106 |
+
"obverse",
|
| 107 |
+
"reverse",
|
| 108 |
+
"left",
|
| 109 |
+
"right",
|
| 110 |
+
"top",
|
| 111 |
+
"bottom",
|
| 112 |
+
]
|
| 113 |
+
|
| 114 |
+
FACE_REMAPPING = {
|
| 115 |
+
"surface a": "obverse",
|
| 116 |
+
"surface b": "reverse",
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
# Special tokens
|
| 120 |
+
SPECIAL_TOKENS = [
|
| 121 |
+
"<B>", # broken
|
| 122 |
+
"<M>", # missing one or more token?
|
| 123 |
+
"<S>", # blank space
|
| 124 |
+
"<D>", # divine
|
| 125 |
+
"<munus>", # young woman, or woman
|
| 126 |
+
"<ansze>",
|
| 127 |
+
"<ki>",
|
| 128 |
+
"<disz>",
|
| 129 |
+
"x", # unknown signs
|
| 130 |
+
]
|
| 131 |
+
|
| 132 |
+
def __init__(self, token_path: str = "./data/cuneiform_vocab.tsv"):
|
| 133 |
+
"""
|
| 134 |
+
Initialize the ATF converter.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
token_path: Path to the cuneiform vocabulary file
|
| 138 |
+
"""
|
| 139 |
+
self.text2sign = self._load_token_mapping(token_path)
|
| 140 |
+
|
| 141 |
+
# Counters for statistics
|
| 142 |
+
self.vocab_freq = Counter()
|
| 143 |
+
self.new_tokens = Counter()
|
| 144 |
+
self.langs = Counter()
|
| 145 |
+
self.unknown_faces = Counter()
|
| 146 |
+
|
| 147 |
+
def _load_token_mapping(self, token_path: str) -> tuple[dict, dict]:
|
| 148 |
+
"""Load the text to sign and sign to text mappings."""
|
| 149 |
+
|
| 150 |
+
text2sign = {}
|
| 151 |
+
for t in open(token_path).readlines():
|
| 152 |
+
try:
|
| 153 |
+
k, s = t.strip("\n").split("\t")
|
| 154 |
+
except:
|
| 155 |
+
print(t)
|
| 156 |
+
continue
|
| 157 |
+
text2sign[k] = s.replace(" ", "")
|
| 158 |
+
|
| 159 |
+
return text2sign
|
| 160 |
+
|
| 161 |
+
def _remove_at(self, x: str) -> Optional[str]:
|
| 162 |
+
"""Remove @c or @t suffixes from tokens."""
|
| 163 |
+
if x.endswith("@c)") or x.endswith("@t)"):
|
| 164 |
+
return x[:-3] + ")"
|
| 165 |
+
return None
|
| 166 |
+
|
| 167 |
+
def _remove_spaces(self, x: list[str]) -> list[str]:
|
| 168 |
+
"""Remove consecutive space tokens."""
|
| 169 |
+
new_x = []
|
| 170 |
+
for item in x:
|
| 171 |
+
if item == "<S>" and len(new_x) > 0 and new_x[-1] == "<S>":
|
| 172 |
+
continue
|
| 173 |
+
new_x.append(item)
|
| 174 |
+
return new_x
|
| 175 |
+
|
| 176 |
+
def parse(self, raw_text: str) -> Optional[ParsedATF]:
|
| 177 |
+
"""
|
| 178 |
+
Parse ATF text and extract transliterations and unicode.
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
raw_text: The raw ATF text to parse
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
ParsedATF object if parsing succeeded, None if the language is not supported
|
| 185 |
+
"""
|
| 186 |
+
token_text = {"default": []}
|
| 187 |
+
info = {}
|
| 188 |
+
|
| 189 |
+
curr_face = "default"
|
| 190 |
+
sep = "\n"
|
| 191 |
+
if "\\n" in raw_text:
|
| 192 |
+
sep = "\\n"
|
| 193 |
+
|
| 194 |
+
for line in raw_text.split(sep):
|
| 195 |
+
line = line.strip()
|
| 196 |
+
|
| 197 |
+
if line.startswith("&") or line.startswith("'&"):
|
| 198 |
+
# metadata
|
| 199 |
+
pass
|
| 200 |
+
elif line.startswith("#atf"):
|
| 201 |
+
info["lang"] = line.split("lang ")[-1].strip()
|
| 202 |
+
self.langs[info["lang"]] += 1
|
| 203 |
+
if info["lang"] not in ["sux", "akk", "sux, akk", "akk _sux"]:
|
| 204 |
+
# do not process those not sux or akk
|
| 205 |
+
return None
|
| 206 |
+
elif (
|
| 207 |
+
line.startswith("#")
|
| 208 |
+
or line.startswith(">>")
|
| 209 |
+
or line.startswith("<<")
|
| 210 |
+
or line.startswith("||")
|
| 211 |
+
):
|
| 212 |
+
# comment/link
|
| 213 |
+
continue
|
| 214 |
+
elif line.startswith("$"):
|
| 215 |
+
if "broken" in line:
|
| 216 |
+
try:
|
| 217 |
+
token_text[curr_face].append("<B>")
|
| 218 |
+
except:
|
| 219 |
+
continue
|
| 220 |
+
elif line.startswith("@"):
|
| 221 |
+
key = line[1:].strip().strip("?")
|
| 222 |
+
if key in self.ALL_FACES:
|
| 223 |
+
curr_face = key
|
| 224 |
+
token_text[key] = []
|
| 225 |
+
elif key.startswith("column"):
|
| 226 |
+
token_text[curr_face].append("<COL>")
|
| 227 |
+
else:
|
| 228 |
+
self.unknown_faces[key] += 1
|
| 229 |
+
else:
|
| 230 |
+
# Process line content
|
| 231 |
+
self._process_line_content(line, curr_face, token_text)
|
| 232 |
+
|
| 233 |
+
# Build transliterations and unicodes from token_text
|
| 234 |
+
transliterations, unicodes, used_signs = self._build_outputs(token_text)
|
| 235 |
+
return ParsedATF(transliterations, unicodes, info, used_signs)
|
| 236 |
+
|
| 237 |
+
def _process_line_content(self, line: str, curr_face: str, token_text: dict):
|
| 238 |
+
"""Process a content line and extract tokens."""
|
| 239 |
+
# Special symbols
|
| 240 |
+
line = line.replace("{d}", "<D>")
|
| 241 |
+
|
| 242 |
+
for x in re.findall(r"\{.*?\}", line):
|
| 243 |
+
line = line.replace(x, " " + x[1:-1] + " ")
|
| 244 |
+
|
| 245 |
+
line = line.replace("($ blank space $)", "<S>")
|
| 246 |
+
|
| 247 |
+
# Remove underscore
|
| 248 |
+
line = line.replace("_", " ")
|
| 249 |
+
|
| 250 |
+
# Remove ending hash #
|
| 251 |
+
line = line.replace("#", "")
|
| 252 |
+
|
| 253 |
+
# Remove question mark, exclamation mark
|
| 254 |
+
line = line.replace("?", "")
|
| 255 |
+
line = line.replace("!", "")
|
| 256 |
+
|
| 257 |
+
# Remove [] and ()
|
| 258 |
+
for x in re.findall(r"\[.*?\]", line):
|
| 259 |
+
line = line.replace(x, "")
|
| 260 |
+
|
| 261 |
+
line = line.split(". ")
|
| 262 |
+
|
| 263 |
+
if len(line) >= 2:
|
| 264 |
+
# Make sure only leading line number is split
|
| 265 |
+
if len(line) > 2:
|
| 266 |
+
line = line[0], ". ".join(line[1:])
|
| 267 |
+
|
| 268 |
+
line_num, text = line
|
| 269 |
+
if curr_face != "":
|
| 270 |
+
tokens = text.split(" ")
|
| 271 |
+
signs = []
|
| 272 |
+
for i, t in enumerate(tokens):
|
| 273 |
+
# if i > 0 and len(signs) > 0:
|
| 274 |
+
# signs.append("<S>") # insert a space between words
|
| 275 |
+
|
| 276 |
+
if "-" in t:
|
| 277 |
+
ts = t.split("-")
|
| 278 |
+
for x in ts:
|
| 279 |
+
x = x.strip()
|
| 280 |
+
if len(x) == 0:
|
| 281 |
+
continue
|
| 282 |
+
if x in self.text2sign:
|
| 283 |
+
self.vocab_freq[x] += 1
|
| 284 |
+
signs.append(self.text2sign[x])
|
| 285 |
+
else:
|
| 286 |
+
new_x = self._remove_at(x)
|
| 287 |
+
if new_x and new_x in self.text2sign:
|
| 288 |
+
signs.append(self.text2sign[new_x])
|
| 289 |
+
else:
|
| 290 |
+
self.new_tokens[x] += 1
|
| 291 |
+
elif t in self.text2sign:
|
| 292 |
+
signs.append(self.text2sign[t])
|
| 293 |
+
elif t in self.SPECIAL_TOKENS:
|
| 294 |
+
self.vocab_freq[t] += 1
|
| 295 |
+
signs.append(t)
|
| 296 |
+
else:
|
| 297 |
+
new_x = self._remove_at(t)
|
| 298 |
+
if new_x and new_x in self.text2sign:
|
| 299 |
+
signs.append(self.text2sign[new_x])
|
| 300 |
+
else:
|
| 301 |
+
if len(t.strip()) > 0:
|
| 302 |
+
self.new_tokens[t] += 1
|
| 303 |
+
|
| 304 |
+
signs = self._remove_spaces(signs)
|
| 305 |
+
token_text[curr_face].append(
|
| 306 |
+
{"raw": text, "num": line_num, "sign": signs}
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
def _build_outputs(
|
| 310 |
+
self, token_text: dict
|
| 311 |
+
) -> tuple[dict[str, list[list[str]]], dict[str, list[list[str]]], set[str]]:
|
| 312 |
+
"""Build transliterations and unicode outputs from parsed token_text."""
|
| 313 |
+
transliterations = {}
|
| 314 |
+
unicodes = {}
|
| 315 |
+
used_signs = set()
|
| 316 |
+
|
| 317 |
+
for face in token_text.keys():
|
| 318 |
+
lines = token_text[face]
|
| 319 |
+
face_key = self.FACE_REMAPPING.get(face, face)
|
| 320 |
+
|
| 321 |
+
# List of columns, each column is a list of lines
|
| 322 |
+
face_transliterations: list[list[str]] = []
|
| 323 |
+
face_unicodes: list[list[str]] = []
|
| 324 |
+
|
| 325 |
+
current_column = {"transliteration": [], "unicode": []}
|
| 326 |
+
|
| 327 |
+
for line in lines:
|
| 328 |
+
if line == "<COL>":
|
| 329 |
+
if len(current_column["transliteration"]) > 0:
|
| 330 |
+
face_transliterations.append(current_column["transliteration"])
|
| 331 |
+
if len(current_column["unicode"]) > 0:
|
| 332 |
+
face_unicodes.append(current_column["unicode"])
|
| 333 |
+
current_column = {"transliteration": [], "unicode": []}
|
| 334 |
+
continue
|
| 335 |
+
|
| 336 |
+
if type(line) == str:
|
| 337 |
+
continue
|
| 338 |
+
|
| 339 |
+
used_signs.update(line.get("sign", ["<B>"]))
|
| 340 |
+
|
| 341 |
+
current_column["transliteration"].append(line.get("raw", "<B>"))
|
| 342 |
+
current_column["unicode"].append(" ".join(line.get("sign", ["<B>"])))
|
| 343 |
+
|
| 344 |
+
if len(current_column["transliteration"]) > 0:
|
| 345 |
+
face_transliterations.append(current_column["transliteration"])
|
| 346 |
+
if len(current_column["unicode"]) > 0:
|
| 347 |
+
face_unicodes.append(current_column["unicode"])
|
| 348 |
+
|
| 349 |
+
if len(face_transliterations) == 1:
|
| 350 |
+
# No need for column markers as there is only one column
|
| 351 |
+
transliterations[face_key] = "\n".join(face_transliterations[0])
|
| 352 |
+
else:
|
| 353 |
+
transliterations[face_key] = "\n".join(
|
| 354 |
+
[
|
| 355 |
+
f"@column {i+1}\n" + "\n".join(column)
|
| 356 |
+
for i, column in enumerate(face_transliterations)
|
| 357 |
+
]
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
if len(face_unicodes) == 1:
|
| 361 |
+
# No need for column markers as there is only one column
|
| 362 |
+
unicodes[face_key] = "\n".join(face_unicodes[0])
|
| 363 |
+
else:
|
| 364 |
+
unicodes[face_key] = "\n".join(
|
| 365 |
+
[
|
| 366 |
+
f"@column {i+1}\n" + "\n".join(column)
|
| 367 |
+
for i, column in enumerate(face_unicodes)
|
| 368 |
+
]
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
return transliterations, unicodes, used_signs
|
training/cuneiform_ocr_eval.ipynb
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "e4ca0fb0",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"import torch\n",
|
| 11 |
+
"from PIL import Image\n",
|
| 12 |
+
"from tqdm.auto import tqdm\n",
|
| 13 |
+
"from transformers import AutoModelForCausalLM, AutoProcessor"
|
| 14 |
+
]
|
| 15 |
+
},
|
| 16 |
+
{
|
| 17 |
+
"cell_type": "code",
|
| 18 |
+
"execution_count": null,
|
| 19 |
+
"id": "a961375e",
|
| 20 |
+
"metadata": {},
|
| 21 |
+
"outputs": [],
|
| 22 |
+
"source": [
|
| 23 |
+
"# Load dataset\n",
|
| 24 |
+
"from get_cdli_dataset import get_dataset, IMG_CACHE\n",
|
| 25 |
+
"\n",
|
| 26 |
+
"dataset = get_dataset()\n",
|
| 27 |
+
"test_dataset = dataset[\"test\"]\n",
|
| 28 |
+
"\n",
|
| 29 |
+
"print(test_dataset)"
|
| 30 |
+
]
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"cell_type": "code",
|
| 34 |
+
"execution_count": null,
|
| 35 |
+
"id": "e226c45c",
|
| 36 |
+
"metadata": {},
|
| 37 |
+
"outputs": [],
|
| 38 |
+
"source": [
|
| 39 |
+
"# Load the model\n",
|
| 40 |
+
"\n",
|
| 41 |
+
"# model_path = \"PaddlePaddle/PaddleOCR-VL\" # base\n",
|
| 42 |
+
"# model_path = \"./outputs/sft\"\n",
|
| 43 |
+
"model_path = \"../\"\n",
|
| 44 |
+
"\n",
|
| 45 |
+
"model = AutoModelForCausalLM.from_pretrained(\n",
|
| 46 |
+
" model_path, trust_remote_code=True, torch_dtype=torch.bfloat16\n",
|
| 47 |
+
").to(\"cuda\").eval()\n",
|
| 48 |
+
"processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)"
|
| 49 |
+
]
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"cell_type": "code",
|
| 53 |
+
"execution_count": null,
|
| 54 |
+
"id": "97b9a2cb",
|
| 55 |
+
"metadata": {},
|
| 56 |
+
"outputs": [],
|
| 57 |
+
"source": [
|
| 58 |
+
"import pyxdameraulevenshtein as dl\n",
|
| 59 |
+
"\n",
|
| 60 |
+
"def compute_ter(expected_ids: list[int], predicted_ids: list[int]) -> float:\n",
|
| 61 |
+
" \"\"\"\n",
|
| 62 |
+
" Compute Token Error Rate (TER) between ground truth and completion tokens.\n",
|
| 63 |
+
" TER = (substitutions + deletions + insertions) / len(ground_truth)\n",
|
| 64 |
+
"\n",
|
| 65 |
+
" TER is better than CER for cuneiform OCR as:\n",
|
| 66 |
+
" - Multi-character Unicode signs count as 1 token instead of multiple chars\n",
|
| 67 |
+
" - Special tokens like @obverse/@reverse count as 1 token\n",
|
| 68 |
+
" \"\"\"\n",
|
| 69 |
+
"\n",
|
| 70 |
+
" if len(expected_ids) == 0:\n",
|
| 71 |
+
" return 0.0 if len(predicted_ids) == 0 else 1.0\n",
|
| 72 |
+
"\n",
|
| 73 |
+
" # Calculate edit distance on token sequences\n",
|
| 74 |
+
" distance = dl.damerau_levenshtein_distance(expected_ids, predicted_ids)\n",
|
| 75 |
+
"\n",
|
| 76 |
+
" # TER is the edit distance normalized by the truth token count\n",
|
| 77 |
+
" ter = distance / max(1, len(expected_ids))\n",
|
| 78 |
+
"\n",
|
| 79 |
+
" return ter"
|
| 80 |
+
]
|
| 81 |
+
},
|
| 82 |
+
{
|
| 83 |
+
"cell_type": "code",
|
| 84 |
+
"execution_count": null,
|
| 85 |
+
"id": "859c4fc2",
|
| 86 |
+
"metadata": {},
|
| 87 |
+
"outputs": [],
|
| 88 |
+
"source": [
|
| 89 |
+
"# Run inference on all test examples\n",
|
| 90 |
+
"results = []\n",
|
| 91 |
+
"total_ter = 0.0\n",
|
| 92 |
+
"\n",
|
| 93 |
+
"pbar = tqdm(test_dataset, desc=\"Evaluating on test set\")\n",
|
| 94 |
+
"\n",
|
| 95 |
+
"for idx, example in enumerate(pbar):\n",
|
| 96 |
+
" expected = example[\"unicode\"]\n",
|
| 97 |
+
" expected_ids = processor.tokenizer.encode(expected, add_special_tokens = False)\n",
|
| 98 |
+
"\n",
|
| 99 |
+
" # Load image\n",
|
| 100 |
+
" with Image.open(IMG_CACHE / f\"P{str(example['id']).rjust(6, '0')}.jpg\").convert(\n",
|
| 101 |
+
" \"RGB\"\n",
|
| 102 |
+
" ) as image:\n",
|
| 103 |
+
" # Prepare input\n",
|
| 104 |
+
" messages = [\n",
|
| 105 |
+
" {\n",
|
| 106 |
+
" \"role\": \"user\",\n",
|
| 107 |
+
" \"content\": [\n",
|
| 108 |
+
" {\"type\": \"image\", \"image\": image},\n",
|
| 109 |
+
" {\"type\": \"text\", \"text\": \"OCR:\"},\n",
|
| 110 |
+
" ],\n",
|
| 111 |
+
" },\n",
|
| 112 |
+
" ]\n",
|
| 113 |
+
"\n",
|
| 114 |
+
" inputs = processor.apply_chat_template(\n",
|
| 115 |
+
" messages, \n",
|
| 116 |
+
" tokenize=True, \n",
|
| 117 |
+
" add_generation_prompt=True, \t\n",
|
| 118 |
+
" return_dict=True,\n",
|
| 119 |
+
" return_tensors=\"pt\"\n",
|
| 120 |
+
" ).to(\"cuda\")\n",
|
| 121 |
+
"\n",
|
| 122 |
+
" # Generate prediction\n",
|
| 123 |
+
" with torch.no_grad():\n",
|
| 124 |
+
" output_ids = model.generate(\n",
|
| 125 |
+
" **inputs,\n",
|
| 126 |
+
" use_cache=True,\n",
|
| 127 |
+
" max_new_tokens=int(len(expected_ids) * 1.2),\n",
|
| 128 |
+
" repetition_penalty=1.03,\n",
|
| 129 |
+
" )\n",
|
| 130 |
+
"\n",
|
| 131 |
+
" predicted_ids = output_ids[0][inputs[\"input_ids\"].shape[1] :][:-1].tolist()\n",
|
| 132 |
+
"\n",
|
| 133 |
+
" # Compute TER for this example\n",
|
| 134 |
+
" ter = compute_ter(expected_ids, predicted_ids)\n",
|
| 135 |
+
" total_ter += ter\n",
|
| 136 |
+
"\n",
|
| 137 |
+
" pbar.set_postfix_str(f\"AVG TER={total_ter / (idx+1):.3f}\")\n",
|
| 138 |
+
"\n",
|
| 139 |
+
" prediction = processor.decode(\n",
|
| 140 |
+
" predicted_ids,\n",
|
| 141 |
+
" skip_special_tokens=False,\n",
|
| 142 |
+
" ).strip()\n",
|
| 143 |
+
"\n",
|
| 144 |
+
" # Store results\n",
|
| 145 |
+
" results.append(\n",
|
| 146 |
+
" {\n",
|
| 147 |
+
" \"id\": example[\"id\"],\n",
|
| 148 |
+
" \"expected\": expected,\n",
|
| 149 |
+
" \"prediction\": prediction,\n",
|
| 150 |
+
" \"ter\": ter,\n",
|
| 151 |
+
" }\n",
|
| 152 |
+
" )\n",
|
| 153 |
+
" tqdm.write(f\"\\033[94m\\nID: {example['id']} | TER: {ter:.4f}\\033[0m\")\n",
|
| 154 |
+
" tqdm.write(f\"\\033[92mExpected:\\033[0m\\n{expected}\")\n",
|
| 155 |
+
" tqdm.write(f\"\\033[91mPredicted:\\033[0m\\n{prediction}\")\n",
|
| 156 |
+
"\n",
|
| 157 |
+
"# Compute averages\n",
|
| 158 |
+
"average_ter = total_ter / len(test_dataset)\n",
|
| 159 |
+
"print(f\"\\n{'='*60}\")\n",
|
| 160 |
+
"print(f\"Average Token Error Rate (TER): {average_ter:.4f} ({average_ter*100:.2f}%)\")\n",
|
| 161 |
+
"print(f\"{'='*60}\")"
|
| 162 |
+
]
|
| 163 |
+
},
|
| 164 |
+
{
|
| 165 |
+
"cell_type": "code",
|
| 166 |
+
"execution_count": null,
|
| 167 |
+
"id": "3c6a8e02",
|
| 168 |
+
"metadata": {},
|
| 169 |
+
"outputs": [],
|
| 170 |
+
"source": [
|
| 171 |
+
"# Show examples: best and worst predictions (sorted by TER)\n",
|
| 172 |
+
"sorted_results = sorted(results, key=lambda x: x[\"ter\"])\n",
|
| 173 |
+
"\n",
|
| 174 |
+
"print(\"=\"*60)\n",
|
| 175 |
+
"print(\"BEST PREDICTIONS (Lowest TER)\")\n",
|
| 176 |
+
"print(\"=\"*60)\n",
|
| 177 |
+
"for i in range(min(10, len(sorted_results))):\n",
|
| 178 |
+
" r = sorted_results[i]\n",
|
| 179 |
+
" print(f\"\\nExample {i+1} - ID: {r['id']} - TER: {r['ter']:.4f}\")\n",
|
| 180 |
+
" print(f\"Expected:\\n{r['expected']}\")\n",
|
| 181 |
+
" print(f\"Predicted:\\n{r['prediction']}\")\n",
|
| 182 |
+
" print(\"-\"*60)\n",
|
| 183 |
+
"\n",
|
| 184 |
+
"print(\"\\n\" + \"=\"*60)\n",
|
| 185 |
+
"print(\"WORST PREDICTIONS (Highest TER)\")\n",
|
| 186 |
+
"print(\"=\"*60)\n",
|
| 187 |
+
"for i in range(min(10, len(sorted_results))):\n",
|
| 188 |
+
" r = sorted_results[-(i+1)]\n",
|
| 189 |
+
" print(f\"\\nExample {i+1} - ID: {r['id']} - TER: {r['ter']:.4f}\")\n",
|
| 190 |
+
" print(f\"Expected:\\n{r['expected']}\")\n",
|
| 191 |
+
" print(f\"Predicted:\\n{r['prediction']}\")\n",
|
| 192 |
+
" print(\"-\"*60)"
|
| 193 |
+
]
|
| 194 |
+
},
|
| 195 |
+
{
|
| 196 |
+
"cell_type": "code",
|
| 197 |
+
"execution_count": null,
|
| 198 |
+
"id": "d5ceae30",
|
| 199 |
+
"metadata": {},
|
| 200 |
+
"outputs": [],
|
| 201 |
+
"source": [
|
| 202 |
+
"# TER and CER distribution statistics\n",
|
| 203 |
+
"import numpy as np\n",
|
| 204 |
+
"\n",
|
| 205 |
+
"ter_values = [r[\"ter\"] for r in results]\n",
|
| 206 |
+
"\n",
|
| 207 |
+
"print(\"=\"*60)\n",
|
| 208 |
+
"print(\"TER (TOKEN ERROR RATE) DISTRIBUTION STATISTICS\")\n",
|
| 209 |
+
"print(\"=\"*60)\n",
|
| 210 |
+
"print(f\"Mean TER: {np.mean(ter_values):.4f} ({np.mean(ter_values)*100:.2f}%)\")\n",
|
| 211 |
+
"print(f\"Median TER: {np.median(ter_values):.4f} ({np.median(ter_values)*100:.2f}%)\")\n",
|
| 212 |
+
"print(f\"Std Dev: {np.std(ter_values):.4f}\")\n",
|
| 213 |
+
"print(f\"Min TER: {np.min(ter_values):.4f} ({np.min(ter_values)*100:.2f}%)\")\n",
|
| 214 |
+
"print(f\"Max TER: {np.max(ter_values):.4f} ({np.max(ter_values)*100:.2f}%)\")\n",
|
| 215 |
+
"print(f\"\\nPercentiles:\")\n",
|
| 216 |
+
"print(f\" 25th: {np.percentile(ter_values, 25):.4f}\")\n",
|
| 217 |
+
"print(f\" 50th: {np.percentile(ter_values, 50):.4f}\")\n",
|
| 218 |
+
"print(f\" 75th: {np.percentile(ter_values, 75):.4f}\")\n",
|
| 219 |
+
"print(f\" 90th: {np.percentile(ter_values, 90):.4f}\")\n",
|
| 220 |
+
"print(f\" 95th: {np.percentile(ter_values, 95):.4f}\")\n",
|
| 221 |
+
"print(f\" 98th: {np.percentile(ter_values, 98):.4f}\")\n",
|
| 222 |
+
"\n",
|
| 223 |
+
"# Count perfect predictions\n",
|
| 224 |
+
"perfect_predictions = sum(1 for ter in ter_values if ter == 0.0)\n",
|
| 225 |
+
"print(f\"\\nPerfect predictions (TER=0%): {perfect_predictions}/{len(ter_values)} ({perfect_predictions/len(ter_values)*100:.2f}%)\")\n",
|
| 226 |
+
"\n",
|
| 227 |
+
"# Count predictions with TER < 0.5 (less than 50% error)\n",
|
| 228 |
+
"good_predictions = sum(1 for ter in ter_values if ter < 0.5)\n",
|
| 229 |
+
"print(f\"Good predictions (TER<50%): {good_predictions}/{len(ter_values)} ({good_predictions/len(ter_values)*100:.2f}%)\")"
|
| 230 |
+
]
|
| 231 |
+
}
|
| 232 |
+
],
|
| 233 |
+
"metadata": {
|
| 234 |
+
"kernelspec": {
|
| 235 |
+
"display_name": ".venv",
|
| 236 |
+
"language": "python",
|
| 237 |
+
"name": "python3"
|
| 238 |
+
},
|
| 239 |
+
"language_info": {
|
| 240 |
+
"codemirror_mode": {
|
| 241 |
+
"name": "ipython",
|
| 242 |
+
"version": 3
|
| 243 |
+
},
|
| 244 |
+
"file_extension": ".py",
|
| 245 |
+
"mimetype": "text/x-python",
|
| 246 |
+
"name": "python",
|
| 247 |
+
"nbconvert_exporter": "python",
|
| 248 |
+
"pygments_lexer": "ipython3",
|
| 249 |
+
"version": "3.13.6"
|
| 250 |
+
}
|
| 251 |
+
},
|
| 252 |
+
"nbformat": 4,
|
| 253 |
+
"nbformat_minor": 5
|
| 254 |
+
}
|
training/cuneiform_ocr_grpo.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
training/cuneiform_ocr_sft.ipynb
ADDED
|
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "fd2siqgrq6w",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"# CRITICAL: This patch MUST run BEFORE importing unsloth!\n",
|
| 11 |
+
"# Fix Unsloth's gradient checkpointing for models with keyword-only forward arguments\n",
|
| 12 |
+
"\n",
|
| 13 |
+
"import sys\n",
|
| 14 |
+
"import torch\n",
|
| 15 |
+
"import os\n",
|
| 16 |
+
"\n",
|
| 17 |
+
"# Import unsloth_zoo.peft_utils first so it's in sys.modules\n",
|
| 18 |
+
"os.environ[\"UNSLOTH_IS_PRESENT\"] = \"1\"\n",
|
| 19 |
+
"import unsloth_zoo.peft_utils\n",
|
| 20 |
+
"\n",
|
| 21 |
+
"# Now patch the function before anything else imports it\n",
|
| 22 |
+
"def patched_requires_grad_pre_hook(module, input):\n",
|
| 23 |
+
" \"\"\"Patched hook that handles empty input tuples gracefully\"\"\"\n",
|
| 24 |
+
" type_input = type(input)\n",
|
| 25 |
+
" if type_input is torch.Tensor:\n",
|
| 26 |
+
" input.requires_grad_(True)\n",
|
| 27 |
+
" elif type_input is tuple or type_input is list:\n",
|
| 28 |
+
" if len(input) == 0:\n",
|
| 29 |
+
" # Empty tuple = keyword-only args. This is fine, gradients flow through kwargs\n",
|
| 30 |
+
" return\n",
|
| 31 |
+
" if len(input) > 0 and torch.is_floating_point(input[0]):\n",
|
| 32 |
+
" input[0].requires_grad_(True)\n",
|
| 33 |
+
"\n",
|
| 34 |
+
"# Get the original function\n",
|
| 35 |
+
"original_func = sys.modules['unsloth_zoo.peft_utils'].requires_grad_for_gradient_checkpointing\n",
|
| 36 |
+
"\n",
|
| 37 |
+
"# Create wrapper that uses our patched hook\n",
|
| 38 |
+
"def patched_requires_grad_for_gradient_checkpointing(model):\n",
|
| 39 |
+
" \"\"\"Wrapper that calls original but uses patched hook\"\"\"\n",
|
| 40 |
+
" import re\n",
|
| 41 |
+
" import inspect\n",
|
| 42 |
+
" \n",
|
| 43 |
+
" # Define the other helper functions we need\n",
|
| 44 |
+
" def requires_grad_post_hook(module, input, output):\n",
|
| 45 |
+
" try:\n",
|
| 46 |
+
" if hasattr(output, \"loss\") and output.loss is not None:\n",
|
| 47 |
+
" output.loss.requires_grad_(True)\n",
|
| 48 |
+
" elif hasattr(output, \"logits\") and output.logits is not None:\n",
|
| 49 |
+
" output.logits.requires_grad_(True)\n",
|
| 50 |
+
" elif type(output) is torch.Tensor:\n",
|
| 51 |
+
" output.requires_grad_(True)\n",
|
| 52 |
+
" except: pass\n",
|
| 53 |
+
" \n",
|
| 54 |
+
" def register_other_hooks(hook_name, hook_func_name, module, hooks_dict_name):\n",
|
| 55 |
+
" if not hasattr(module, hooks_dict_name): return\n",
|
| 56 |
+
" hooks_dict = getattr(module, hooks_dict_name)\n",
|
| 57 |
+
" for hook_id, hook_fn in list(hooks_dict.items()):\n",
|
| 58 |
+
" if hook_func_name in str(hook_fn):\n",
|
| 59 |
+
" del hooks_dict[hook_id]\n",
|
| 60 |
+
" \n",
|
| 61 |
+
" # Find first parameter with requires_grad\n",
|
| 62 |
+
" param = None\n",
|
| 63 |
+
" for name, param in model.named_parameters():\n",
|
| 64 |
+
" if param.requires_grad: break\n",
|
| 65 |
+
" if param is None: return\n",
|
| 66 |
+
" \n",
|
| 67 |
+
" name = re.sub(r\"\\.([\\d]{1,})\\.\", r\"[\\1].\", name)\n",
|
| 68 |
+
" name_components = name.split(\".\")\n",
|
| 69 |
+
" if len(name_components) == 0:\n",
|
| 70 |
+
" raise RuntimeError(\"Unsloth: Model has 0 layers?\")\n",
|
| 71 |
+
" \n",
|
| 72 |
+
" # Find the module to hook\n",
|
| 73 |
+
" final_where = None\n",
|
| 74 |
+
" for j in range(len(name_components)-1, 0, -1):\n",
|
| 75 |
+
" name_curr = name_components[j]\n",
|
| 76 |
+
" name_pre = \"model.\" + \".\".join(name_components[:j])\n",
|
| 77 |
+
" if re.search(r\"\\[[\\d]{1,}\\]\", name_pre): continue\n",
|
| 78 |
+
" module = eval(name_pre)\n",
|
| 79 |
+
" if hasattr(module, \"forward\"):\n",
|
| 80 |
+
" try: forward = inspect.getsource(module.forward)\n",
|
| 81 |
+
" except: continue\n",
|
| 82 |
+
" if f\"self.{name_curr}(\" in forward:\n",
|
| 83 |
+
" final_where = j + 1\n",
|
| 84 |
+
" break\n",
|
| 85 |
+
" module_list = re.sub(r\"\\[[\\d]{1,}\\]\", \"\", name_curr)\n",
|
| 86 |
+
" if f\"in self.{module_list}:\" in forward:\n",
|
| 87 |
+
" final_where = j\n",
|
| 88 |
+
" break\n",
|
| 89 |
+
" elif re.search(r\"for [^\\s]{3,} in self\\.\" + module_list, forward):\n",
|
| 90 |
+
" final_where = j\n",
|
| 91 |
+
" break\n",
|
| 92 |
+
" \n",
|
| 93 |
+
" if final_where is None:\n",
|
| 94 |
+
" for module_name, module in model.named_modules():\n",
|
| 95 |
+
" if not hasattr(module, \"get_input_embeddings\"): break\n",
|
| 96 |
+
" register_other_hooks(\"requires_grad_post_hook\", \"requires_grad_post_hook\", module, \"_forward_hooks\")\n",
|
| 97 |
+
" module.register_forward_hook(requires_grad_post_hook)\n",
|
| 98 |
+
" return\n",
|
| 99 |
+
" \n",
|
| 100 |
+
" module_name = \"model.\" + \".\".join(name_components[:final_where])\n",
|
| 101 |
+
" module = eval(module_name)\n",
|
| 102 |
+
" \n",
|
| 103 |
+
" if hasattr(module, \"config\") and module.config.__class__.__name__ in (\"CLIPVisionConfig\", \"SiglipVisionConfig\"):\n",
|
| 104 |
+
" old_module = model\n",
|
| 105 |
+
" for module_name, module in model.named_modules():\n",
|
| 106 |
+
" if not hasattr(module, \"get_input_embeddings\"): break\n",
|
| 107 |
+
" old_module = module\n",
|
| 108 |
+
" module = old_module\n",
|
| 109 |
+
" \n",
|
| 110 |
+
" print(f\"Unsloth: Making `{module_name}` require gradients\")\n",
|
| 111 |
+
" \n",
|
| 112 |
+
" # Try post-hook first\n",
|
| 113 |
+
" if hasattr(module, \"get_input_embeddings\"):\n",
|
| 114 |
+
" try:\n",
|
| 115 |
+
" module = module.get_input_embeddings()\n",
|
| 116 |
+
" register_other_hooks(\"requires_grad_post_hook\", \"requires_grad_post_hook\", module, \"_forward_hooks\")\n",
|
| 117 |
+
" module.register_forward_hook(requires_grad_post_hook)\n",
|
| 118 |
+
" return\n",
|
| 119 |
+
" except: pass\n",
|
| 120 |
+
" \n",
|
| 121 |
+
" # Use our patched pre-hook\n",
|
| 122 |
+
" register_other_hooks(\"requires_grad_pre_hook\", \"requires_grad_pre_hook\", module, \"_forward_pre_hooks\")\n",
|
| 123 |
+
" module.register_forward_pre_hook(patched_requires_grad_pre_hook)\n",
|
| 124 |
+
"\n",
|
| 125 |
+
"# Replace in sys.modules\n",
|
| 126 |
+
"sys.modules['unsloth_zoo.peft_utils'].requires_grad_for_gradient_checkpointing = patched_requires_grad_for_gradient_checkpointing\n",
|
| 127 |
+
"\n",
|
| 128 |
+
"print(\"✓ Patched Unsloth gradient checkpointing BEFORE imports\")"
|
| 129 |
+
]
|
| 130 |
+
},
|
| 131 |
+
{
|
| 132 |
+
"cell_type": "code",
|
| 133 |
+
"execution_count": null,
|
| 134 |
+
"id": "c2c30bc6",
|
| 135 |
+
"metadata": {},
|
| 136 |
+
"outputs": [],
|
| 137 |
+
"source": [
|
| 138 |
+
"from unsloth import FastVisionModel\n",
|
| 139 |
+
"from unsloth.trainer import UnslothVisionDataCollator"
|
| 140 |
+
]
|
| 141 |
+
},
|
| 142 |
+
{
|
| 143 |
+
"cell_type": "code",
|
| 144 |
+
"execution_count": null,
|
| 145 |
+
"id": "4326b62e",
|
| 146 |
+
"metadata": {},
|
| 147 |
+
"outputs": [],
|
| 148 |
+
"source": [
|
| 149 |
+
"import torch\n",
|
| 150 |
+
"from PIL import Image\n",
|
| 151 |
+
"from transformers import AutoModel, AutoProcessor\n",
|
| 152 |
+
"from trl import SFTTrainer, SFTConfig"
|
| 153 |
+
]
|
| 154 |
+
},
|
| 155 |
+
{
|
| 156 |
+
"cell_type": "code",
|
| 157 |
+
"execution_count": null,
|
| 158 |
+
"id": "d5e899ca",
|
| 159 |
+
"metadata": {},
|
| 160 |
+
"outputs": [],
|
| 161 |
+
"source": [
|
| 162 |
+
"from get_cdli_dataset import atf_converter, get_dataset, IMG_CACHE\n",
|
| 163 |
+
"\n",
|
| 164 |
+
"# Load dataset\n",
|
| 165 |
+
"dataset = get_dataset()\n",
|
| 166 |
+
"\n",
|
| 167 |
+
"train_dataset = dataset[\"train\"]\n",
|
| 168 |
+
"test_dataset = dataset[\"test\"]\n",
|
| 169 |
+
"\n",
|
| 170 |
+
"print(train_dataset, test_dataset)\n"
|
| 171 |
+
]
|
| 172 |
+
},
|
| 173 |
+
{
|
| 174 |
+
"cell_type": "code",
|
| 175 |
+
"execution_count": null,
|
| 176 |
+
"id": "9e0aa56b",
|
| 177 |
+
"metadata": {},
|
| 178 |
+
"outputs": [],
|
| 179 |
+
"source": [
|
| 180 |
+
"# Load processor and model\n",
|
| 181 |
+
"model, tokenizer = FastVisionModel.from_pretrained(\n",
|
| 182 |
+
" \"PaddlePaddle/PaddleOCR-VL\",\n",
|
| 183 |
+
" cache_dir = \"./hf_cache/models\",\n",
|
| 184 |
+
" trust_remote_code = True,\n",
|
| 185 |
+
" load_in_4bit = False,\n",
|
| 186 |
+
" auto_model = AutoModel,\n",
|
| 187 |
+
" full_finetuning=True,\n",
|
| 188 |
+
" unsloth_force_compile = True,\n",
|
| 189 |
+
" use_gradient_checkpointing = \"unsloth\",\n",
|
| 190 |
+
" max_seq_length = 16000,\n",
|
| 191 |
+
")\n",
|
| 192 |
+
"\n",
|
| 193 |
+
"processor = AutoProcessor.from_pretrained(\n",
|
| 194 |
+
" \"PaddlePaddle/PaddleOCR-VL\",\n",
|
| 195 |
+
" cache_dir=\"./hf_cache/models\",\n",
|
| 196 |
+
" trust_remote_code=True,\n",
|
| 197 |
+
")\n",
|
| 198 |
+
"\n",
|
| 199 |
+
"processor.tokenizer = tokenizer\n"
|
| 200 |
+
]
|
| 201 |
+
},
|
| 202 |
+
{
|
| 203 |
+
"cell_type": "code",
|
| 204 |
+
"execution_count": null,
|
| 205 |
+
"id": "28656983",
|
| 206 |
+
"metadata": {},
|
| 207 |
+
"outputs": [],
|
| 208 |
+
"source": [
|
| 209 |
+
"used_signs = set()\n",
|
| 210 |
+
"for example in train_dataset:\n",
|
| 211 |
+
" parsed = atf_converter.parse(example[\"atf\"])\n",
|
| 212 |
+
" used_signs.update(parsed.get_used_signs())\n",
|
| 213 |
+
"for example in test_dataset:\n",
|
| 214 |
+
" parsed = atf_converter.parse(example[\"atf\"])\n",
|
| 215 |
+
" used_signs.update(parsed.get_used_signs())\n",
|
| 216 |
+
"\n",
|
| 217 |
+
"print(f\"Base model vocab size: {len(processor.tokenizer)}\")\n",
|
| 218 |
+
"\n",
|
| 219 |
+
"# Add the cuneiform to the model vocab\n",
|
| 220 |
+
"num_added_tokens = processor.tokenizer.add_tokens(list(used_signs))\n",
|
| 221 |
+
"num_added_special_tokens = processor.tokenizer.add_special_tokens(\n",
|
| 222 |
+
" {\n",
|
| 223 |
+
" \"additional_special_tokens\": [f\"@{face}\" for face in atf_converter.ALL_FACES]\n",
|
| 224 |
+
" + atf_converter.SPECIAL_TOKENS\n",
|
| 225 |
+
" },\n",
|
| 226 |
+
" replace_additional_special_tokens=False,\n",
|
| 227 |
+
")\n",
|
| 228 |
+
"\n",
|
| 229 |
+
"print(\n",
|
| 230 |
+
" f\"Added {num_added_tokens} tokens and {num_added_special_tokens} special tokens to tokenizer\"\n",
|
| 231 |
+
")\n",
|
| 232 |
+
"\n",
|
| 233 |
+
"# Assign the average to the new token embeddings\n",
|
| 234 |
+
"model.resize_token_embeddings(len(processor.tokenizer))\n",
|
| 235 |
+
"\n",
|
| 236 |
+
"print(f\"New model vocab size: {len(processor.tokenizer)}\")"
|
| 237 |
+
]
|
| 238 |
+
},
|
| 239 |
+
{
|
| 240 |
+
"cell_type": "code",
|
| 241 |
+
"execution_count": null,
|
| 242 |
+
"id": "5100b97c",
|
| 243 |
+
"metadata": {},
|
| 244 |
+
"outputs": [],
|
| 245 |
+
"source": [
|
| 246 |
+
"# Configure training\n",
|
| 247 |
+
"sft_training_args = SFTConfig(\n",
|
| 248 |
+
" output_dir=\"./outputs/sft\",\n",
|
| 249 |
+
" # max_steps=50, # Remove for full run\n",
|
| 250 |
+
" num_train_epochs=2,\n",
|
| 251 |
+
" per_device_train_batch_size=2,\n",
|
| 252 |
+
" per_device_eval_batch_size=2,\n",
|
| 253 |
+
" gradient_accumulation_steps=1,\n",
|
| 254 |
+
" learning_rate=2e-5,\n",
|
| 255 |
+
" optim=\"adamw_8bit\",\n",
|
| 256 |
+
" warmup_ratio=0.05,\n",
|
| 257 |
+
" weight_decay=0.001,\n",
|
| 258 |
+
" lr_scheduler_type=\"linear\",\n",
|
| 259 |
+
" bf16=True,\n",
|
| 260 |
+
" save_strategy=\"steps\",\n",
|
| 261 |
+
" save_steps=200,\n",
|
| 262 |
+
" eval_strategy=\"steps\",\n",
|
| 263 |
+
" eval_steps=1000,\n",
|
| 264 |
+
" logging_steps=1,\n",
|
| 265 |
+
" report_to=\"none\",\n",
|
| 266 |
+
" dataloader_num_workers=0,\n",
|
| 267 |
+
" # You MUST put the below items for vision finetuning:\n",
|
| 268 |
+
" remove_unused_columns=False,\n",
|
| 269 |
+
" dataset_text_field=\"\",\n",
|
| 270 |
+
" dataset_kwargs={\"skip_prepare_dataset\": True},\n",
|
| 271 |
+
" max_length=16000,\n",
|
| 272 |
+
")\n",
|
| 273 |
+
"\n",
|
| 274 |
+
"# Initialize trainer\n",
|
| 275 |
+
"sft_trainer = SFTTrainer(\n",
|
| 276 |
+
" model=model,\n",
|
| 277 |
+
" processing_class=processor,\n",
|
| 278 |
+
" data_collator=UnslothVisionDataCollator(\n",
|
| 279 |
+
" model,\n",
|
| 280 |
+
" processor,\n",
|
| 281 |
+
" train_on_responses_only=False, # Fixed: was masking all tokens with True\n",
|
| 282 |
+
" instruction_part=\"User: \",\n",
|
| 283 |
+
" response_part=\"Assistant: \",\n",
|
| 284 |
+
" pad_to_multiple_of=2,\n",
|
| 285 |
+
" resize_dimension=\"max\",\n",
|
| 286 |
+
" formatting_func=lambda example: {\n",
|
| 287 |
+
" \"images\": [\n",
|
| 288 |
+
" Image.open(IMG_CACHE / f\"P{str(example['id']).rjust(6, '0')}.jpg\")\n",
|
| 289 |
+
" ],\n",
|
| 290 |
+
" \"messages\": [\n",
|
| 291 |
+
" # Add user message with image and task prompt\n",
|
| 292 |
+
" {\n",
|
| 293 |
+
" \"role\": \"user\",\n",
|
| 294 |
+
" \"content\": [\n",
|
| 295 |
+
" {\n",
|
| 296 |
+
" \"type\": \"image\",\n",
|
| 297 |
+
" \"image\": Image.open(\n",
|
| 298 |
+
" IMG_CACHE / f\"P{str(example['id']).rjust(6, '0')}.jpg\"\n",
|
| 299 |
+
" ),\n",
|
| 300 |
+
" },\n",
|
| 301 |
+
" {\"type\": \"text\", \"text\": \"OCR:\"},\n",
|
| 302 |
+
" ],\n",
|
| 303 |
+
" },\n",
|
| 304 |
+
" # Add assistant message with completion text\n",
|
| 305 |
+
" {\n",
|
| 306 |
+
" \"role\": \"assistant\",\n",
|
| 307 |
+
" \"content\": [{\"type\": \"text\", \"text\": example[\"unicode\"]}],\n",
|
| 308 |
+
" },\n",
|
| 309 |
+
" ],\n",
|
| 310 |
+
" },\n",
|
| 311 |
+
" ),\n",
|
| 312 |
+
" args=sft_training_args,\n",
|
| 313 |
+
" train_dataset=train_dataset,\n",
|
| 314 |
+
" eval_dataset=test_dataset,\n",
|
| 315 |
+
")"
|
| 316 |
+
]
|
| 317 |
+
},
|
| 318 |
+
{
|
| 319 |
+
"cell_type": "code",
|
| 320 |
+
"execution_count": null,
|
| 321 |
+
"id": "97e8455e",
|
| 322 |
+
"metadata": {},
|
| 323 |
+
"outputs": [],
|
| 324 |
+
"source": [
|
| 325 |
+
"gpu_stats = torch.cuda.get_device_properties(0)\n",
|
| 326 |
+
"start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
|
| 327 |
+
"max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n",
|
| 328 |
+
"print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n",
|
| 329 |
+
"print(f\"{start_gpu_memory} GB of memory reserved.\")"
|
| 330 |
+
]
|
| 331 |
+
},
|
| 332 |
+
{
|
| 333 |
+
"cell_type": "code",
|
| 334 |
+
"execution_count": null,
|
| 335 |
+
"id": "e6103bbe",
|
| 336 |
+
"metadata": {},
|
| 337 |
+
"outputs": [],
|
| 338 |
+
"source": [
|
| 339 |
+
"sft_trainer_stats = sft_trainer.train(resume_from_checkpoint=False)"
|
| 340 |
+
]
|
| 341 |
+
},
|
| 342 |
+
{
|
| 343 |
+
"cell_type": "code",
|
| 344 |
+
"execution_count": null,
|
| 345 |
+
"id": "36dc79b9",
|
| 346 |
+
"metadata": {},
|
| 347 |
+
"outputs": [],
|
| 348 |
+
"source": [
|
| 349 |
+
"used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
|
| 350 |
+
"used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n",
|
| 351 |
+
"used_percentage = round(used_memory / max_memory * 100, 3)\n",
|
| 352 |
+
"lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n",
|
| 353 |
+
"print(f\"{sft_trainer_stats.metrics['train_runtime']} seconds used for training.\")\n",
|
| 354 |
+
"print(\n",
|
| 355 |
+
" f\"{round(sft_trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\"\n",
|
| 356 |
+
")\n",
|
| 357 |
+
"print(f\"Peak reserved memory = {used_memory} GB.\")\n",
|
| 358 |
+
"print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n",
|
| 359 |
+
"print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n",
|
| 360 |
+
"print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")"
|
| 361 |
+
]
|
| 362 |
+
},
|
| 363 |
+
{
|
| 364 |
+
"cell_type": "code",
|
| 365 |
+
"execution_count": null,
|
| 366 |
+
"id": "50bdf718",
|
| 367 |
+
"metadata": {},
|
| 368 |
+
"outputs": [],
|
| 369 |
+
"source": [
|
| 370 |
+
"# Save model\n",
|
| 371 |
+
"processor.save_pretrained(sft_training_args.output_dir)\n",
|
| 372 |
+
"model.save_pretrained(sft_training_args.output_dir, processor)\n"
|
| 373 |
+
]
|
| 374 |
+
}
|
| 375 |
+
],
|
| 376 |
+
"metadata": {
|
| 377 |
+
"kernelspec": {
|
| 378 |
+
"display_name": ".venv",
|
| 379 |
+
"language": "python",
|
| 380 |
+
"name": "python3"
|
| 381 |
+
},
|
| 382 |
+
"language_info": {
|
| 383 |
+
"codemirror_mode": {
|
| 384 |
+
"name": "ipython",
|
| 385 |
+
"version": 3
|
| 386 |
+
},
|
| 387 |
+
"file_extension": ".py",
|
| 388 |
+
"mimetype": "text/x-python",
|
| 389 |
+
"name": "python",
|
| 390 |
+
"nbconvert_exporter": "python",
|
| 391 |
+
"pygments_lexer": "ipython3",
|
| 392 |
+
"version": "3.13.6"
|
| 393 |
+
}
|
| 394 |
+
},
|
| 395 |
+
"nbformat": 4,
|
| 396 |
+
"nbformat_minor": 5
|
| 397 |
+
}
|
training/get_cdli_dataset.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import concurrent.futures
|
| 2 |
+
import math
|
| 3 |
+
import time
|
| 4 |
+
from io import BytesIO
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import requests
|
| 8 |
+
from convert_atf import ATFConverter
|
| 9 |
+
from datasets import Dataset
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from tqdm.auto import tqdm
|
| 12 |
+
|
| 13 |
+
atf_converter = ATFConverter()
|
| 14 |
+
|
| 15 |
+
IMG_CACHE = Path("./data/cdli_images")
|
| 16 |
+
IMG_CACHE.mkdir(exist_ok=True, parents=True)
|
| 17 |
+
MAX_IMG_RES = 2048
|
| 18 |
+
DOWNLOAD_MODE = False
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def smart_resize(
|
| 22 |
+
height: int,
|
| 23 |
+
width: int,
|
| 24 |
+
factor: int = 28,
|
| 25 |
+
min_pixels: int = 28 * 28 * 130,
|
| 26 |
+
max_pixels: int = 28 * 28 * 1280,
|
| 27 |
+
):
|
| 28 |
+
"""Rescales the image so that the following conditions are met:
|
| 29 |
+
|
| 30 |
+
1. Both dimensions (height and width) are divisible by 'factor'.
|
| 31 |
+
|
| 32 |
+
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
| 33 |
+
|
| 34 |
+
3. The aspect ratio of the image is maintained as closely as possible.
|
| 35 |
+
|
| 36 |
+
"""
|
| 37 |
+
# if height < factor or width < factor:
|
| 38 |
+
# raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
|
| 39 |
+
# if int(height < factor//4) + int(width < factor//4):
|
| 40 |
+
# raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor//4}")
|
| 41 |
+
|
| 42 |
+
if height < factor:
|
| 43 |
+
print(f"smart_resize: height={height} < factor={factor}, reset height=factor")
|
| 44 |
+
width = round((width * factor) / height)
|
| 45 |
+
height = factor
|
| 46 |
+
|
| 47 |
+
if width < factor:
|
| 48 |
+
print(f"smart_resize: width={width} < factor={factor}, reset width=factor")
|
| 49 |
+
height = round((height * factor) / width)
|
| 50 |
+
width = factor
|
| 51 |
+
|
| 52 |
+
if max(height, width) / min(height, width) > 200:
|
| 53 |
+
raise ValueError(
|
| 54 |
+
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
|
| 55 |
+
)
|
| 56 |
+
h_bar = round(height / factor) * factor
|
| 57 |
+
w_bar = round(width / factor) * factor
|
| 58 |
+
if h_bar * w_bar > max_pixels:
|
| 59 |
+
beta = math.sqrt((height * width) / max_pixels)
|
| 60 |
+
h_bar = math.floor(height / beta / factor) * factor
|
| 61 |
+
w_bar = math.floor(width / beta / factor) * factor
|
| 62 |
+
elif h_bar * w_bar < min_pixels:
|
| 63 |
+
beta = math.sqrt(min_pixels / (height * width))
|
| 64 |
+
h_bar = math.ceil(height * beta / factor) * factor
|
| 65 |
+
w_bar = math.ceil(width * beta / factor) * factor
|
| 66 |
+
return h_bar, w_bar
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def resize_image(img_path):
|
| 70 |
+
with Image.open(img_path).convert("RGB") as image:
|
| 71 |
+
width, height = image.size
|
| 72 |
+
# Scale down if larger than MAX_IMG_RES
|
| 73 |
+
if width > MAX_IMG_RES or height > MAX_IMG_RES:
|
| 74 |
+
scale = MAX_IMG_RES / max(width, height)
|
| 75 |
+
height = int(height * scale)
|
| 76 |
+
width = int(width * scale)
|
| 77 |
+
# Always ensure dimensions are multiples of 28 for vision model compatibility
|
| 78 |
+
new_height, new_width = smart_resize(height, width)
|
| 79 |
+
if new_height != image.height or new_width != image.width:
|
| 80 |
+
image = image.resize((new_width, new_height), Image.LANCZOS)
|
| 81 |
+
image.save(img_path)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def resize_cached_images():
|
| 85 |
+
img_paths = list(IMG_CACHE.glob("*.jpg"))
|
| 86 |
+
pbar = tqdm(img_paths)
|
| 87 |
+
|
| 88 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=20) as executor:
|
| 89 |
+
futures = [executor.submit(resize_image, img_path) for img_path in img_paths]
|
| 90 |
+
for future in concurrent.futures.as_completed(futures):
|
| 91 |
+
pbar.update(1)
|
| 92 |
+
|
| 93 |
+
pbar.close()
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def get_image(id: int):
|
| 97 |
+
file_name = f"P{str(id).rjust(6, '0')}.jpg"
|
| 98 |
+
url = f"https://cdli.earth/dl/photo/{file_name}"
|
| 99 |
+
cache_file = IMG_CACHE / file_name
|
| 100 |
+
|
| 101 |
+
try:
|
| 102 |
+
if cache_file.exists():
|
| 103 |
+
tqdm.write(f"Found {file_name} in cache")
|
| 104 |
+
image = Image.open(cache_file).convert("RGB")
|
| 105 |
+
else:
|
| 106 |
+
response = requests.get(url, timeout=5)
|
| 107 |
+
response.raise_for_status()
|
| 108 |
+
image = Image.open(BytesIO(response.content)).convert("RGB")
|
| 109 |
+
|
| 110 |
+
tqdm.write(f"Downloaded {file_name}")
|
| 111 |
+
|
| 112 |
+
width, height = image.size
|
| 113 |
+
# Scale down if larger than MAX_IMG_RES
|
| 114 |
+
if width > MAX_IMG_RES or height > MAX_IMG_RES:
|
| 115 |
+
scale = MAX_IMG_RES / max(width, height)
|
| 116 |
+
height = int(height * scale)
|
| 117 |
+
width = int(width * scale)
|
| 118 |
+
# Always ensure dimensions are multiples of 28 for vision model compatibility
|
| 119 |
+
new_height, new_width = smart_resize(height, width)
|
| 120 |
+
if new_height != image.height or new_width != image.width:
|
| 121 |
+
image = image.resize((new_width, new_height), Image.LANCZOS)
|
| 122 |
+
|
| 123 |
+
image.save(cache_file)
|
| 124 |
+
time.sleep(0.02) # Rate limiting
|
| 125 |
+
except requests.exceptions.Timeout:
|
| 126 |
+
tqdm.write(f"Timeout downloading {file_name}")
|
| 127 |
+
return None
|
| 128 |
+
except requests.exceptions.RequestException as e:
|
| 129 |
+
tqdm.write(f"Error downloading {file_name}: {e}")
|
| 130 |
+
return None
|
| 131 |
+
except Exception as e:
|
| 132 |
+
tqdm.write(f"Error processing {file_name}: {type(e).__name__}: {e}")
|
| 133 |
+
return None
|
| 134 |
+
|
| 135 |
+
return image
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def count_repetitions(text: str) -> int:
|
| 139 |
+
"""
|
| 140 |
+
Count the total number of repeated token occurrences in a sequence.
|
| 141 |
+
E.g., 122233 has 3 repetitions (2 appears 2 extra times, 3 appears 1 extra time).
|
| 142 |
+
"""
|
| 143 |
+
if len(text) < 2:
|
| 144 |
+
return 0
|
| 145 |
+
|
| 146 |
+
return len(text) - len(set(text))
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def get_dataset(file="./data/cdli_dataset.parquet"):
|
| 150 |
+
if Path(file).exists():
|
| 151 |
+
return Dataset.from_parquet(file).train_test_split(test_size=1000, seed=42)
|
| 152 |
+
|
| 153 |
+
# 1. Get all the ids from cdli.atf (source: https://github.com/cdli-gh/data/raw/refs/heads/master/cdliatf_unblocked.atf)
|
| 154 |
+
cdli_raw = Path("./data/cdli.atf").read_text(encoding="utf-8").split("&P")
|
| 155 |
+
cdli_filtered = [
|
| 156 |
+
section.strip()
|
| 157 |
+
for section in cdli_raw
|
| 158 |
+
if section.strip() # Ignore empty sections
|
| 159 |
+
and "@tablet" in section # Only include tablets
|
| 160 |
+
and len(section) > 50 # Ignore short sections
|
| 161 |
+
and len(section) < 1000 # Ignore long sections
|
| 162 |
+
and any(lang in section for lang in ["sux", "akk"]) # Limit supported languages
|
| 163 |
+
]
|
| 164 |
+
|
| 165 |
+
ids = []
|
| 166 |
+
atfs = []
|
| 167 |
+
unicodes = []
|
| 168 |
+
|
| 169 |
+
for section in tqdm(cdli_filtered, desc="Parsing CDLI dump"):
|
| 170 |
+
# Split section at first space to get the ID, ignore if not parseable
|
| 171 |
+
lines = section.splitlines()
|
| 172 |
+
id_part = lines[0].split("=")[0].strip()
|
| 173 |
+
if not id_part.isdigit():
|
| 174 |
+
continue
|
| 175 |
+
|
| 176 |
+
atf = "\n".join(
|
| 177 |
+
[
|
| 178 |
+
line
|
| 179 |
+
for line in lines[1:]
|
| 180 |
+
if not (
|
| 181 |
+
line.startswith("# ")
|
| 182 |
+
or line.startswith(">>")
|
| 183 |
+
or line.startswith("<<")
|
| 184 |
+
or line.startswith("||")
|
| 185 |
+
)
|
| 186 |
+
]
|
| 187 |
+
)
|
| 188 |
+
parsed = atf_converter.parse(atf)
|
| 189 |
+
if parsed is None:
|
| 190 |
+
tqdm.write(f"=====\033[91m {id_part} skip (parse fail) \033[0m=====")
|
| 191 |
+
continue
|
| 192 |
+
|
| 193 |
+
unicode_parts = [
|
| 194 |
+
f"@{face}\n{parsed.get_unicode(face)}"
|
| 195 |
+
for face in parsed.ALL_FACES
|
| 196 |
+
if parsed.get_unicode(face)
|
| 197 |
+
]
|
| 198 |
+
|
| 199 |
+
# Skip massive tablets
|
| 200 |
+
unicode_len = sum([len(part) for part in unicode_parts])
|
| 201 |
+
if unicode_len > 300 or unicode_len < 20:
|
| 202 |
+
tqdm.write(f"=====\033[91m {id_part} skip (too short/long) \033[0m=====")
|
| 203 |
+
continue
|
| 204 |
+
# Skip tablets that are poorly translated to unicode
|
| 205 |
+
if sum([part.count("x") for part in unicode_parts]) >= 2:
|
| 206 |
+
tqdm.write(f"=====\033[91m {id_part} skip (missing symbols) \033[0m=====")
|
| 207 |
+
continue
|
| 208 |
+
|
| 209 |
+
unicode = "\n".join(unicode_parts)
|
| 210 |
+
|
| 211 |
+
# Drop the super repetitive admin tablets (model ends up getting stuck repeating the common phrases)
|
| 212 |
+
if count_repetitions(unicode) / len(unicode) > 0.7:
|
| 213 |
+
tqdm.write(f"=====\033[91m {id_part} skip (too repetitive) \033[0m=====")
|
| 214 |
+
continue
|
| 215 |
+
|
| 216 |
+
# Ignore if we don't have an image for this atf
|
| 217 |
+
if DOWNLOAD_MODE:
|
| 218 |
+
image = get_image(int(id_part))
|
| 219 |
+
elif (IMG_CACHE / f"P{str(int(id_part)).rjust(6, '0')}.jpg").exists():
|
| 220 |
+
image = Image.open(
|
| 221 |
+
IMG_CACHE / f"P{str(int(id_part)).rjust(6, '0')}.jpg"
|
| 222 |
+
).convert("RGB")
|
| 223 |
+
else:
|
| 224 |
+
tqdm.write(f"=====\033[91m {id_part} skip (no img) \033[0m=====")
|
| 225 |
+
continue
|
| 226 |
+
|
| 227 |
+
if not image:
|
| 228 |
+
tqdm.write(f"=====\033[91m {id_part} skip (no img) \033[0m=====")
|
| 229 |
+
continue
|
| 230 |
+
|
| 231 |
+
# Drop low res, B&W, or non-isolated background
|
| 232 |
+
try:
|
| 233 |
+
if min(image.size) < 100:
|
| 234 |
+
tqdm.write(f"=====\033[91m {id_part} skip (lowres) \033[0m=====")
|
| 235 |
+
continue
|
| 236 |
+
|
| 237 |
+
scale = 150 / image.height
|
| 238 |
+
small_image = image.resize(
|
| 239 |
+
(int(image.width * scale), int(image.height * scale)), Image.LANCZOS
|
| 240 |
+
)
|
| 241 |
+
pixels = list(small_image.getdata())
|
| 242 |
+
small_image.close()
|
| 243 |
+
image.close()
|
| 244 |
+
|
| 245 |
+
bw_pixels = sum(1 for r, g, b in pixels if r == g == b)
|
| 246 |
+
bw_percent = bw_pixels / len(pixels)
|
| 247 |
+
if bw_percent > 0.95 or bw_percent < 0.1:
|
| 248 |
+
tqdm.write(
|
| 249 |
+
f"=====\033[91m {id_part} skip (bw {bw_percent*100:.1f}%) \033[0m====="
|
| 250 |
+
)
|
| 251 |
+
continue
|
| 252 |
+
|
| 253 |
+
if sum(1 for r, g, b in pixels if r == g == b == 0) / len(pixels) < 0.15:
|
| 254 |
+
tqdm.write(
|
| 255 |
+
f"=====\033[91m {id_part} skip (not on black background) \033[0m====="
|
| 256 |
+
)
|
| 257 |
+
continue
|
| 258 |
+
except Exception as e:
|
| 259 |
+
tqdm.write(
|
| 260 |
+
f"=====\033[91m {id_part} skip (err img check: {e}) \033[0m====="
|
| 261 |
+
)
|
| 262 |
+
continue
|
| 263 |
+
|
| 264 |
+
ids.append(int(id_part))
|
| 265 |
+
atfs.append(atf)
|
| 266 |
+
unicodes.append(unicode)
|
| 267 |
+
|
| 268 |
+
tqdm.write(f"=====\033[32m {id_part} unicode (len {unicode_len}) \033[0m=====")
|
| 269 |
+
|
| 270 |
+
dataset = Dataset.from_dict(
|
| 271 |
+
{
|
| 272 |
+
"id": ids,
|
| 273 |
+
"atf": atfs,
|
| 274 |
+
"unicode": unicodes,
|
| 275 |
+
}
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
dataset.to_parquet(file)
|
| 279 |
+
return dataset.train_test_split(test_size=1000, seed=42)
|