gohan-api-light / inference_gohan_cid.py
tabito12345678910
Fix syntax error: add missing except block in _load_encoders method
8e060d3
#!/usr/bin/env python3
"""
Gohan CID Product Recommendation Inference Engine
"""
import os
import sys
import json
import pickle
import torch
import numpy as np
import pandas as pd
from typing import Dict, List, Tuple
try:
from rtdl import FTTransformer
except ImportError:
FTTransformer = None
class GohanCIDInferenceEngine:
def __init__(self, model_path: str, encoders_dir: str, product_master_path: str):
self.model_path = model_path
self.encoders_dir = encoders_dir
self.product_master_path = product_master_path
self._load_encoders()
self._load_product_master()
self.model = self._load_model()
def _load_encoders(self):
# JSON first, fallback to pickle (legacy)
try:
with open(os.path.join(self.encoders_dir, 'idx_to_cid1.json'), 'r', encoding='utf-8') as f:
self.idx_to_cid = json.load(f)
with open(os.path.join(self.encoders_dir, 'all_cids1.json'), 'r', encoding='utf-8') as f:
self.all_cids = json.load(f)
with open(os.path.join(self.encoders_dir, 'cat_encoders1.json'), 'r', encoding='utf-8') as f:
self.cat_encoders = json.load(f)
with open(os.path.join(self.encoders_dir, 'cat_cardinalities1.json'), 'r', encoding='utf-8') as f:
self.cat_cardinalities = json.load(f)
except Exception as e:
print(f"Error loading encoders: {e}")
self.idx_to_cid = {}
self.all_cids = []
self.cat_encoders = {}
self.cat_cardinalities = []
def _load_product_master(self):
if os.path.exists(self.product_master_path):
pm = pd.read_csv(self.product_master_path, encoding='utf-8-sig')
# Normalize to uppercase columns if present in lowercase
cols = {c.lower(): c for c in pm.columns}
if 'category_id' in cols and 'category_name' in cols:
pm = pm.rename(columns={cols['category_id']: 'CATEGORY_ID', cols['category_name']: 'CATEGORY_NAME'})
self.product_master = pm
else:
self.product_master = pd.DataFrame(columns=['CATEGORY_ID','CATEGORY_NAME'])
def _load_model(self):
if FTTransformer is None:
return None
# Use training-script hyperparameters
model = FTTransformer.make_baseline(
n_num_features=5, # Updated: 5 numerical features (age ranges are now categorical)
cat_cardinalities=self.cat_cardinalities,
d_out=len(self.all_cids),
d_token=1024, # Use the actual saved model's d_token
n_blocks=8,
attention_dropout=0.15,
ffn_d_hidden=1024, # Use the actual saved model's ffn_d_hidden
ffn_dropout=0.15,
residual_dropout=0.10
)
if os.path.exists(self.model_path):
try:
state = torch.load(self.model_path, map_location='cpu')
model.load_state_dict(state)
except Exception as e:
print(f"⚠️ Could not load weights: {e}. Falling back to no-model mode.", file=sys.stderr)
return None
model.eval()
return model
def _encode_categorical(self, value_map: Dict[str, int], value: str) -> int:
if value in value_map:
return int(value_map[value])
return int(value_map.get('__UNKNOWN__', 0))
def _preprocess(self, data: Dict) -> Tuple[torch.Tensor, torch.Tensor]:
# Expect English field names as provided by the client
required_en = [
'INDUSTRY', 'EMPLOYEE_RANGE', 'FRIDGE_RANGE', 'PAYMENT_METHOD', 'PREFECTURE',
'FIRST_YEAR', 'FIRST_MONTH', 'LAT', 'LONG', 'DELIVERY_NUM', 'MEDIAN_GENDER_RATIO',
'MODE_TOP_AGE_RANGE_1', 'MODE_TOP_AGE_RANGE_2', 'MODE_TOP_AGE_RANGE_3'
]
missing = [k for k in required_en if k not in data]
if missing:
raise ValueError(f"Missing required inputs: {missing}")
df = pd.DataFrame([data])
# Categorical features: use any keys present in cat_encoders (assumed English keys)
X_cat = []
for col in self.cat_encoders.keys():
if col in df.columns:
v = df[col].iloc[0]
X_cat.append(self._encode_categorical(self.cat_encoders[col], v))
else:
X_cat.append(self._encode_categorical(self.cat_encoders[col], '__UNKNOWN__'))
X_cat = torch.tensor([X_cat], dtype=torch.long)
# Numerical features (5 features to match training script - age ranges are now categorical)
# Remove age range fields from numerical features since they're now categorical
num_cols = ['LAT', 'LONG', 'DELIVERY_NUM', 'MEDIAN_GENDER_RATIO', 'TOTAL_VOLUME']
X_num = []
for col in num_cols:
if col in df.columns:
try:
X_num.append(float(df[col].iloc[0]))
except (ValueError, TypeError):
X_num.append(0.0)
else:
# Provide default values for missing fields
if col == 'TOTAL_VOLUME':
X_num.append(0.0) # Default total volume
else:
X_num.append(0.0)
X_num = torch.tensor([X_num], dtype=torch.float32)
return X_cat, X_num
def predict(self, data: Dict) -> List[Dict]:
if self.model is None:
topK = int(data.get('topK', 30))
pm = self.product_master[['CATEGORY_ID','CATEGORY_NAME']].dropna()
rows = pm.head(topK).to_dict(orient='records')
return [
{"category_id": int(r['CATEGORY_ID']), "category_name": str(r['CATEGORY_NAME']), "score": 0.0}
for r in rows
]
X_cat, X_num = self._preprocess(data)
with torch.no_grad():
logits = self.model(X_num, X_cat)
scores = torch.sigmoid(logits).flatten().cpu().numpy()
indices = np.argsort(scores)[::-1]
cids = [self.all_cids[i] if i < len(self.all_cids) else None for i in indices]
pm = self.product_master
name_map = {int(row['CATEGORY_ID']): str(row['CATEGORY_NAME']) for _, row in pm.iterrows() if pd.notna(row['CATEGORY_ID'])}
results = []
for idx, cid in zip(indices, cids):
if cid is None:
continue
results.append({
"category_id": int(cid),
"category_name": name_map.get(int(cid), "不明"),
"score": float(scores[idx])
})
return results