Spaces:
Running
Running
| import argparse | |
| import albumentations as A | |
| import csv | |
| import huggingface_hub | |
| import numpy as np | |
| import onnxruntime as ort | |
| import os | |
| import yaml | |
| from PIL import Image | |
| from train import CutMax, ResizeWithPad | |
| CONFIG_PATH = huggingface_hub.hf_hub_download( | |
| repo_id="storia/font-classify-onnx", filename="model_config.yaml" | |
| ) | |
| MODEL_PATH = huggingface_hub.hf_hub_download( | |
| repo_id="storia/font-classify-onnx", filename="model.onnx" | |
| ) | |
| MAPPING_PATH = r"C:\Users\fmaul\Documents\KULIAH\COMPRO\Laravel\autentik-laravel\FastAPI\model\font-classify-main\google_fonts_mapping.tsv" | |
| def parse_args(): | |
| # Create an argument parser | |
| parser = argparse.ArgumentParser( | |
| description="Inference with pretrained model from Storia" | |
| ) | |
| parser.add_argument( | |
| "--data_folder", | |
| type=str, | |
| default=r"C:\Users\fmaul\Documents\KULIAH\COMPRO\Laravel\autentik-laravel\FastAPI\model\Checkpoint\image", | |
| help="Path to images to run inference on", | |
| ) | |
| args = parser.parse_args() | |
| return args | |
| def softmax(x): | |
| """Computes softmax values for each sets of scores in x.""" | |
| e_x = np.exp(x - np.max(x)) | |
| return e_x / e_x.sum(axis=0) # axis=0 for 2d array case | |
| def main(args): | |
| with open(CONFIG_PATH, "r") as f: | |
| config = yaml.safe_load(f) | |
| input_size = config["size"] | |
| google_font_mapping = {} | |
| with open(MAPPING_PATH, "r") as f: | |
| tsv_file = csv.reader(f, delimiter="\t") | |
| for i, row in enumerate(tsv_file): | |
| if i > 0: | |
| filename, font_name, version = row | |
| google_font_mapping[filename] = (font_name, version) | |
| session = ort.InferenceSession(MODEL_PATH) | |
| transform = A.Compose( | |
| [ | |
| A.Lambda(image=CutMax(1024)), | |
| A.Lambda(image=ResizeWithPad((input_size, input_size))), | |
| A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ] | |
| ) | |
| for image_file in os.listdir(args.data_folder): | |
| image_path = os.path.join(args.data_folder, image_file) | |
| image = np.array(Image.open(image_path).convert("RGB")) | |
| image = transform(image=image)["image"] | |
| # Move the channel dimension to the front. | |
| image = np.transpose(image, (2, 0, 1)) | |
| # Add a dummy batch dimension. | |
| image = np.expand_dims(image, 0) | |
| logits = session.run(None, {"input": image})[0][0] | |
| probs = softmax(logits) | |
| predicted = config["classnames"][probs.argmax(0)] | |
| print(image_file, *google_font_mapping.get(predicted)) | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| main(args) | |