{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "a29bcadb", "metadata": {}, "outputs": [ { "ename": "ModuleNotFoundError", "evalue": "No module named 'train'", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", "Cell \u001b[1;32mIn[1], line 10\u001b[0m\n\u001b[0;32m 7\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01myaml\u001b[39;00m\n\u001b[0;32m 9\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mPIL\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Image\n\u001b[1;32m---> 10\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtrain\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m CutMax, ResizeWithPad\n", "\u001b[1;31mModuleNotFoundError\u001b[0m: No module named 'train'" ] } ], "source": [ "import argparse\n", "import albumentations as A\n", "import csv\n", "import numpy as np\n", "import onnxruntime as ort\n", "import os\n", "import yaml\n", "\n", "from PIL import Image\n", "from train import CutMax, ResizeWithPad" ] }, { "cell_type": "code", "execution_count": null, "id": "133ad4f9", "metadata": {}, "outputs": [], "source": [ "\n", "\n", "\n", "# =========================\n", "# PATH LOKAL MODEL & CONFIG\n", "# =========================\n", "BASE_PATH = r\"C:\\Users\\fmaul\\Documents\\KULIAH\\COMPRO\\Font Classify\\Checkpoint\"\n", "\n", "CONFIG_PATH = os.path.join(BASE_PATH, \"model_config.yaml\")\n", "MODEL_PATH = os.path.join(BASE_PATH, \"model.onnx\")\n", "\n", "MAPPING_PATH = r\"C:\\Users\\fmaul\\Documents\\KULIAH\\COMPRO\\Font Classify\\font-classify-main\\google_fonts_mapping.tsv\"\n", "\n", "\n", "def parse_args():\n", " parser = argparse.ArgumentParser(\n", " description=\"Inference with pretrained Storia model (local)\"\n", " )\n", " parser.add_argument(\n", " \"--data_folder\",\n", " type=str,\n", " default=r\"C:\\Users\\fmaul\\Documents\\KULIAH\\COMPRO\\Font Classify\\font-classify-main\\sample_data\\fonts\",\n", " help=\"Path to images to run inference on\",\n", " )\n", " return parser.parse_args()\n", "\n", "\n", "def softmax(x):\n", " e_x = np.exp(x - np.max(x))\n", " return e_x / e_x.sum(axis=0)\n", "\n", "\n", "def main(args):\n", " # ===== Load config =====\n", " with open(CONFIG_PATH, \"r\") as f:\n", " config = yaml.safe_load(f)\n", "\n", " input_size = config[\"size\"]\n", "\n", " # ===== Load font mapping =====\n", " google_font_mapping = {}\n", " with open(MAPPING_PATH, \"r\", encoding=\"utf-8\") as f:\n", " tsv_file = csv.reader(f, delimiter=\"\\t\")\n", " for i, row in enumerate(tsv_file):\n", " if i > 0:\n", " filename, font_name, version = row\n", " google_font_mapping[filename] = (font_name, version)\n", "\n", " # ===== Load ONNX model =====\n", " session = ort.InferenceSession(MODEL_PATH, providers=[\"CPUExecutionProvider\"])\n", "\n", " # ===== Preprocessing =====\n", " transform = A.Compose(\n", " [\n", " A.Lambda(image=CutMax(1024)),\n", " A.Lambda(image=ResizeWithPad((input_size, input_size))),\n", " A.Normalize(mean=[0.485, 0.456, 0.406],\n", " std=[0.229, 0.224, 0.225]),\n", " ]\n", " )\n", "\n", " # ===== Inference =====\n", " for image_file in os.listdir(args.data_folder):\n", " image_path = os.path.join(args.data_folder, image_file)\n", "\n", " image = np.array(Image.open(image_path).convert(\"RGB\"))\n", " image = transform(image=image)[\"image\"]\n", "\n", " image = np.transpose(image, (2, 0, 1)) # HWC → CHW\n", " image = np.expand_dims(image, 0) # Add batch dim\n", "\n", " logits = session.run(None, {\"input\": image})[0][0]\n", " probs = softmax(logits)\n", "\n", " predicted = config[\"classnames\"][probs.argmax()]\n", " font_name, version = google_font_mapping.get(predicted, (\"Unknown\", \"-\"))\n", "\n", " print(image_file, font_name, version)\n", "\n", "\n", "if __name__ == \"__main__\":\n", " args = parse_args()\n", " main(args)\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 }