In [1]:
import argparse
import albumentations as A
import csv
import numpy as np
import onnxruntime as ort
import os
import yaml

from PIL import Image
from train import CutMax, ResizeWithPad

ModuleNotFoundError: No module named 'train'

In [None]:



# =========================
# PATH LOKAL MODEL & CONFIG
# =========================
BASE_PATH = r"C:\Users\fmaul\Documents\KULIAH\COMPRO\Font Classify\Checkpoint"

CONFIG_PATH = os.path.join(BASE_PATH, "model_config.yaml")
MODEL_PATH = os.path.join(BASE_PATH, "model.onnx")

MAPPING_PATH = r"C:\Users\fmaul\Documents\KULIAH\COMPRO\Font Classify\font-classify-main\google_fonts_mapping.tsv"


def parse_args():
 parser = argparse.ArgumentParser(
 description="Inference with pretrained Storia model (local)"
 )
 parser.add_argument(
 "--data_folder",
 type=str,
 default=r"C:\Users\fmaul\Documents\KULIAH\COMPRO\Font Classify\font-classify-main\sample_data\fonts",
 help="Path to images to run inference on",
 )
 return parser.parse_args()


def softmax(x):
 e_x = np.exp(x - np.max(x))
 return e_x / e_x.sum(axis=0)


def main(args):
 # ===== Load config =====
 with open(CONFIG_PATH, "r") as f:
 config = yaml.safe_load(f)

 input_size = config["size"]

 # ===== Load font mapping =====
 google_font_mapping = {}
 with open(MAPPING_PATH, "r", encoding="utf-8") as f:
 tsv_file = csv.reader(f, delimiter="\t")
 for i, row in enumerate(tsv_file):
 if i > 0:
 filename, font_name, version = row
 google_font_mapping[filename] = (font_name, version)

 # ===== Load ONNX model =====
 session = ort.InferenceSession(MODEL_PATH, providers=["CPUExecutionProvider"])

 # ===== Preprocessing =====
 transform = A.Compose(
 [
 A.Lambda(image=CutMax(1024)),
 A.Lambda(image=ResizeWithPad((input_size, input_size))),
 A.Normalize(mean=[0.485, 0.456, 0.406],
 std=[0.229, 0.224, 0.225]),
 ]
 )

 # ===== Inference =====
 for image_file in os.listdir(args.data_folder):
 image_path = os.path.join(args.data_folder, image_file)

 image = np.array(Image.open(image_path).convert("RGB"))
 image = transform(image=image)["image"]

 image = np.transpose(image, (2, 0, 1)) # HWC → CHW
 image = np.expand_dims(image, 0) # Add batch dim

 logits = session.run(None, {"input": image})[0][0]
 probs = softmax(logits)

 predicted = config["classnames"][probs.argmax()]
 font_name, version = google_font_mapping.get(predicted, ("Unknown", "-"))

 print(image_file, font_name, version)


if __name__ == "__main__":
 args = parse_args()
 main(args)
