Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
|
@@ -25,7 +25,7 @@ os.makedirs(model_dir, exist_ok=True)
|
|
| 25 |
recog_path_local = api.hf_hub_download(repo_id="KennethTM/jms-model", filename="recog_model.onnx", local_dir=model_dir, repo_type="model")
|
| 26 |
corner_path_local = api.hf_hub_download(repo_id="KennethTM/jms-model", filename="corner_model.onnx", local_dir=model_dir, repo_type="model")
|
| 27 |
card_data_path = api.hf_hub_download(repo_id="KennethTM/jms-model", filename="card_data_minimal.parquet", local_dir=model_dir, repo_type="model")
|
| 28 |
-
card_embeddings_path = api.hf_hub_download(repo_id="KennethTM/jms-model", filename="
|
| 29 |
task_config_path = api.hf_hub_download(repo_id="KennethTM/jms-model", filename="task_config.json", local_dir=model_dir, repo_type="model")
|
| 30 |
|
| 31 |
# Initialize FastAPI app
|
|
@@ -41,10 +41,10 @@ recog_session = ort.InferenceSession(recog_path_local)
|
|
| 41 |
|
| 42 |
# Load reference embeddings and card data
|
| 43 |
df = pd.read_parquet(card_data_path)
|
| 44 |
-
ref_embeddings = np.load(card_embeddings_path)['embeddings']
|
| 45 |
|
| 46 |
# Pre-compute card info as list of dicts for faster access (avoid DataFrame iloc overhead)
|
| 47 |
-
card_metadata = df[['
|
| 48 |
del df # Free DataFrame memory after extracting needed data
|
| 49 |
|
| 50 |
with open(task_config_path) as f:
|
|
@@ -115,12 +115,9 @@ def preprocess_onnx(image: np.ndarray, task: str) -> np.ndarray:
|
|
| 115 |
return image
|
| 116 |
|
| 117 |
class Card(BaseModel):
|
| 118 |
-
id: str
|
| 119 |
name: str
|
| 120 |
-
uri: str
|
| 121 |
scryfall_uri: str
|
| 122 |
image_url: str
|
| 123 |
-
lang: str
|
| 124 |
rarity: str
|
| 125 |
set_name: str
|
| 126 |
set: str
|
|
@@ -191,12 +188,9 @@ async def predict(file: UploadFile = File(...)) -> Card:
|
|
| 191 |
prediction_time_ms = int((t1 - t0) * 1000)
|
| 192 |
|
| 193 |
return Card(
|
| 194 |
-
id=card_info['card_id'],
|
| 195 |
name=card_info['name'],
|
| 196 |
-
uri=card_info['uri'],
|
| 197 |
scryfall_uri=card_info['card_url'],
|
| 198 |
image_url=card_info['image_url'],
|
| 199 |
-
lang=card_info['lang'],
|
| 200 |
rarity=card_info['rarity'],
|
| 201 |
set_name=card_info['set_name'],
|
| 202 |
set=card_info['set'],
|
|
|
|
| 25 |
recog_path_local = api.hf_hub_download(repo_id="KennethTM/jms-model", filename="recog_model.onnx", local_dir=model_dir, repo_type="model")
|
| 26 |
corner_path_local = api.hf_hub_download(repo_id="KennethTM/jms-model", filename="corner_model.onnx", local_dir=model_dir, repo_type="model")
|
| 27 |
card_data_path = api.hf_hub_download(repo_id="KennethTM/jms-model", filename="card_data_minimal.parquet", local_dir=model_dir, repo_type="model")
|
| 28 |
+
card_embeddings_path = api.hf_hub_download(repo_id="KennethTM/jms-model", filename="card_embeddings_float16.npz", local_dir=model_dir, repo_type="model")
|
| 29 |
task_config_path = api.hf_hub_download(repo_id="KennethTM/jms-model", filename="task_config.json", local_dir=model_dir, repo_type="model")
|
| 30 |
|
| 31 |
# Initialize FastAPI app
|
|
|
|
| 41 |
|
| 42 |
# Load reference embeddings and card data
|
| 43 |
df = pd.read_parquet(card_data_path)
|
| 44 |
+
ref_embeddings = np.load(card_embeddings_path)['embeddings'].astype(np.float32)
|
| 45 |
|
| 46 |
# Pre-compute card info as list of dicts for faster access (avoid DataFrame iloc overhead)
|
| 47 |
+
card_metadata = df[['name', 'card_url', 'image_url', 'rarity', 'set_name', 'set']].to_dict('records')
|
| 48 |
del df # Free DataFrame memory after extracting needed data
|
| 49 |
|
| 50 |
with open(task_config_path) as f:
|
|
|
|
| 115 |
return image
|
| 116 |
|
| 117 |
class Card(BaseModel):
|
|
|
|
| 118 |
name: str
|
|
|
|
| 119 |
scryfall_uri: str
|
| 120 |
image_url: str
|
|
|
|
| 121 |
rarity: str
|
| 122 |
set_name: str
|
| 123 |
set: str
|
|
|
|
| 188 |
prediction_time_ms = int((t1 - t0) * 1000)
|
| 189 |
|
| 190 |
return Card(
|
|
|
|
| 191 |
name=card_info['name'],
|
|
|
|
| 192 |
scryfall_uri=card_info['card_url'],
|
| 193 |
image_url=card_info['image_url'],
|
|
|
|
| 194 |
rarity=card_info['rarity'],
|
| 195 |
set_name=card_info['set_name'],
|
| 196 |
set=card_info['set'],
|