| import concurrent.futures
|
| import math
|
| import time
|
| from io import BytesIO
|
| from pathlib import Path
|
|
|
| import requests
|
| from convert_atf import ATFConverter
|
| from datasets import Dataset
|
| from PIL import Image
|
| from tqdm.auto import tqdm
|
|
|
| atf_converter = ATFConverter()
|
|
|
| IMG_CACHE = Path("./data/cdli_images")
|
| IMG_CACHE.mkdir(exist_ok=True, parents=True)
|
| MAX_IMG_RES = 2048
|
| DOWNLOAD_MODE = False
|
|
|
|
|
| def smart_resize(
|
| height: int,
|
| width: int,
|
| factor: int = 28,
|
| min_pixels: int = 28 * 28 * 130,
|
| max_pixels: int = 28 * 28 * 1280,
|
| ):
|
| """Rescales the image so that the following conditions are met:
|
|
|
| 1. Both dimensions (height and width) are divisible by 'factor'.
|
|
|
| 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
|
|
| 3. The aspect ratio of the image is maintained as closely as possible.
|
|
|
| """
|
|
|
|
|
|
|
|
|
|
|
| if height < factor:
|
| print(f"smart_resize: height={height} < factor={factor}, reset height=factor")
|
| width = round((width * factor) / height)
|
| height = factor
|
|
|
| if width < factor:
|
| print(f"smart_resize: width={width} < factor={factor}, reset width=factor")
|
| height = round((height * factor) / width)
|
| width = factor
|
|
|
| if max(height, width) / min(height, width) > 200:
|
| raise ValueError(
|
| f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
|
| )
|
| h_bar = round(height / factor) * factor
|
| w_bar = round(width / factor) * factor
|
| if h_bar * w_bar > max_pixels:
|
| beta = math.sqrt((height * width) / max_pixels)
|
| h_bar = math.floor(height / beta / factor) * factor
|
| w_bar = math.floor(width / beta / factor) * factor
|
| elif h_bar * w_bar < min_pixels:
|
| beta = math.sqrt(min_pixels / (height * width))
|
| h_bar = math.ceil(height * beta / factor) * factor
|
| w_bar = math.ceil(width * beta / factor) * factor
|
| return h_bar, w_bar
|
|
|
|
|
| def resize_image(img_path):
|
| with Image.open(img_path).convert("RGB") as image:
|
| width, height = image.size
|
|
|
| if width > MAX_IMG_RES or height > MAX_IMG_RES:
|
| scale = MAX_IMG_RES / max(width, height)
|
| height = int(height * scale)
|
| width = int(width * scale)
|
|
|
| new_height, new_width = smart_resize(height, width)
|
| if new_height != image.height or new_width != image.width:
|
| image = image.resize((new_width, new_height), Image.LANCZOS)
|
| image.save(img_path)
|
|
|
|
|
| def resize_cached_images():
|
| img_paths = list(IMG_CACHE.glob("*.jpg"))
|
| pbar = tqdm(img_paths)
|
|
|
| with concurrent.futures.ThreadPoolExecutor(max_workers=20) as executor:
|
| futures = [executor.submit(resize_image, img_path) for img_path in img_paths]
|
| for future in concurrent.futures.as_completed(futures):
|
| pbar.update(1)
|
|
|
| pbar.close()
|
|
|
|
|
| def get_image(id: int):
|
| file_name = f"P{str(id).rjust(6, '0')}.jpg"
|
| url = f"https://cdli.earth/dl/photo/{file_name}"
|
| cache_file = IMG_CACHE / file_name
|
|
|
| try:
|
| if cache_file.exists():
|
| tqdm.write(f"Found {file_name} in cache")
|
| image = Image.open(cache_file).convert("RGB")
|
| else:
|
| response = requests.get(url, timeout=5)
|
| response.raise_for_status()
|
| image = Image.open(BytesIO(response.content)).convert("RGB")
|
|
|
| tqdm.write(f"Downloaded {file_name}")
|
|
|
| width, height = image.size
|
|
|
| if width > MAX_IMG_RES or height > MAX_IMG_RES:
|
| scale = MAX_IMG_RES / max(width, height)
|
| height = int(height * scale)
|
| width = int(width * scale)
|
|
|
| new_height, new_width = smart_resize(height, width)
|
| if new_height != image.height or new_width != image.width:
|
| image = image.resize((new_width, new_height), Image.LANCZOS)
|
|
|
| image.save(cache_file)
|
| time.sleep(0.02)
|
| except requests.exceptions.Timeout:
|
| tqdm.write(f"Timeout downloading {file_name}")
|
| return None
|
| except requests.exceptions.RequestException as e:
|
| tqdm.write(f"Error downloading {file_name}: {e}")
|
| return None
|
| except Exception as e:
|
| tqdm.write(f"Error processing {file_name}: {type(e).__name__}: {e}")
|
| return None
|
|
|
| return image
|
|
|
|
|
| def count_repetitions(text: str) -> int:
|
| """
|
| Count the total number of repeated token occurrences in a sequence.
|
| E.g., 122233 has 3 repetitions (2 appears 2 extra times, 3 appears 1 extra time).
|
| """
|
| if len(text) < 2:
|
| return 0
|
|
|
| return len(text) - len(set(text))
|
|
|
|
|
| def get_dataset(file="./data/cdli_dataset.parquet"):
|
| if Path(file).exists():
|
| return Dataset.from_parquet(file).train_test_split(test_size=1000, seed=42)
|
|
|
|
|
| cdli_raw = Path("./data/cdli.atf").read_text(encoding="utf-8").split("&P")
|
| cdli_filtered = [
|
| section.strip()
|
| for section in cdli_raw
|
| if section.strip()
|
| and "@tablet" in section
|
| and len(section) > 50
|
| and len(section) < 1000
|
| and any(lang in section for lang in ["sux", "akk"])
|
| ]
|
|
|
| ids = []
|
| atfs = []
|
| unicodes = []
|
|
|
| for section in tqdm(cdli_filtered, desc="Parsing CDLI dump"):
|
|
|
| lines = section.splitlines()
|
| id_part = lines[0].split("=")[0].strip()
|
| if not id_part.isdigit():
|
| continue
|
|
|
| atf = "\n".join(
|
| [
|
| line
|
| for line in lines[1:]
|
| if not (
|
| line.startswith("# ")
|
| or line.startswith(">>")
|
| or line.startswith("<<")
|
| or line.startswith("||")
|
| )
|
| ]
|
| )
|
| parsed = atf_converter.parse(atf)
|
| if parsed is None:
|
| tqdm.write(f"=====\033[91m {id_part} skip (parse fail) \033[0m=====")
|
| continue
|
|
|
| unicode_parts = [
|
| f"@{face}\n{parsed.get_unicode(face)}"
|
| for face in parsed.ALL_FACES
|
| if parsed.get_unicode(face)
|
| ]
|
|
|
|
|
| unicode_len = sum([len(part) for part in unicode_parts])
|
| if unicode_len > 300 or unicode_len < 20:
|
| tqdm.write(f"=====\033[91m {id_part} skip (too short/long) \033[0m=====")
|
| continue
|
|
|
| if sum([part.count("x") for part in unicode_parts]) >= 2:
|
| tqdm.write(f"=====\033[91m {id_part} skip (missing symbols) \033[0m=====")
|
| continue
|
|
|
| unicode = "\n".join(unicode_parts)
|
|
|
|
|
| if count_repetitions(unicode) / len(unicode) > 0.7:
|
| tqdm.write(f"=====\033[91m {id_part} skip (too repetitive) \033[0m=====")
|
| continue
|
|
|
|
|
| if DOWNLOAD_MODE:
|
| image = get_image(int(id_part))
|
| elif (IMG_CACHE / f"P{str(int(id_part)).rjust(6, '0')}.jpg").exists():
|
| image = Image.open(
|
| IMG_CACHE / f"P{str(int(id_part)).rjust(6, '0')}.jpg"
|
| ).convert("RGB")
|
| else:
|
| tqdm.write(f"=====\033[91m {id_part} skip (no img) \033[0m=====")
|
| continue
|
|
|
| if not image:
|
| tqdm.write(f"=====\033[91m {id_part} skip (no img) \033[0m=====")
|
| continue
|
|
|
|
|
| try:
|
| if min(image.size) < 100:
|
| tqdm.write(f"=====\033[91m {id_part} skip (lowres) \033[0m=====")
|
| continue
|
|
|
| scale = 150 / image.height
|
| small_image = image.resize(
|
| (int(image.width * scale), int(image.height * scale)), Image.LANCZOS
|
| )
|
| pixels = list(small_image.getdata())
|
| small_image.close()
|
| image.close()
|
|
|
| bw_pixels = sum(1 for r, g, b in pixels if r == g == b)
|
| bw_percent = bw_pixels / len(pixels)
|
| if bw_percent > 0.95 or bw_percent < 0.1:
|
| tqdm.write(
|
| f"=====\033[91m {id_part} skip (bw {bw_percent*100:.1f}%) \033[0m====="
|
| )
|
| continue
|
|
|
| if sum(1 for r, g, b in pixels if r == g == b == 0) / len(pixels) < 0.15:
|
| tqdm.write(
|
| f"=====\033[91m {id_part} skip (not on black background) \033[0m====="
|
| )
|
| continue
|
| except Exception as e:
|
| tqdm.write(
|
| f"=====\033[91m {id_part} skip (err img check: {e}) \033[0m====="
|
| )
|
| continue
|
|
|
| ids.append(int(id_part))
|
| atfs.append(atf)
|
| unicodes.append(unicode)
|
|
|
| tqdm.write(f"=====\033[32m {id_part} unicode (len {unicode_len}) \033[0m=====")
|
|
|
| dataset = Dataset.from_dict(
|
| {
|
| "id": ids,
|
| "atf": atfs,
|
| "unicode": unicodes,
|
| }
|
| )
|
|
|
| dataset.to_parquet(file)
|
| return dataset.train_test_split(test_size=1000, seed=42)
|
|
|