Spaces:
Sleeping
Sleeping
File size: 7,062 Bytes
90ae4e9 0feb6a7 7370a45 0feb6a7 90ae4e9 b5303fa 90ae4e9 b5303fa 0feb6a7 7370a45 55d526b 7370a45 90ae4e9 61a3921 0feb6a7 90ae4e9 b5303fa 0feb6a7 55d526b 61a3921 7370a45 55d526b 8ea0aee 7370a45 61a3921 0feb6a7 61a3921 0feb6a7 90ae4e9 0feb6a7 90ae4e9 0feb6a7 90ae4e9 0feb6a7 90ae4e9 0feb6a7 90ae4e9 0feb6a7 90ae4e9 61a3921 0feb6a7 61a3921 0feb6a7 61a3921 0feb6a7 61a3921 0feb6a7 61a3921 0feb6a7 61a3921 7370a45 8ea0aee 0feb6a7 a50bc66 0feb6a7 90ae4e9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 | from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import HTMLResponse
import numpy as np
from PIL import Image
import io
import onnxruntime as ort
from pydantic import BaseModel
import time
from pathlib import Path
import cv2
import albumentations as A
import pandas as pd
import os
import json
from huggingface_hub import HfApi
api = HfApi(token=os.getenv("jms_hf_token"))
# Download model and embeddings from Hugging Face if not present
model_dir = "app_models"
# Create model directory if it doesn't exist
os.makedirs(model_dir, exist_ok=True)
recog_path_local = api.hf_hub_download(repo_id="KennethTM/jms-model", filename="recog_model.onnx", local_dir=model_dir, repo_type="model")
corner_path_local = api.hf_hub_download(repo_id="KennethTM/jms-model", filename="corner_model.onnx", local_dir=model_dir, repo_type="model")
card_data_path = api.hf_hub_download(repo_id="KennethTM/jms-model", filename="card_data_minimal.parquet", local_dir=model_dir, repo_type="model")
card_embeddings_path = api.hf_hub_download(repo_id="KennethTM/jms-model", filename="card_embeddings_float16.npz", local_dir=model_dir, repo_type="model")
task_config_path = api.hf_hub_download(repo_id="KennethTM/jms-model", filename="task_config.json", local_dir=model_dir, repo_type="model")
# Initialize FastAPI app
app = FastAPI(
title="just-mtg-scan",
description="Just a Magic: The Gathering card scanner",
version="1.0.0"
)
# Load ONNX models
corner_session = ort.InferenceSession(corner_path_local)
recog_session = ort.InferenceSession(recog_path_local)
# Load reference embeddings and card data
df = pd.read_parquet(card_data_path)
ref_embeddings = np.load(card_embeddings_path)['embeddings'].astype(np.float32)
# Pre-compute card info as list of dicts for faster access (avoid DataFrame iloc overhead)
card_metadata = df[['name', 'card_url', 'image_url', 'rarity', 'set_name', 'set']].to_dict('records')
del df # Free DataFrame memory after extracting needed data
with open(task_config_path) as f:
task_config = json.load(f)
def perspective_transform(image: np.ndarray, corners: np.ndarray) -> np.ndarray:
h, w = image.shape[:2]
# Denormalize corners
pts = corners.reshape(4, 2)
pts[:, 0] *= w
pts[:, 1] *= h
# Define destination points (rectangle with recognition dimensions)
dst_width = task_config["recog"]["image_width"]
dst_height = task_config["recog"]["image_height"]
dst_pts = np.array([
[0, 0],
[dst_width - 1, 0],
[dst_width - 1, dst_height - 1],
[0, dst_height - 1]
], dtype=np.float32)
# Compute perspective transform matrix
M = cv2.getPerspectiveTransform(pts.astype(np.float32), dst_pts)
# Warp the image
warped = cv2.warpPerspective(image, M, (dst_width, dst_height))
return warped
onnx_transform_corner = A.Compose([
A.LongestMaxSize(max_size=task_config["corner"]["image_height"]),
A.PadIfNeeded(min_height=task_config["corner"]["image_height"],
min_width=task_config["corner"]["image_width"],
border_mode=cv2.BORDER_CONSTANT, fill=0),
])
onnx_transform_recog = A.Resize(height=task_config["recog"]["image_height"],
width=task_config["recog"]["image_width"],
interpolation=cv2.INTER_LINEAR)
def preprocess_onnx(image: np.ndarray, task: str) -> np.ndarray:
if task == "recog" and image.shape[:2] != (task_config["recog"]["image_height"], task_config["recog"]["image_width"]):
# Resize
image = onnx_transform_recog(image=image)['image']
# If corner task, resize longest side to 256 and pad
if task == "corner":
image = onnx_transform_corner(image=image)['image']
# Convert to float32 and scale to [0, 1]
image = image.astype(np.float32) / 255.0
# Normalize
means = np.array(task_config[task]["means"], dtype=np.float32)
stds = np.array(task_config[task]["stds"], dtype=np.float32)
image = (image - means) / stds
# Convert to CHW format
image = np.transpose(image, (2, 0, 1))
# Add batch dimension
image = np.expand_dims(image, axis=0)
return image
class Card(BaseModel):
name: str
scryfall_uri: str
image_url: str
rarity: str
set_name: str
set: str
prediction_time: int # milliseconds
@app.get("/", response_class=HTMLResponse)
async def root():
"""Serve the index.html file."""
html_path = Path(__file__).parent / "index.html"
if not html_path.exists():
raise HTTPException(status_code=404, detail="index.html not found")
return HTMLResponse(content=html_path.read_text(), status_code=200)
@app.post("/predict")
async def predict(file: UploadFile = File(...)) -> Card:
# Validate file type
if not file.content_type.startswith('image/'):
raise HTTPException(status_code=400, detail="File must be an image")
# Read image
contents = await file.read()
image = Image.open(io.BytesIO(contents))
# Convert to RGB if needed
if image.mode != 'RGB':
image = image.convert('RGB')
# Image must to 256x256
if not (image.width == 256 and image.height == 256):
raise HTTPException(status_code=400, detail="Image must be 256x256 pixels")
# Convert PIL to numpy array
image_rgb = np.array(image)
# Start timing for entire inference process
t0 = time.perf_counter()
# Preprocess for corner detection
corner_input = preprocess_onnx(image_rgb, task="corner")
# Run corner model
corner_outputs = corner_session.run(None, {corner_session.get_inputs()[0].name: corner_input})
corners = corner_outputs[0][0] # Shape: (8,) - normalized coordinates
# Apply perspective transformation
warped_image = perspective_transform(image_rgb, corners)
# Preprocess warped image for recognition
recog_input = preprocess_onnx(warped_image, task="recog")
# Run recognition model
recog_outputs = recog_session.run(None, {recog_session.get_inputs()[0].name: recog_input})
query_embedding = recog_outputs[0][0] # Shape: (embedding_dim,)
# Compute cosine similarities
similarities = np.dot(ref_embeddings, query_embedding)
# Find best match
best_idx = np.argmax(similarities)
best_sim = float(similarities[best_idx])
# Retrieve card metadata from pre-computed list (much faster than DataFrame iloc)
card_info = card_metadata[best_idx]
# End timing
t1 = time.perf_counter()
prediction_time_ms = int((t1 - t0) * 1000)
return Card(
name=card_info['name'],
scryfall_uri=card_info['card_url'],
image_url=card_info['image_url'],
rarity=card_info['rarity'],
set_name=card_info['set_name'],
set=card_info['set'],
prediction_time=prediction_time_ms
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
|