just-mtg-scan / main.py
KennethTM's picture
Update main.py
8ea0aee verified
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import HTMLResponse
import numpy as np
from PIL import Image
import io
import onnxruntime as ort
from pydantic import BaseModel
import time
from pathlib import Path
import cv2
import albumentations as A
import pandas as pd
import os
from huggingface_hub import HfApi
api = HfApi(token=os.getenv("jms_hf_token"))
# Download model and embeddings from Hugging Face if not present
model_dir = "app_models"
# Create model directory if it doesn't exist
os.makedirs(model_dir, exist_ok=True)
recog_path_local = api.hf_hub_download(repo_id="KennethTM/jms-model", filename="recog_model.onnx", local_dir=model_dir, repo_type="model")
corner_path_local = api.hf_hub_download(repo_id="KennethTM/jms-model", filename="corner_model.onnx", local_dir=model_dir, repo_type="model")
card_data_path = api.hf_hub_download(repo_id="KennethTM/jms-model", filename="card_data_embeddings.parquet", local_dir=model_dir, repo_type="model")
# Initialize FastAPI app
app = FastAPI(
title="just-mtg-scan",
description="Just a Magic: The Gathering card scanner",
version="1.0.0"
)
# Load ONNX models
corner_session = ort.InferenceSession(corner_path_local)
recog_session = ort.InferenceSession(recog_path_local)
# Load reference embeddings and card data
df = pd.read_parquet(card_data_path)
ref_embeddings = np.vstack(df['embedding'].values).astype(np.float32)
card_metadata = df[['card_id', 'name', 'uri', 'card_url', 'image_url', 'lang', 'rarity', 'set_name', 'set']].to_dict('records')
del df # Free DataFrame memory after extracting needed data
task_config = {
"recog": {"image_width": 160, "image_height": 224, "means": [0.5, 0.5, 0.5], "stds": [0.5, 0.5, 0.5]},
"corner": {"image_width": 256, "image_height": 256, "means": [0.5, 0.5, 0.5], "stds": [0.5, 0.5, 0.5]},
}
def perspective_transform(image: np.ndarray, corners: np.ndarray) -> np.ndarray:
h, w = image.shape[:2]
# Denormalize corners
pts = corners.reshape(4, 2)
pts[:, 0] *= w
pts[:, 1] *= h
# Define destination points (rectangle with recognition dimensions)
dst_width = task_config["recog"]["image_width"]
dst_height = task_config["recog"]["image_height"]
dst_pts = np.array([
[0, 0],
[dst_width - 1, 0],
[dst_width - 1, dst_height - 1],
[0, dst_height - 1]
], dtype=np.float32)
# Compute perspective transform matrix
M = cv2.getPerspectiveTransform(pts.astype(np.float32), dst_pts)
# Warp the image
warped = cv2.warpPerspective(image, M, (dst_width, dst_height))
return warped
onnx_transform_corner = A.Compose([
A.LongestMaxSize(max_size=task_config["corner"]["image_height"]),
A.PadIfNeeded(min_height=task_config["corner"]["image_height"],
min_width=task_config["corner"]["image_width"],
border_mode=cv2.BORDER_CONSTANT, fill=0),
])
onnx_transform_recog = A.Resize(height=task_config["recog"]["image_height"],
width=task_config["recog"]["image_width"],
interpolation=cv2.INTER_LINEAR)
def preprocess_onnx(image: np.ndarray, task: str) -> np.ndarray:
if task == "recog" and image.shape[:2] != (task_config["recog"]["image_height"], task_config["recog"]["image_width"]):
# Resize
image = onnx_transform_recog(image=image)['image']
# If corner task, resize longest side to 256 and pad
if task == "corner":
image = onnx_transform_corner(image=image)['image']
# Convert to float32 and scale to [0, 1]
image = image.astype(np.float32) / 255.0
# Normalize
means = np.array(task_config[task]["means"], dtype=np.float32)
stds = np.array(task_config[task]["stds"], dtype=np.float32)
image = (image - means) / stds
# Convert to CHW format
image = np.transpose(image, (2, 0, 1))
# Add batch dimension
image = np.expand_dims(image, axis=0)
return image
class Card(BaseModel):
id: str
name: str
uri: str
scryfall_uri: str
image_url: str
lang: str
rarity: str
set_name: str
set: str
prediction_time: int # milliseconds
@app.get("/", response_class=HTMLResponse)
async def root():
"""Serve the index.html file."""
html_path = Path(__file__).parent / "index.html"
if not html_path.exists():
raise HTTPException(status_code=404, detail="index.html not found")
return HTMLResponse(content=html_path.read_text(), status_code=200)
@app.post("/predict")
async def predict(file: UploadFile = File(...)) -> Card:
# Validate file type
if not file.content_type.startswith('image/'):
raise HTTPException(status_code=400, detail="File must be an image")
# Read image
contents = await file.read()
image = Image.open(io.BytesIO(contents))
# Convert to RGB if needed
if image.mode != 'RGB':
image = image.convert('RGB')
# Image must to 256x256
if not (image.width == 256 and image.height == 256):
raise HTTPException(status_code=400, detail="Image must be 256x256 pixels")
# Convert PIL to numpy array
image_rgb = np.array(image)
# Start timing for entire inference process
t0 = time.perf_counter()
# Preprocess for corner detection
corner_input = preprocess_onnx(image_rgb, task="corner")
# Run corner model
corner_outputs = corner_session.run(None, {corner_session.get_inputs()[0].name: corner_input})
corners = corner_outputs[0][0] # Shape: (8,) - normalized coordinates
# Apply perspective transformation
warped_image = perspective_transform(image_rgb, corners)
# Preprocess warped image for recognition
recog_input = preprocess_onnx(warped_image, task="recog")
# Run recognition model
recog_outputs = recog_session.run(None, {recog_session.get_inputs()[0].name: recog_input})
query_embedding = recog_outputs[0][0] # Shape: (embedding_dim,)
# Compute cosine similarities
similarities = np.dot(ref_embeddings, query_embedding)
# Find best match
best_idx = np.argmax(similarities)
best_sim = float(similarities[best_idx])
# Retrieve card metadata from dataframe using integer position
card_info = card_metadata[best_idx]
# End timing
t1 = time.perf_counter()
prediction_time_ms = int((t1 - t0) * 1000)
return Card(
id=card_info["card_id"],
name=card_info['name'],
uri=card_info['uri'],
scryfall_uri=card_info['card_url'],
image_url=card_info['image_url'],
lang=card_info['lang'],
rarity=card_info['rarity'],
set_name=card_info['set_name'],
set=card_info['set'],
prediction_time=prediction_time_ms
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)