boatbomber commited on
Commit
819ced0
·
verified ·
1 Parent(s): 1c275ca

Upload training code

Browse files
training/convert_atf.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from collections import Counter
3
+ from typing import Optional
4
+
5
+
6
+ class ParsedATF:
7
+ """Represents a parsed ATF document with methods to extract data."""
8
+
9
+ # Face types
10
+ ALL_FACES = [
11
+ "obverse",
12
+ "reverse",
13
+ "left",
14
+ "right",
15
+ "top",
16
+ "bottom",
17
+ ]
18
+
19
+ def __init__(
20
+ self, transliterations: dict, unicodes: dict, info: dict, used_signs: set
21
+ ):
22
+ """
23
+ Initialize parsed ATF data.
24
+
25
+ Args:
26
+ transliterations: Dictionary mapping face names to transliteration line lists
27
+ unicodes: Dictionary mapping face names to unicode line lists
28
+ info: Metadata dictionary (e.g., language)
29
+ """
30
+ self._transliterations = transliterations
31
+ self._unicodes = unicodes
32
+ self._info = info
33
+ self._used_signs = used_signs
34
+
35
+ def get_used_signs(self) -> set[str]:
36
+ """Get the set of used signs."""
37
+ return self._used_signs
38
+
39
+ def get_transliteration(self, face: str) -> Optional[str]:
40
+ """
41
+ Get the transliteration for a given face.
42
+
43
+ Args:
44
+ face: The face name (e.g., 'obverse', 'reverse')
45
+
46
+ Returns:
47
+ The transliteration as a string with lines separated by newlines,
48
+ or None if the face has no content
49
+ """
50
+ if face in self._transliterations:
51
+ return self._transliterations[face]
52
+ return None
53
+
54
+ def get_unicode(self, face: str) -> Optional[str]:
55
+ """
56
+ Get the unicode representation for a given face.
57
+
58
+ Args:
59
+ face: The face name (e.g., 'obverse', 'reverse')
60
+
61
+ Returns:
62
+ The unicode representation as a string with lines separated by newlines,
63
+ or None if the face has no content
64
+ """
65
+ if face in self._unicodes:
66
+ return self._unicodes[face]
67
+ return None
68
+
69
+ def get_all_unicodes(self) -> dict[str, Optional[str]]:
70
+ """
71
+ Get unicode for all faces.
72
+
73
+ Returns:
74
+ Dictionary mapping face names to unicode strings
75
+ """
76
+ return {
77
+ f"{face}_unicode": self.get_unicode(face)
78
+ for face in self.ALL_FACES
79
+ if self.get_unicode(face) is not None
80
+ }
81
+
82
+ def get_all_transliterations(self) -> dict[str, Optional[str]]:
83
+ """
84
+ Get transliteration for all faces.
85
+
86
+ Returns:
87
+ Dictionary mapping face names to transliteration strings
88
+ """
89
+ return {
90
+ f"{face}_transliteration": self.get_transliteration(face)
91
+ for face in self.ALL_FACES
92
+ if self.get_transliteration(face) is not None
93
+ }
94
+
95
+ @property
96
+ def info(self) -> dict:
97
+ """Get parsing info (e.g., language)."""
98
+ return self._info
99
+
100
+
101
+ class ATFConverter:
102
+ """Converter for ATF (ASCII Transliteration Format) cuneiform text."""
103
+
104
+ # Face types
105
+ ALL_FACES = [
106
+ "obverse",
107
+ "reverse",
108
+ "left",
109
+ "right",
110
+ "top",
111
+ "bottom",
112
+ ]
113
+
114
+ FACE_REMAPPING = {
115
+ "surface a": "obverse",
116
+ "surface b": "reverse",
117
+ }
118
+
119
+ # Special tokens
120
+ SPECIAL_TOKENS = [
121
+ "<B>", # broken
122
+ "<M>", # missing one or more token?
123
+ "<S>", # blank space
124
+ "<D>", # divine
125
+ "<munus>", # young woman, or woman
126
+ "<ansze>",
127
+ "<ki>",
128
+ "<disz>",
129
+ "x", # unknown signs
130
+ ]
131
+
132
+ def __init__(self, token_path: str = "./data/cuneiform_vocab.tsv"):
133
+ """
134
+ Initialize the ATF converter.
135
+
136
+ Args:
137
+ token_path: Path to the cuneiform vocabulary file
138
+ """
139
+ self.text2sign = self._load_token_mapping(token_path)
140
+
141
+ # Counters for statistics
142
+ self.vocab_freq = Counter()
143
+ self.new_tokens = Counter()
144
+ self.langs = Counter()
145
+ self.unknown_faces = Counter()
146
+
147
+ def _load_token_mapping(self, token_path: str) -> tuple[dict, dict]:
148
+ """Load the text to sign and sign to text mappings."""
149
+
150
+ text2sign = {}
151
+ for t in open(token_path).readlines():
152
+ try:
153
+ k, s = t.strip("\n").split("\t")
154
+ except:
155
+ print(t)
156
+ continue
157
+ text2sign[k] = s.replace(" ", "")
158
+
159
+ return text2sign
160
+
161
+ def _remove_at(self, x: str) -> Optional[str]:
162
+ """Remove @c or @t suffixes from tokens."""
163
+ if x.endswith("@c)") or x.endswith("@t)"):
164
+ return x[:-3] + ")"
165
+ return None
166
+
167
+ def _remove_spaces(self, x: list[str]) -> list[str]:
168
+ """Remove consecutive space tokens."""
169
+ new_x = []
170
+ for item in x:
171
+ if item == "<S>" and len(new_x) > 0 and new_x[-1] == "<S>":
172
+ continue
173
+ new_x.append(item)
174
+ return new_x
175
+
176
+ def parse(self, raw_text: str) -> Optional[ParsedATF]:
177
+ """
178
+ Parse ATF text and extract transliterations and unicode.
179
+
180
+ Args:
181
+ raw_text: The raw ATF text to parse
182
+
183
+ Returns:
184
+ ParsedATF object if parsing succeeded, None if the language is not supported
185
+ """
186
+ token_text = {"default": []}
187
+ info = {}
188
+
189
+ curr_face = "default"
190
+ sep = "\n"
191
+ if "\\n" in raw_text:
192
+ sep = "\\n"
193
+
194
+ for line in raw_text.split(sep):
195
+ line = line.strip()
196
+
197
+ if line.startswith("&") or line.startswith("'&"):
198
+ # metadata
199
+ pass
200
+ elif line.startswith("#atf"):
201
+ info["lang"] = line.split("lang ")[-1].strip()
202
+ self.langs[info["lang"]] += 1
203
+ if info["lang"] not in ["sux", "akk", "sux, akk", "akk _sux"]:
204
+ # do not process those not sux or akk
205
+ return None
206
+ elif (
207
+ line.startswith("#")
208
+ or line.startswith(">>")
209
+ or line.startswith("<<")
210
+ or line.startswith("||")
211
+ ):
212
+ # comment/link
213
+ continue
214
+ elif line.startswith("$"):
215
+ if "broken" in line:
216
+ try:
217
+ token_text[curr_face].append("<B>")
218
+ except:
219
+ continue
220
+ elif line.startswith("@"):
221
+ key = line[1:].strip().strip("?")
222
+ if key in self.ALL_FACES:
223
+ curr_face = key
224
+ token_text[key] = []
225
+ elif key.startswith("column"):
226
+ token_text[curr_face].append("<COL>")
227
+ else:
228
+ self.unknown_faces[key] += 1
229
+ else:
230
+ # Process line content
231
+ self._process_line_content(line, curr_face, token_text)
232
+
233
+ # Build transliterations and unicodes from token_text
234
+ transliterations, unicodes, used_signs = self._build_outputs(token_text)
235
+ return ParsedATF(transliterations, unicodes, info, used_signs)
236
+
237
+ def _process_line_content(self, line: str, curr_face: str, token_text: dict):
238
+ """Process a content line and extract tokens."""
239
+ # Special symbols
240
+ line = line.replace("{d}", "<D>")
241
+
242
+ for x in re.findall(r"\{.*?\}", line):
243
+ line = line.replace(x, " " + x[1:-1] + " ")
244
+
245
+ line = line.replace("($ blank space $)", "<S>")
246
+
247
+ # Remove underscore
248
+ line = line.replace("_", " ")
249
+
250
+ # Remove ending hash #
251
+ line = line.replace("#", "")
252
+
253
+ # Remove question mark, exclamation mark
254
+ line = line.replace("?", "")
255
+ line = line.replace("!", "")
256
+
257
+ # Remove [] and ()
258
+ for x in re.findall(r"\[.*?\]", line):
259
+ line = line.replace(x, "")
260
+
261
+ line = line.split(". ")
262
+
263
+ if len(line) >= 2:
264
+ # Make sure only leading line number is split
265
+ if len(line) > 2:
266
+ line = line[0], ". ".join(line[1:])
267
+
268
+ line_num, text = line
269
+ if curr_face != "":
270
+ tokens = text.split(" ")
271
+ signs = []
272
+ for i, t in enumerate(tokens):
273
+ # if i > 0 and len(signs) > 0:
274
+ # signs.append("<S>") # insert a space between words
275
+
276
+ if "-" in t:
277
+ ts = t.split("-")
278
+ for x in ts:
279
+ x = x.strip()
280
+ if len(x) == 0:
281
+ continue
282
+ if x in self.text2sign:
283
+ self.vocab_freq[x] += 1
284
+ signs.append(self.text2sign[x])
285
+ else:
286
+ new_x = self._remove_at(x)
287
+ if new_x and new_x in self.text2sign:
288
+ signs.append(self.text2sign[new_x])
289
+ else:
290
+ self.new_tokens[x] += 1
291
+ elif t in self.text2sign:
292
+ signs.append(self.text2sign[t])
293
+ elif t in self.SPECIAL_TOKENS:
294
+ self.vocab_freq[t] += 1
295
+ signs.append(t)
296
+ else:
297
+ new_x = self._remove_at(t)
298
+ if new_x and new_x in self.text2sign:
299
+ signs.append(self.text2sign[new_x])
300
+ else:
301
+ if len(t.strip()) > 0:
302
+ self.new_tokens[t] += 1
303
+
304
+ signs = self._remove_spaces(signs)
305
+ token_text[curr_face].append(
306
+ {"raw": text, "num": line_num, "sign": signs}
307
+ )
308
+
309
+ def _build_outputs(
310
+ self, token_text: dict
311
+ ) -> tuple[dict[str, list[list[str]]], dict[str, list[list[str]]], set[str]]:
312
+ """Build transliterations and unicode outputs from parsed token_text."""
313
+ transliterations = {}
314
+ unicodes = {}
315
+ used_signs = set()
316
+
317
+ for face in token_text.keys():
318
+ lines = token_text[face]
319
+ face_key = self.FACE_REMAPPING.get(face, face)
320
+
321
+ # List of columns, each column is a list of lines
322
+ face_transliterations: list[list[str]] = []
323
+ face_unicodes: list[list[str]] = []
324
+
325
+ current_column = {"transliteration": [], "unicode": []}
326
+
327
+ for line in lines:
328
+ if line == "<COL>":
329
+ if len(current_column["transliteration"]) > 0:
330
+ face_transliterations.append(current_column["transliteration"])
331
+ if len(current_column["unicode"]) > 0:
332
+ face_unicodes.append(current_column["unicode"])
333
+ current_column = {"transliteration": [], "unicode": []}
334
+ continue
335
+
336
+ if type(line) == str:
337
+ continue
338
+
339
+ used_signs.update(line.get("sign", ["<B>"]))
340
+
341
+ current_column["transliteration"].append(line.get("raw", "<B>"))
342
+ current_column["unicode"].append(" ".join(line.get("sign", ["<B>"])))
343
+
344
+ if len(current_column["transliteration"]) > 0:
345
+ face_transliterations.append(current_column["transliteration"])
346
+ if len(current_column["unicode"]) > 0:
347
+ face_unicodes.append(current_column["unicode"])
348
+
349
+ if len(face_transliterations) == 1:
350
+ # No need for column markers as there is only one column
351
+ transliterations[face_key] = "\n".join(face_transliterations[0])
352
+ else:
353
+ transliterations[face_key] = "\n".join(
354
+ [
355
+ f"@column {i+1}\n" + "\n".join(column)
356
+ for i, column in enumerate(face_transliterations)
357
+ ]
358
+ )
359
+
360
+ if len(face_unicodes) == 1:
361
+ # No need for column markers as there is only one column
362
+ unicodes[face_key] = "\n".join(face_unicodes[0])
363
+ else:
364
+ unicodes[face_key] = "\n".join(
365
+ [
366
+ f"@column {i+1}\n" + "\n".join(column)
367
+ for i, column in enumerate(face_unicodes)
368
+ ]
369
+ )
370
+
371
+ return transliterations, unicodes, used_signs
training/cuneiform_ocr_eval.ipynb ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "e4ca0fb0",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import torch\n",
11
+ "from PIL import Image\n",
12
+ "from tqdm.auto import tqdm\n",
13
+ "from transformers import AutoModelForCausalLM, AutoProcessor"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": null,
19
+ "id": "a961375e",
20
+ "metadata": {},
21
+ "outputs": [],
22
+ "source": [
23
+ "# Load dataset\n",
24
+ "from get_cdli_dataset import get_dataset, IMG_CACHE\n",
25
+ "\n",
26
+ "dataset = get_dataset()\n",
27
+ "test_dataset = dataset[\"test\"]\n",
28
+ "\n",
29
+ "print(test_dataset)"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": null,
35
+ "id": "e226c45c",
36
+ "metadata": {},
37
+ "outputs": [],
38
+ "source": [
39
+ "# Load the model\n",
40
+ "\n",
41
+ "# model_path = \"PaddlePaddle/PaddleOCR-VL\" # base\n",
42
+ "# model_path = \"./outputs/sft\"\n",
43
+ "model_path = \"../\"\n",
44
+ "\n",
45
+ "model = AutoModelForCausalLM.from_pretrained(\n",
46
+ " model_path, trust_remote_code=True, torch_dtype=torch.bfloat16\n",
47
+ ").to(\"cuda\").eval()\n",
48
+ "processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": null,
54
+ "id": "97b9a2cb",
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "import pyxdameraulevenshtein as dl\n",
59
+ "\n",
60
+ "def compute_ter(expected_ids: list[int], predicted_ids: list[int]) -> float:\n",
61
+ " \"\"\"\n",
62
+ " Compute Token Error Rate (TER) between ground truth and completion tokens.\n",
63
+ " TER = (substitutions + deletions + insertions) / len(ground_truth)\n",
64
+ "\n",
65
+ " TER is better than CER for cuneiform OCR as:\n",
66
+ " - Multi-character Unicode signs count as 1 token instead of multiple chars\n",
67
+ " - Special tokens like @obverse/@reverse count as 1 token\n",
68
+ " \"\"\"\n",
69
+ "\n",
70
+ " if len(expected_ids) == 0:\n",
71
+ " return 0.0 if len(predicted_ids) == 0 else 1.0\n",
72
+ "\n",
73
+ " # Calculate edit distance on token sequences\n",
74
+ " distance = dl.damerau_levenshtein_distance(expected_ids, predicted_ids)\n",
75
+ "\n",
76
+ " # TER is the edit distance normalized by the truth token count\n",
77
+ " ter = distance / max(1, len(expected_ids))\n",
78
+ "\n",
79
+ " return ter"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "code",
84
+ "execution_count": null,
85
+ "id": "859c4fc2",
86
+ "metadata": {},
87
+ "outputs": [],
88
+ "source": [
89
+ "# Run inference on all test examples\n",
90
+ "results = []\n",
91
+ "total_ter = 0.0\n",
92
+ "\n",
93
+ "pbar = tqdm(test_dataset, desc=\"Evaluating on test set\")\n",
94
+ "\n",
95
+ "for idx, example in enumerate(pbar):\n",
96
+ " expected = example[\"unicode\"]\n",
97
+ " expected_ids = processor.tokenizer.encode(expected, add_special_tokens = False)\n",
98
+ "\n",
99
+ " # Load image\n",
100
+ " with Image.open(IMG_CACHE / f\"P{str(example['id']).rjust(6, '0')}.jpg\").convert(\n",
101
+ " \"RGB\"\n",
102
+ " ) as image:\n",
103
+ " # Prepare input\n",
104
+ " messages = [\n",
105
+ " {\n",
106
+ " \"role\": \"user\",\n",
107
+ " \"content\": [\n",
108
+ " {\"type\": \"image\", \"image\": image},\n",
109
+ " {\"type\": \"text\", \"text\": \"OCR:\"},\n",
110
+ " ],\n",
111
+ " },\n",
112
+ " ]\n",
113
+ "\n",
114
+ " inputs = processor.apply_chat_template(\n",
115
+ " messages, \n",
116
+ " tokenize=True, \n",
117
+ " add_generation_prompt=True, \t\n",
118
+ " return_dict=True,\n",
119
+ " return_tensors=\"pt\"\n",
120
+ " ).to(\"cuda\")\n",
121
+ "\n",
122
+ " # Generate prediction\n",
123
+ " with torch.no_grad():\n",
124
+ " output_ids = model.generate(\n",
125
+ " **inputs,\n",
126
+ " use_cache=True,\n",
127
+ " max_new_tokens=int(len(expected_ids) * 1.2),\n",
128
+ " repetition_penalty=1.03,\n",
129
+ " )\n",
130
+ "\n",
131
+ " predicted_ids = output_ids[0][inputs[\"input_ids\"].shape[1] :][:-1].tolist()\n",
132
+ "\n",
133
+ " # Compute TER for this example\n",
134
+ " ter = compute_ter(expected_ids, predicted_ids)\n",
135
+ " total_ter += ter\n",
136
+ "\n",
137
+ " pbar.set_postfix_str(f\"AVG TER={total_ter / (idx+1):.3f}\")\n",
138
+ "\n",
139
+ " prediction = processor.decode(\n",
140
+ " predicted_ids,\n",
141
+ " skip_special_tokens=False,\n",
142
+ " ).strip()\n",
143
+ "\n",
144
+ " # Store results\n",
145
+ " results.append(\n",
146
+ " {\n",
147
+ " \"id\": example[\"id\"],\n",
148
+ " \"expected\": expected,\n",
149
+ " \"prediction\": prediction,\n",
150
+ " \"ter\": ter,\n",
151
+ " }\n",
152
+ " )\n",
153
+ " tqdm.write(f\"\\033[94m\\nID: {example['id']} | TER: {ter:.4f}\\033[0m\")\n",
154
+ " tqdm.write(f\"\\033[92mExpected:\\033[0m\\n{expected}\")\n",
155
+ " tqdm.write(f\"\\033[91mPredicted:\\033[0m\\n{prediction}\")\n",
156
+ "\n",
157
+ "# Compute averages\n",
158
+ "average_ter = total_ter / len(test_dataset)\n",
159
+ "print(f\"\\n{'='*60}\")\n",
160
+ "print(f\"Average Token Error Rate (TER): {average_ter:.4f} ({average_ter*100:.2f}%)\")\n",
161
+ "print(f\"{'='*60}\")"
162
+ ]
163
+ },
164
+ {
165
+ "cell_type": "code",
166
+ "execution_count": null,
167
+ "id": "3c6a8e02",
168
+ "metadata": {},
169
+ "outputs": [],
170
+ "source": [
171
+ "# Show examples: best and worst predictions (sorted by TER)\n",
172
+ "sorted_results = sorted(results, key=lambda x: x[\"ter\"])\n",
173
+ "\n",
174
+ "print(\"=\"*60)\n",
175
+ "print(\"BEST PREDICTIONS (Lowest TER)\")\n",
176
+ "print(\"=\"*60)\n",
177
+ "for i in range(min(10, len(sorted_results))):\n",
178
+ " r = sorted_results[i]\n",
179
+ " print(f\"\\nExample {i+1} - ID: {r['id']} - TER: {r['ter']:.4f}\")\n",
180
+ " print(f\"Expected:\\n{r['expected']}\")\n",
181
+ " print(f\"Predicted:\\n{r['prediction']}\")\n",
182
+ " print(\"-\"*60)\n",
183
+ "\n",
184
+ "print(\"\\n\" + \"=\"*60)\n",
185
+ "print(\"WORST PREDICTIONS (Highest TER)\")\n",
186
+ "print(\"=\"*60)\n",
187
+ "for i in range(min(10, len(sorted_results))):\n",
188
+ " r = sorted_results[-(i+1)]\n",
189
+ " print(f\"\\nExample {i+1} - ID: {r['id']} - TER: {r['ter']:.4f}\")\n",
190
+ " print(f\"Expected:\\n{r['expected']}\")\n",
191
+ " print(f\"Predicted:\\n{r['prediction']}\")\n",
192
+ " print(\"-\"*60)"
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "code",
197
+ "execution_count": null,
198
+ "id": "d5ceae30",
199
+ "metadata": {},
200
+ "outputs": [],
201
+ "source": [
202
+ "# TER and CER distribution statistics\n",
203
+ "import numpy as np\n",
204
+ "\n",
205
+ "ter_values = [r[\"ter\"] for r in results]\n",
206
+ "\n",
207
+ "print(\"=\"*60)\n",
208
+ "print(\"TER (TOKEN ERROR RATE) DISTRIBUTION STATISTICS\")\n",
209
+ "print(\"=\"*60)\n",
210
+ "print(f\"Mean TER: {np.mean(ter_values):.4f} ({np.mean(ter_values)*100:.2f}%)\")\n",
211
+ "print(f\"Median TER: {np.median(ter_values):.4f} ({np.median(ter_values)*100:.2f}%)\")\n",
212
+ "print(f\"Std Dev: {np.std(ter_values):.4f}\")\n",
213
+ "print(f\"Min TER: {np.min(ter_values):.4f} ({np.min(ter_values)*100:.2f}%)\")\n",
214
+ "print(f\"Max TER: {np.max(ter_values):.4f} ({np.max(ter_values)*100:.2f}%)\")\n",
215
+ "print(f\"\\nPercentiles:\")\n",
216
+ "print(f\" 25th: {np.percentile(ter_values, 25):.4f}\")\n",
217
+ "print(f\" 50th: {np.percentile(ter_values, 50):.4f}\")\n",
218
+ "print(f\" 75th: {np.percentile(ter_values, 75):.4f}\")\n",
219
+ "print(f\" 90th: {np.percentile(ter_values, 90):.4f}\")\n",
220
+ "print(f\" 95th: {np.percentile(ter_values, 95):.4f}\")\n",
221
+ "print(f\" 98th: {np.percentile(ter_values, 98):.4f}\")\n",
222
+ "\n",
223
+ "# Count perfect predictions\n",
224
+ "perfect_predictions = sum(1 for ter in ter_values if ter == 0.0)\n",
225
+ "print(f\"\\nPerfect predictions (TER=0%): {perfect_predictions}/{len(ter_values)} ({perfect_predictions/len(ter_values)*100:.2f}%)\")\n",
226
+ "\n",
227
+ "# Count predictions with TER < 0.5 (less than 50% error)\n",
228
+ "good_predictions = sum(1 for ter in ter_values if ter < 0.5)\n",
229
+ "print(f\"Good predictions (TER<50%): {good_predictions}/{len(ter_values)} ({good_predictions/len(ter_values)*100:.2f}%)\")"
230
+ ]
231
+ }
232
+ ],
233
+ "metadata": {
234
+ "kernelspec": {
235
+ "display_name": ".venv",
236
+ "language": "python",
237
+ "name": "python3"
238
+ },
239
+ "language_info": {
240
+ "codemirror_mode": {
241
+ "name": "ipython",
242
+ "version": 3
243
+ },
244
+ "file_extension": ".py",
245
+ "mimetype": "text/x-python",
246
+ "name": "python",
247
+ "nbconvert_exporter": "python",
248
+ "pygments_lexer": "ipython3",
249
+ "version": "3.13.6"
250
+ }
251
+ },
252
+ "nbformat": 4,
253
+ "nbformat_minor": 5
254
+ }
training/cuneiform_ocr_grpo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
training/cuneiform_ocr_sft.ipynb ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "fd2siqgrq6w",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "# CRITICAL: This patch MUST run BEFORE importing unsloth!\n",
11
+ "# Fix Unsloth's gradient checkpointing for models with keyword-only forward arguments\n",
12
+ "\n",
13
+ "import sys\n",
14
+ "import torch\n",
15
+ "import os\n",
16
+ "\n",
17
+ "# Import unsloth_zoo.peft_utils first so it's in sys.modules\n",
18
+ "os.environ[\"UNSLOTH_IS_PRESENT\"] = \"1\"\n",
19
+ "import unsloth_zoo.peft_utils\n",
20
+ "\n",
21
+ "# Now patch the function before anything else imports it\n",
22
+ "def patched_requires_grad_pre_hook(module, input):\n",
23
+ " \"\"\"Patched hook that handles empty input tuples gracefully\"\"\"\n",
24
+ " type_input = type(input)\n",
25
+ " if type_input is torch.Tensor:\n",
26
+ " input.requires_grad_(True)\n",
27
+ " elif type_input is tuple or type_input is list:\n",
28
+ " if len(input) == 0:\n",
29
+ " # Empty tuple = keyword-only args. This is fine, gradients flow through kwargs\n",
30
+ " return\n",
31
+ " if len(input) > 0 and torch.is_floating_point(input[0]):\n",
32
+ " input[0].requires_grad_(True)\n",
33
+ "\n",
34
+ "# Get the original function\n",
35
+ "original_func = sys.modules['unsloth_zoo.peft_utils'].requires_grad_for_gradient_checkpointing\n",
36
+ "\n",
37
+ "# Create wrapper that uses our patched hook\n",
38
+ "def patched_requires_grad_for_gradient_checkpointing(model):\n",
39
+ " \"\"\"Wrapper that calls original but uses patched hook\"\"\"\n",
40
+ " import re\n",
41
+ " import inspect\n",
42
+ " \n",
43
+ " # Define the other helper functions we need\n",
44
+ " def requires_grad_post_hook(module, input, output):\n",
45
+ " try:\n",
46
+ " if hasattr(output, \"loss\") and output.loss is not None:\n",
47
+ " output.loss.requires_grad_(True)\n",
48
+ " elif hasattr(output, \"logits\") and output.logits is not None:\n",
49
+ " output.logits.requires_grad_(True)\n",
50
+ " elif type(output) is torch.Tensor:\n",
51
+ " output.requires_grad_(True)\n",
52
+ " except: pass\n",
53
+ " \n",
54
+ " def register_other_hooks(hook_name, hook_func_name, module, hooks_dict_name):\n",
55
+ " if not hasattr(module, hooks_dict_name): return\n",
56
+ " hooks_dict = getattr(module, hooks_dict_name)\n",
57
+ " for hook_id, hook_fn in list(hooks_dict.items()):\n",
58
+ " if hook_func_name in str(hook_fn):\n",
59
+ " del hooks_dict[hook_id]\n",
60
+ " \n",
61
+ " # Find first parameter with requires_grad\n",
62
+ " param = None\n",
63
+ " for name, param in model.named_parameters():\n",
64
+ " if param.requires_grad: break\n",
65
+ " if param is None: return\n",
66
+ " \n",
67
+ " name = re.sub(r\"\\.([\\d]{1,})\\.\", r\"[\\1].\", name)\n",
68
+ " name_components = name.split(\".\")\n",
69
+ " if len(name_components) == 0:\n",
70
+ " raise RuntimeError(\"Unsloth: Model has 0 layers?\")\n",
71
+ " \n",
72
+ " # Find the module to hook\n",
73
+ " final_where = None\n",
74
+ " for j in range(len(name_components)-1, 0, -1):\n",
75
+ " name_curr = name_components[j]\n",
76
+ " name_pre = \"model.\" + \".\".join(name_components[:j])\n",
77
+ " if re.search(r\"\\[[\\d]{1,}\\]\", name_pre): continue\n",
78
+ " module = eval(name_pre)\n",
79
+ " if hasattr(module, \"forward\"):\n",
80
+ " try: forward = inspect.getsource(module.forward)\n",
81
+ " except: continue\n",
82
+ " if f\"self.{name_curr}(\" in forward:\n",
83
+ " final_where = j + 1\n",
84
+ " break\n",
85
+ " module_list = re.sub(r\"\\[[\\d]{1,}\\]\", \"\", name_curr)\n",
86
+ " if f\"in self.{module_list}:\" in forward:\n",
87
+ " final_where = j\n",
88
+ " break\n",
89
+ " elif re.search(r\"for [^\\s]{3,} in self\\.\" + module_list, forward):\n",
90
+ " final_where = j\n",
91
+ " break\n",
92
+ " \n",
93
+ " if final_where is None:\n",
94
+ " for module_name, module in model.named_modules():\n",
95
+ " if not hasattr(module, \"get_input_embeddings\"): break\n",
96
+ " register_other_hooks(\"requires_grad_post_hook\", \"requires_grad_post_hook\", module, \"_forward_hooks\")\n",
97
+ " module.register_forward_hook(requires_grad_post_hook)\n",
98
+ " return\n",
99
+ " \n",
100
+ " module_name = \"model.\" + \".\".join(name_components[:final_where])\n",
101
+ " module = eval(module_name)\n",
102
+ " \n",
103
+ " if hasattr(module, \"config\") and module.config.__class__.__name__ in (\"CLIPVisionConfig\", \"SiglipVisionConfig\"):\n",
104
+ " old_module = model\n",
105
+ " for module_name, module in model.named_modules():\n",
106
+ " if not hasattr(module, \"get_input_embeddings\"): break\n",
107
+ " old_module = module\n",
108
+ " module = old_module\n",
109
+ " \n",
110
+ " print(f\"Unsloth: Making `{module_name}` require gradients\")\n",
111
+ " \n",
112
+ " # Try post-hook first\n",
113
+ " if hasattr(module, \"get_input_embeddings\"):\n",
114
+ " try:\n",
115
+ " module = module.get_input_embeddings()\n",
116
+ " register_other_hooks(\"requires_grad_post_hook\", \"requires_grad_post_hook\", module, \"_forward_hooks\")\n",
117
+ " module.register_forward_hook(requires_grad_post_hook)\n",
118
+ " return\n",
119
+ " except: pass\n",
120
+ " \n",
121
+ " # Use our patched pre-hook\n",
122
+ " register_other_hooks(\"requires_grad_pre_hook\", \"requires_grad_pre_hook\", module, \"_forward_pre_hooks\")\n",
123
+ " module.register_forward_pre_hook(patched_requires_grad_pre_hook)\n",
124
+ "\n",
125
+ "# Replace in sys.modules\n",
126
+ "sys.modules['unsloth_zoo.peft_utils'].requires_grad_for_gradient_checkpointing = patched_requires_grad_for_gradient_checkpointing\n",
127
+ "\n",
128
+ "print(\"✓ Patched Unsloth gradient checkpointing BEFORE imports\")"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": null,
134
+ "id": "c2c30bc6",
135
+ "metadata": {},
136
+ "outputs": [],
137
+ "source": [
138
+ "from unsloth import FastVisionModel\n",
139
+ "from unsloth.trainer import UnslothVisionDataCollator"
140
+ ]
141
+ },
142
+ {
143
+ "cell_type": "code",
144
+ "execution_count": null,
145
+ "id": "4326b62e",
146
+ "metadata": {},
147
+ "outputs": [],
148
+ "source": [
149
+ "import torch\n",
150
+ "from PIL import Image\n",
151
+ "from transformers import AutoModel, AutoProcessor\n",
152
+ "from trl import SFTTrainer, SFTConfig"
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "code",
157
+ "execution_count": null,
158
+ "id": "d5e899ca",
159
+ "metadata": {},
160
+ "outputs": [],
161
+ "source": [
162
+ "from get_cdli_dataset import atf_converter, get_dataset, IMG_CACHE\n",
163
+ "\n",
164
+ "# Load dataset\n",
165
+ "dataset = get_dataset()\n",
166
+ "\n",
167
+ "train_dataset = dataset[\"train\"]\n",
168
+ "test_dataset = dataset[\"test\"]\n",
169
+ "\n",
170
+ "print(train_dataset, test_dataset)\n"
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "code",
175
+ "execution_count": null,
176
+ "id": "9e0aa56b",
177
+ "metadata": {},
178
+ "outputs": [],
179
+ "source": [
180
+ "# Load processor and model\n",
181
+ "model, tokenizer = FastVisionModel.from_pretrained(\n",
182
+ " \"PaddlePaddle/PaddleOCR-VL\",\n",
183
+ " cache_dir = \"./hf_cache/models\",\n",
184
+ " trust_remote_code = True,\n",
185
+ " load_in_4bit = False,\n",
186
+ " auto_model = AutoModel,\n",
187
+ " full_finetuning=True,\n",
188
+ " unsloth_force_compile = True,\n",
189
+ " use_gradient_checkpointing = \"unsloth\",\n",
190
+ " max_seq_length = 16000,\n",
191
+ ")\n",
192
+ "\n",
193
+ "processor = AutoProcessor.from_pretrained(\n",
194
+ " \"PaddlePaddle/PaddleOCR-VL\",\n",
195
+ " cache_dir=\"./hf_cache/models\",\n",
196
+ " trust_remote_code=True,\n",
197
+ ")\n",
198
+ "\n",
199
+ "processor.tokenizer = tokenizer\n"
200
+ ]
201
+ },
202
+ {
203
+ "cell_type": "code",
204
+ "execution_count": null,
205
+ "id": "28656983",
206
+ "metadata": {},
207
+ "outputs": [],
208
+ "source": [
209
+ "used_signs = set()\n",
210
+ "for example in train_dataset:\n",
211
+ " parsed = atf_converter.parse(example[\"atf\"])\n",
212
+ " used_signs.update(parsed.get_used_signs())\n",
213
+ "for example in test_dataset:\n",
214
+ " parsed = atf_converter.parse(example[\"atf\"])\n",
215
+ " used_signs.update(parsed.get_used_signs())\n",
216
+ "\n",
217
+ "print(f\"Base model vocab size: {len(processor.tokenizer)}\")\n",
218
+ "\n",
219
+ "# Add the cuneiform to the model vocab\n",
220
+ "num_added_tokens = processor.tokenizer.add_tokens(list(used_signs))\n",
221
+ "num_added_special_tokens = processor.tokenizer.add_special_tokens(\n",
222
+ " {\n",
223
+ " \"additional_special_tokens\": [f\"@{face}\" for face in atf_converter.ALL_FACES]\n",
224
+ " + atf_converter.SPECIAL_TOKENS\n",
225
+ " },\n",
226
+ " replace_additional_special_tokens=False,\n",
227
+ ")\n",
228
+ "\n",
229
+ "print(\n",
230
+ " f\"Added {num_added_tokens} tokens and {num_added_special_tokens} special tokens to tokenizer\"\n",
231
+ ")\n",
232
+ "\n",
233
+ "# Assign the average to the new token embeddings\n",
234
+ "model.resize_token_embeddings(len(processor.tokenizer))\n",
235
+ "\n",
236
+ "print(f\"New model vocab size: {len(processor.tokenizer)}\")"
237
+ ]
238
+ },
239
+ {
240
+ "cell_type": "code",
241
+ "execution_count": null,
242
+ "id": "5100b97c",
243
+ "metadata": {},
244
+ "outputs": [],
245
+ "source": [
246
+ "# Configure training\n",
247
+ "sft_training_args = SFTConfig(\n",
248
+ " output_dir=\"./outputs/sft\",\n",
249
+ " # max_steps=50, # Remove for full run\n",
250
+ " num_train_epochs=2,\n",
251
+ " per_device_train_batch_size=2,\n",
252
+ " per_device_eval_batch_size=2,\n",
253
+ " gradient_accumulation_steps=1,\n",
254
+ " learning_rate=2e-5,\n",
255
+ " optim=\"adamw_8bit\",\n",
256
+ " warmup_ratio=0.05,\n",
257
+ " weight_decay=0.001,\n",
258
+ " lr_scheduler_type=\"linear\",\n",
259
+ " bf16=True,\n",
260
+ " save_strategy=\"steps\",\n",
261
+ " save_steps=200,\n",
262
+ " eval_strategy=\"steps\",\n",
263
+ " eval_steps=1000,\n",
264
+ " logging_steps=1,\n",
265
+ " report_to=\"none\",\n",
266
+ " dataloader_num_workers=0,\n",
267
+ " # You MUST put the below items for vision finetuning:\n",
268
+ " remove_unused_columns=False,\n",
269
+ " dataset_text_field=\"\",\n",
270
+ " dataset_kwargs={\"skip_prepare_dataset\": True},\n",
271
+ " max_length=16000,\n",
272
+ ")\n",
273
+ "\n",
274
+ "# Initialize trainer\n",
275
+ "sft_trainer = SFTTrainer(\n",
276
+ " model=model,\n",
277
+ " processing_class=processor,\n",
278
+ " data_collator=UnslothVisionDataCollator(\n",
279
+ " model,\n",
280
+ " processor,\n",
281
+ " train_on_responses_only=False, # Fixed: was masking all tokens with True\n",
282
+ " instruction_part=\"User: \",\n",
283
+ " response_part=\"Assistant: \",\n",
284
+ " pad_to_multiple_of=2,\n",
285
+ " resize_dimension=\"max\",\n",
286
+ " formatting_func=lambda example: {\n",
287
+ " \"images\": [\n",
288
+ " Image.open(IMG_CACHE / f\"P{str(example['id']).rjust(6, '0')}.jpg\")\n",
289
+ " ],\n",
290
+ " \"messages\": [\n",
291
+ " # Add user message with image and task prompt\n",
292
+ " {\n",
293
+ " \"role\": \"user\",\n",
294
+ " \"content\": [\n",
295
+ " {\n",
296
+ " \"type\": \"image\",\n",
297
+ " \"image\": Image.open(\n",
298
+ " IMG_CACHE / f\"P{str(example['id']).rjust(6, '0')}.jpg\"\n",
299
+ " ),\n",
300
+ " },\n",
301
+ " {\"type\": \"text\", \"text\": \"OCR:\"},\n",
302
+ " ],\n",
303
+ " },\n",
304
+ " # Add assistant message with completion text\n",
305
+ " {\n",
306
+ " \"role\": \"assistant\",\n",
307
+ " \"content\": [{\"type\": \"text\", \"text\": example[\"unicode\"]}],\n",
308
+ " },\n",
309
+ " ],\n",
310
+ " },\n",
311
+ " ),\n",
312
+ " args=sft_training_args,\n",
313
+ " train_dataset=train_dataset,\n",
314
+ " eval_dataset=test_dataset,\n",
315
+ ")"
316
+ ]
317
+ },
318
+ {
319
+ "cell_type": "code",
320
+ "execution_count": null,
321
+ "id": "97e8455e",
322
+ "metadata": {},
323
+ "outputs": [],
324
+ "source": [
325
+ "gpu_stats = torch.cuda.get_device_properties(0)\n",
326
+ "start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
327
+ "max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n",
328
+ "print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n",
329
+ "print(f\"{start_gpu_memory} GB of memory reserved.\")"
330
+ ]
331
+ },
332
+ {
333
+ "cell_type": "code",
334
+ "execution_count": null,
335
+ "id": "e6103bbe",
336
+ "metadata": {},
337
+ "outputs": [],
338
+ "source": [
339
+ "sft_trainer_stats = sft_trainer.train(resume_from_checkpoint=False)"
340
+ ]
341
+ },
342
+ {
343
+ "cell_type": "code",
344
+ "execution_count": null,
345
+ "id": "36dc79b9",
346
+ "metadata": {},
347
+ "outputs": [],
348
+ "source": [
349
+ "used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
350
+ "used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n",
351
+ "used_percentage = round(used_memory / max_memory * 100, 3)\n",
352
+ "lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n",
353
+ "print(f\"{sft_trainer_stats.metrics['train_runtime']} seconds used for training.\")\n",
354
+ "print(\n",
355
+ " f\"{round(sft_trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\"\n",
356
+ ")\n",
357
+ "print(f\"Peak reserved memory = {used_memory} GB.\")\n",
358
+ "print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n",
359
+ "print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n",
360
+ "print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")"
361
+ ]
362
+ },
363
+ {
364
+ "cell_type": "code",
365
+ "execution_count": null,
366
+ "id": "50bdf718",
367
+ "metadata": {},
368
+ "outputs": [],
369
+ "source": [
370
+ "# Save model\n",
371
+ "processor.save_pretrained(sft_training_args.output_dir)\n",
372
+ "model.save_pretrained(sft_training_args.output_dir, processor)\n"
373
+ ]
374
+ }
375
+ ],
376
+ "metadata": {
377
+ "kernelspec": {
378
+ "display_name": ".venv",
379
+ "language": "python",
380
+ "name": "python3"
381
+ },
382
+ "language_info": {
383
+ "codemirror_mode": {
384
+ "name": "ipython",
385
+ "version": 3
386
+ },
387
+ "file_extension": ".py",
388
+ "mimetype": "text/x-python",
389
+ "name": "python",
390
+ "nbconvert_exporter": "python",
391
+ "pygments_lexer": "ipython3",
392
+ "version": "3.13.6"
393
+ }
394
+ },
395
+ "nbformat": 4,
396
+ "nbformat_minor": 5
397
+ }
training/get_cdli_dataset.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import concurrent.futures
2
+ import math
3
+ import time
4
+ from io import BytesIO
5
+ from pathlib import Path
6
+
7
+ import requests
8
+ from convert_atf import ATFConverter
9
+ from datasets import Dataset
10
+ from PIL import Image
11
+ from tqdm.auto import tqdm
12
+
13
+ atf_converter = ATFConverter()
14
+
15
+ IMG_CACHE = Path("./data/cdli_images")
16
+ IMG_CACHE.mkdir(exist_ok=True, parents=True)
17
+ MAX_IMG_RES = 2048
18
+ DOWNLOAD_MODE = False
19
+
20
+
21
+ def smart_resize(
22
+ height: int,
23
+ width: int,
24
+ factor: int = 28,
25
+ min_pixels: int = 28 * 28 * 130,
26
+ max_pixels: int = 28 * 28 * 1280,
27
+ ):
28
+ """Rescales the image so that the following conditions are met:
29
+
30
+ 1. Both dimensions (height and width) are divisible by 'factor'.
31
+
32
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
33
+
34
+ 3. The aspect ratio of the image is maintained as closely as possible.
35
+
36
+ """
37
+ # if height < factor or width < factor:
38
+ # raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
39
+ # if int(height < factor//4) + int(width < factor//4):
40
+ # raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor//4}")
41
+
42
+ if height < factor:
43
+ print(f"smart_resize: height={height} < factor={factor}, reset height=factor")
44
+ width = round((width * factor) / height)
45
+ height = factor
46
+
47
+ if width < factor:
48
+ print(f"smart_resize: width={width} < factor={factor}, reset width=factor")
49
+ height = round((height * factor) / width)
50
+ width = factor
51
+
52
+ if max(height, width) / min(height, width) > 200:
53
+ raise ValueError(
54
+ f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
55
+ )
56
+ h_bar = round(height / factor) * factor
57
+ w_bar = round(width / factor) * factor
58
+ if h_bar * w_bar > max_pixels:
59
+ beta = math.sqrt((height * width) / max_pixels)
60
+ h_bar = math.floor(height / beta / factor) * factor
61
+ w_bar = math.floor(width / beta / factor) * factor
62
+ elif h_bar * w_bar < min_pixels:
63
+ beta = math.sqrt(min_pixels / (height * width))
64
+ h_bar = math.ceil(height * beta / factor) * factor
65
+ w_bar = math.ceil(width * beta / factor) * factor
66
+ return h_bar, w_bar
67
+
68
+
69
+ def resize_image(img_path):
70
+ with Image.open(img_path).convert("RGB") as image:
71
+ width, height = image.size
72
+ # Scale down if larger than MAX_IMG_RES
73
+ if width > MAX_IMG_RES or height > MAX_IMG_RES:
74
+ scale = MAX_IMG_RES / max(width, height)
75
+ height = int(height * scale)
76
+ width = int(width * scale)
77
+ # Always ensure dimensions are multiples of 28 for vision model compatibility
78
+ new_height, new_width = smart_resize(height, width)
79
+ if new_height != image.height or new_width != image.width:
80
+ image = image.resize((new_width, new_height), Image.LANCZOS)
81
+ image.save(img_path)
82
+
83
+
84
+ def resize_cached_images():
85
+ img_paths = list(IMG_CACHE.glob("*.jpg"))
86
+ pbar = tqdm(img_paths)
87
+
88
+ with concurrent.futures.ThreadPoolExecutor(max_workers=20) as executor:
89
+ futures = [executor.submit(resize_image, img_path) for img_path in img_paths]
90
+ for future in concurrent.futures.as_completed(futures):
91
+ pbar.update(1)
92
+
93
+ pbar.close()
94
+
95
+
96
+ def get_image(id: int):
97
+ file_name = f"P{str(id).rjust(6, '0')}.jpg"
98
+ url = f"https://cdli.earth/dl/photo/{file_name}"
99
+ cache_file = IMG_CACHE / file_name
100
+
101
+ try:
102
+ if cache_file.exists():
103
+ tqdm.write(f"Found {file_name} in cache")
104
+ image = Image.open(cache_file).convert("RGB")
105
+ else:
106
+ response = requests.get(url, timeout=5)
107
+ response.raise_for_status()
108
+ image = Image.open(BytesIO(response.content)).convert("RGB")
109
+
110
+ tqdm.write(f"Downloaded {file_name}")
111
+
112
+ width, height = image.size
113
+ # Scale down if larger than MAX_IMG_RES
114
+ if width > MAX_IMG_RES or height > MAX_IMG_RES:
115
+ scale = MAX_IMG_RES / max(width, height)
116
+ height = int(height * scale)
117
+ width = int(width * scale)
118
+ # Always ensure dimensions are multiples of 28 for vision model compatibility
119
+ new_height, new_width = smart_resize(height, width)
120
+ if new_height != image.height or new_width != image.width:
121
+ image = image.resize((new_width, new_height), Image.LANCZOS)
122
+
123
+ image.save(cache_file)
124
+ time.sleep(0.02) # Rate limiting
125
+ except requests.exceptions.Timeout:
126
+ tqdm.write(f"Timeout downloading {file_name}")
127
+ return None
128
+ except requests.exceptions.RequestException as e:
129
+ tqdm.write(f"Error downloading {file_name}: {e}")
130
+ return None
131
+ except Exception as e:
132
+ tqdm.write(f"Error processing {file_name}: {type(e).__name__}: {e}")
133
+ return None
134
+
135
+ return image
136
+
137
+
138
+ def count_repetitions(text: str) -> int:
139
+ """
140
+ Count the total number of repeated token occurrences in a sequence.
141
+ E.g., 122233 has 3 repetitions (2 appears 2 extra times, 3 appears 1 extra time).
142
+ """
143
+ if len(text) < 2:
144
+ return 0
145
+
146
+ return len(text) - len(set(text))
147
+
148
+
149
+ def get_dataset(file="./data/cdli_dataset.parquet"):
150
+ if Path(file).exists():
151
+ return Dataset.from_parquet(file).train_test_split(test_size=1000, seed=42)
152
+
153
+ # 1. Get all the ids from cdli.atf (source: https://github.com/cdli-gh/data/raw/refs/heads/master/cdliatf_unblocked.atf)
154
+ cdli_raw = Path("./data/cdli.atf").read_text(encoding="utf-8").split("&P")
155
+ cdli_filtered = [
156
+ section.strip()
157
+ for section in cdli_raw
158
+ if section.strip() # Ignore empty sections
159
+ and "@tablet" in section # Only include tablets
160
+ and len(section) > 50 # Ignore short sections
161
+ and len(section) < 1000 # Ignore long sections
162
+ and any(lang in section for lang in ["sux", "akk"]) # Limit supported languages
163
+ ]
164
+
165
+ ids = []
166
+ atfs = []
167
+ unicodes = []
168
+
169
+ for section in tqdm(cdli_filtered, desc="Parsing CDLI dump"):
170
+ # Split section at first space to get the ID, ignore if not parseable
171
+ lines = section.splitlines()
172
+ id_part = lines[0].split("=")[0].strip()
173
+ if not id_part.isdigit():
174
+ continue
175
+
176
+ atf = "\n".join(
177
+ [
178
+ line
179
+ for line in lines[1:]
180
+ if not (
181
+ line.startswith("# ")
182
+ or line.startswith(">>")
183
+ or line.startswith("<<")
184
+ or line.startswith("||")
185
+ )
186
+ ]
187
+ )
188
+ parsed = atf_converter.parse(atf)
189
+ if parsed is None:
190
+ tqdm.write(f"=====\033[91m {id_part} skip (parse fail) \033[0m=====")
191
+ continue
192
+
193
+ unicode_parts = [
194
+ f"@{face}\n{parsed.get_unicode(face)}"
195
+ for face in parsed.ALL_FACES
196
+ if parsed.get_unicode(face)
197
+ ]
198
+
199
+ # Skip massive tablets
200
+ unicode_len = sum([len(part) for part in unicode_parts])
201
+ if unicode_len > 300 or unicode_len < 20:
202
+ tqdm.write(f"=====\033[91m {id_part} skip (too short/long) \033[0m=====")
203
+ continue
204
+ # Skip tablets that are poorly translated to unicode
205
+ if sum([part.count("x") for part in unicode_parts]) >= 2:
206
+ tqdm.write(f"=====\033[91m {id_part} skip (missing symbols) \033[0m=====")
207
+ continue
208
+
209
+ unicode = "\n".join(unicode_parts)
210
+
211
+ # Drop the super repetitive admin tablets (model ends up getting stuck repeating the common phrases)
212
+ if count_repetitions(unicode) / len(unicode) > 0.7:
213
+ tqdm.write(f"=====\033[91m {id_part} skip (too repetitive) \033[0m=====")
214
+ continue
215
+
216
+ # Ignore if we don't have an image for this atf
217
+ if DOWNLOAD_MODE:
218
+ image = get_image(int(id_part))
219
+ elif (IMG_CACHE / f"P{str(int(id_part)).rjust(6, '0')}.jpg").exists():
220
+ image = Image.open(
221
+ IMG_CACHE / f"P{str(int(id_part)).rjust(6, '0')}.jpg"
222
+ ).convert("RGB")
223
+ else:
224
+ tqdm.write(f"=====\033[91m {id_part} skip (no img) \033[0m=====")
225
+ continue
226
+
227
+ if not image:
228
+ tqdm.write(f"=====\033[91m {id_part} skip (no img) \033[0m=====")
229
+ continue
230
+
231
+ # Drop low res, B&W, or non-isolated background
232
+ try:
233
+ if min(image.size) < 100:
234
+ tqdm.write(f"=====\033[91m {id_part} skip (lowres) \033[0m=====")
235
+ continue
236
+
237
+ scale = 150 / image.height
238
+ small_image = image.resize(
239
+ (int(image.width * scale), int(image.height * scale)), Image.LANCZOS
240
+ )
241
+ pixels = list(small_image.getdata())
242
+ small_image.close()
243
+ image.close()
244
+
245
+ bw_pixels = sum(1 for r, g, b in pixels if r == g == b)
246
+ bw_percent = bw_pixels / len(pixels)
247
+ if bw_percent > 0.95 or bw_percent < 0.1:
248
+ tqdm.write(
249
+ f"=====\033[91m {id_part} skip (bw {bw_percent*100:.1f}%) \033[0m====="
250
+ )
251
+ continue
252
+
253
+ if sum(1 for r, g, b in pixels if r == g == b == 0) / len(pixels) < 0.15:
254
+ tqdm.write(
255
+ f"=====\033[91m {id_part} skip (not on black background) \033[0m====="
256
+ )
257
+ continue
258
+ except Exception as e:
259
+ tqdm.write(
260
+ f"=====\033[91m {id_part} skip (err img check: {e}) \033[0m====="
261
+ )
262
+ continue
263
+
264
+ ids.append(int(id_part))
265
+ atfs.append(atf)
266
+ unicodes.append(unicode)
267
+
268
+ tqdm.write(f"=====\033[32m {id_part} unicode (len {unicode_len}) \033[0m=====")
269
+
270
+ dataset = Dataset.from_dict(
271
+ {
272
+ "id": ids,
273
+ "atf": atfs,
274
+ "unicode": unicodes,
275
+ }
276
+ )
277
+
278
+ dataset.to_parquet(file)
279
+ return dataset.train_test_split(test_size=1000, seed=42)