{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "e4ca0fb0", "metadata": {}, "outputs": [], "source": [ "import torch\n", "from PIL import Image\n", "from tqdm.auto import tqdm\n", "from transformers import AutoModelForCausalLM, AutoProcessor" ] }, { "cell_type": "code", "execution_count": null, "id": "a961375e", "metadata": {}, "outputs": [], "source": [ "# Load dataset\n", "from get_cdli_dataset import get_dataset, IMG_CACHE\n", "\n", "dataset = get_dataset()\n", "test_dataset = dataset[\"test\"]\n", "\n", "print(test_dataset)" ] }, { "cell_type": "code", "execution_count": null, "id": "e226c45c", "metadata": {}, "outputs": [], "source": [ "# Load the model\n", "\n", "# model_path = \"PaddlePaddle/PaddleOCR-VL\" # base\n", "# model_path = \"./outputs/sft\"\n", "model_path = \"../\"\n", "\n", "model = AutoModelForCausalLM.from_pretrained(\n", " model_path, trust_remote_code=True, torch_dtype=torch.bfloat16\n", ").to(\"cuda\").eval()\n", "processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "97b9a2cb", "metadata": {}, "outputs": [], "source": [ "import pyxdameraulevenshtein as dl\n", "\n", "def compute_ter(expected_ids: list[int], predicted_ids: list[int]) -> float:\n", " \"\"\"\n", " Compute Token Error Rate (TER) between ground truth and completion tokens.\n", " TER = (substitutions + deletions + insertions) / len(ground_truth)\n", "\n", " TER is better than CER for cuneiform OCR as:\n", " - Multi-character Unicode signs count as 1 token instead of multiple chars\n", " - Special tokens like @obverse/@reverse count as 1 token\n", " \"\"\"\n", "\n", " if len(expected_ids) == 0:\n", " return 0.0 if len(predicted_ids) == 0 else 1.0\n", "\n", " # Calculate edit distance on token sequences\n", " distance = dl.damerau_levenshtein_distance(expected_ids, predicted_ids)\n", "\n", " # TER is the edit distance normalized by the truth token count\n", " ter = distance / max(1, len(expected_ids))\n", "\n", " return ter" ] }, { "cell_type": "code", "execution_count": null, "id": "859c4fc2", "metadata": {}, "outputs": [], "source": [ "# Run inference on all test examples\n", "results = []\n", "total_ter = 0.0\n", "\n", "pbar = tqdm(test_dataset, desc=\"Evaluating on test set\")\n", "\n", "for idx, example in enumerate(pbar):\n", " expected = example[\"unicode\"]\n", " expected_ids = processor.tokenizer.encode(expected, add_special_tokens = False)\n", "\n", " # Load image\n", " with Image.open(IMG_CACHE / f\"P{str(example['id']).rjust(6, '0')}.jpg\").convert(\n", " \"RGB\"\n", " ) as image:\n", " # Prepare input\n", " messages = [\n", " {\n", " \"role\": \"user\",\n", " \"content\": [\n", " {\"type\": \"image\", \"image\": image},\n", " {\"type\": \"text\", \"text\": \"OCR:\"},\n", " ],\n", " },\n", " ]\n", "\n", " inputs = processor.apply_chat_template(\n", " messages, \n", " tokenize=True, \n", " add_generation_prompt=True, \t\n", " return_dict=True,\n", " return_tensors=\"pt\"\n", " ).to(\"cuda\")\n", "\n", " # Generate prediction\n", " with torch.no_grad():\n", " output_ids = model.generate(\n", " **inputs,\n", " use_cache=True,\n", " max_new_tokens=int(len(expected_ids) * 1.2),\n", " repetition_penalty=1.03,\n", " )\n", "\n", " predicted_ids = output_ids[0][inputs[\"input_ids\"].shape[1] :][:-1].tolist()\n", "\n", " # Compute TER for this example\n", " ter = compute_ter(expected_ids, predicted_ids)\n", " total_ter += ter\n", "\n", " pbar.set_postfix_str(f\"AVG TER={total_ter / (idx+1):.3f}\")\n", "\n", " prediction = processor.decode(\n", " predicted_ids,\n", " skip_special_tokens=False,\n", " ).strip()\n", "\n", " # Store results\n", " results.append(\n", " {\n", " \"id\": example[\"id\"],\n", " \"expected\": expected,\n", " \"prediction\": prediction,\n", " \"ter\": ter,\n", " }\n", " )\n", " tqdm.write(f\"\\033[94m\\nID: {example['id']} | TER: {ter:.4f}\\033[0m\")\n", " tqdm.write(f\"\\033[92mExpected:\\033[0m\\n{expected}\")\n", " tqdm.write(f\"\\033[91mPredicted:\\033[0m\\n{prediction}\")\n", "\n", "# Compute averages\n", "average_ter = total_ter / len(test_dataset)\n", "print(f\"\\n{'='*60}\")\n", "print(f\"Average Token Error Rate (TER): {average_ter:.4f} ({average_ter*100:.2f}%)\")\n", "print(f\"{'='*60}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "3c6a8e02", "metadata": {}, "outputs": [], "source": [ "# Show examples: best and worst predictions (sorted by TER)\n", "sorted_results = sorted(results, key=lambda x: x[\"ter\"])\n", "\n", "print(\"=\"*60)\n", "print(\"BEST PREDICTIONS (Lowest TER)\")\n", "print(\"=\"*60)\n", "for i in range(min(10, len(sorted_results))):\n", " r = sorted_results[i]\n", " print(f\"\\nExample {i+1} - ID: {r['id']} - TER: {r['ter']:.4f}\")\n", " print(f\"Expected:\\n{r['expected']}\")\n", " print(f\"Predicted:\\n{r['prediction']}\")\n", " print(\"-\"*60)\n", "\n", "print(\"\\n\" + \"=\"*60)\n", "print(\"WORST PREDICTIONS (Highest TER)\")\n", "print(\"=\"*60)\n", "for i in range(min(10, len(sorted_results))):\n", " r = sorted_results[-(i+1)]\n", " print(f\"\\nExample {i+1} - ID: {r['id']} - TER: {r['ter']:.4f}\")\n", " print(f\"Expected:\\n{r['expected']}\")\n", " print(f\"Predicted:\\n{r['prediction']}\")\n", " print(\"-\"*60)" ] }, { "cell_type": "code", "execution_count": null, "id": "d5ceae30", "metadata": {}, "outputs": [], "source": [ "# TER and CER distribution statistics\n", "import numpy as np\n", "\n", "ter_values = [r[\"ter\"] for r in results]\n", "\n", "print(\"=\"*60)\n", "print(\"TER (TOKEN ERROR RATE) DISTRIBUTION STATISTICS\")\n", "print(\"=\"*60)\n", "print(f\"Mean TER: {np.mean(ter_values):.4f} ({np.mean(ter_values)*100:.2f}%)\")\n", "print(f\"Median TER: {np.median(ter_values):.4f} ({np.median(ter_values)*100:.2f}%)\")\n", "print(f\"Std Dev: {np.std(ter_values):.4f}\")\n", "print(f\"Min TER: {np.min(ter_values):.4f} ({np.min(ter_values)*100:.2f}%)\")\n", "print(f\"Max TER: {np.max(ter_values):.4f} ({np.max(ter_values)*100:.2f}%)\")\n", "print(f\"\\nPercentiles:\")\n", "print(f\" 25th: {np.percentile(ter_values, 25):.4f}\")\n", "print(f\" 50th: {np.percentile(ter_values, 50):.4f}\")\n", "print(f\" 75th: {np.percentile(ter_values, 75):.4f}\")\n", "print(f\" 90th: {np.percentile(ter_values, 90):.4f}\")\n", "print(f\" 95th: {np.percentile(ter_values, 95):.4f}\")\n", "print(f\" 98th: {np.percentile(ter_values, 98):.4f}\")\n", "\n", "# Count perfect predictions\n", "perfect_predictions = sum(1 for ter in ter_values if ter == 0.0)\n", "print(f\"\\nPerfect predictions (TER=0%): {perfect_predictions}/{len(ter_values)} ({perfect_predictions/len(ter_values)*100:.2f}%)\")\n", "\n", "# Count predictions with TER < 0.5 (less than 50% error)\n", "good_predictions = sum(1 for ter in ter_values if ter < 0.5)\n", "print(f\"Good predictions (TER<50%): {good_predictions}/{len(ter_values)} ({good_predictions/len(ter_values)*100:.2f}%)\")" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.6" } }, "nbformat": 4, "nbformat_minor": 5 }