Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.responses import HTMLResponse | |
| import numpy as np | |
| from PIL import Image | |
| import io | |
| import onnxruntime as ort | |
| from pydantic import BaseModel | |
| import time | |
| from pathlib import Path | |
| import cv2 | |
| import albumentations as A | |
| import pandas as pd | |
| import os | |
| from huggingface_hub import HfApi | |
| api = HfApi(token=os.getenv("jms_hf_token")) | |
| # Download model and embeddings from Hugging Face if not present | |
| model_dir = "app_models" | |
| # Create model directory if it doesn't exist | |
| os.makedirs(model_dir, exist_ok=True) | |
| recog_path_local = api.hf_hub_download(repo_id="KennethTM/jms-model", filename="recog_model.onnx", local_dir=model_dir, repo_type="model") | |
| corner_path_local = api.hf_hub_download(repo_id="KennethTM/jms-model", filename="corner_model.onnx", local_dir=model_dir, repo_type="model") | |
| card_data_path = api.hf_hub_download(repo_id="KennethTM/jms-model", filename="card_data_embeddings.parquet", local_dir=model_dir, repo_type="model") | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="just-mtg-scan", | |
| description="Just a Magic: The Gathering card scanner", | |
| version="1.0.0" | |
| ) | |
| # Load ONNX models | |
| corner_session = ort.InferenceSession(corner_path_local) | |
| recog_session = ort.InferenceSession(recog_path_local) | |
| # Load reference embeddings and card data | |
| df = pd.read_parquet(card_data_path) | |
| ref_embeddings = np.vstack(df['embedding'].values).astype(np.float32) | |
| card_metadata = df[['card_id', 'name', 'uri', 'card_url', 'image_url', 'lang', 'rarity', 'set_name', 'set']].to_dict('records') | |
| del df # Free DataFrame memory after extracting needed data | |
| task_config = { | |
| "recog": {"image_width": 160, "image_height": 224, "means": [0.5, 0.5, 0.5], "stds": [0.5, 0.5, 0.5]}, | |
| "corner": {"image_width": 256, "image_height": 256, "means": [0.5, 0.5, 0.5], "stds": [0.5, 0.5, 0.5]}, | |
| } | |
| def perspective_transform(image: np.ndarray, corners: np.ndarray) -> np.ndarray: | |
| h, w = image.shape[:2] | |
| # Denormalize corners | |
| pts = corners.reshape(4, 2) | |
| pts[:, 0] *= w | |
| pts[:, 1] *= h | |
| # Define destination points (rectangle with recognition dimensions) | |
| dst_width = task_config["recog"]["image_width"] | |
| dst_height = task_config["recog"]["image_height"] | |
| dst_pts = np.array([ | |
| [0, 0], | |
| [dst_width - 1, 0], | |
| [dst_width - 1, dst_height - 1], | |
| [0, dst_height - 1] | |
| ], dtype=np.float32) | |
| # Compute perspective transform matrix | |
| M = cv2.getPerspectiveTransform(pts.astype(np.float32), dst_pts) | |
| # Warp the image | |
| warped = cv2.warpPerspective(image, M, (dst_width, dst_height)) | |
| return warped | |
| onnx_transform_corner = A.Compose([ | |
| A.LongestMaxSize(max_size=task_config["corner"]["image_height"]), | |
| A.PadIfNeeded(min_height=task_config["corner"]["image_height"], | |
| min_width=task_config["corner"]["image_width"], | |
| border_mode=cv2.BORDER_CONSTANT, fill=0), | |
| ]) | |
| onnx_transform_recog = A.Resize(height=task_config["recog"]["image_height"], | |
| width=task_config["recog"]["image_width"], | |
| interpolation=cv2.INTER_LINEAR) | |
| def preprocess_onnx(image: np.ndarray, task: str) -> np.ndarray: | |
| if task == "recog" and image.shape[:2] != (task_config["recog"]["image_height"], task_config["recog"]["image_width"]): | |
| # Resize | |
| image = onnx_transform_recog(image=image)['image'] | |
| # If corner task, resize longest side to 256 and pad | |
| if task == "corner": | |
| image = onnx_transform_corner(image=image)['image'] | |
| # Convert to float32 and scale to [0, 1] | |
| image = image.astype(np.float32) / 255.0 | |
| # Normalize | |
| means = np.array(task_config[task]["means"], dtype=np.float32) | |
| stds = np.array(task_config[task]["stds"], dtype=np.float32) | |
| image = (image - means) / stds | |
| # Convert to CHW format | |
| image = np.transpose(image, (2, 0, 1)) | |
| # Add batch dimension | |
| image = np.expand_dims(image, axis=0) | |
| return image | |
| class Card(BaseModel): | |
| id: str | |
| name: str | |
| uri: str | |
| scryfall_uri: str | |
| image_url: str | |
| lang: str | |
| rarity: str | |
| set_name: str | |
| set: str | |
| prediction_time: int # milliseconds | |
| async def root(): | |
| """Serve the index.html file.""" | |
| html_path = Path(__file__).parent / "index.html" | |
| if not html_path.exists(): | |
| raise HTTPException(status_code=404, detail="index.html not found") | |
| return HTMLResponse(content=html_path.read_text(), status_code=200) | |
| async def predict(file: UploadFile = File(...)) -> Card: | |
| # Validate file type | |
| if not file.content_type.startswith('image/'): | |
| raise HTTPException(status_code=400, detail="File must be an image") | |
| # Read image | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)) | |
| # Convert to RGB if needed | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Image must to 256x256 | |
| if not (image.width == 256 and image.height == 256): | |
| raise HTTPException(status_code=400, detail="Image must be 256x256 pixels") | |
| # Convert PIL to numpy array | |
| image_rgb = np.array(image) | |
| # Start timing for entire inference process | |
| t0 = time.perf_counter() | |
| # Preprocess for corner detection | |
| corner_input = preprocess_onnx(image_rgb, task="corner") | |
| # Run corner model | |
| corner_outputs = corner_session.run(None, {corner_session.get_inputs()[0].name: corner_input}) | |
| corners = corner_outputs[0][0] # Shape: (8,) - normalized coordinates | |
| # Apply perspective transformation | |
| warped_image = perspective_transform(image_rgb, corners) | |
| # Preprocess warped image for recognition | |
| recog_input = preprocess_onnx(warped_image, task="recog") | |
| # Run recognition model | |
| recog_outputs = recog_session.run(None, {recog_session.get_inputs()[0].name: recog_input}) | |
| query_embedding = recog_outputs[0][0] # Shape: (embedding_dim,) | |
| # Compute cosine similarities | |
| similarities = np.dot(ref_embeddings, query_embedding) | |
| # Find best match | |
| best_idx = np.argmax(similarities) | |
| best_sim = float(similarities[best_idx]) | |
| # Retrieve card metadata from dataframe using integer position | |
| card_info = card_metadata[best_idx] | |
| # End timing | |
| t1 = time.perf_counter() | |
| prediction_time_ms = int((t1 - t0) * 1000) | |
| return Card( | |
| id=card_info["card_id"], | |
| name=card_info['name'], | |
| uri=card_info['uri'], | |
| scryfall_uri=card_info['card_url'], | |
| image_url=card_info['image_url'], | |
| lang=card_info['lang'], | |
| rarity=card_info['rarity'], | |
| set_name=card_info['set_name'], | |
| set=card_info['set'], | |
| prediction_time=prediction_time_ms | |
| ) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |