Spaces:
Runtime error
Runtime error
tabito12345678910 commited on
Commit ·
11e93ed
1
Parent(s): 8610011
Fix age range handling and model architecture
Browse files- inference_gohan_cid.py +19 -4
inference_gohan_cid.py
CHANGED
|
@@ -55,7 +55,7 @@ class GohanCIDInferenceEngine:
|
|
| 55 |
return None
|
| 56 |
# Use training-script hyperparameters
|
| 57 |
model = FTTransformer.make_baseline(
|
| 58 |
-
n_num_features=
|
| 59 |
cat_cardinalities=self.cat_cardinalities,
|
| 60 |
d_out=len(self.all_cids),
|
| 61 |
d_token=768, # Use the actual saved model's d_token
|
|
@@ -103,9 +103,24 @@ class GohanCIDInferenceEngine:
|
|
| 103 |
X_cat.append(self._encode_categorical(self.cat_encoders[col], '__UNKNOWN__'))
|
| 104 |
X_cat = torch.tensor([X_cat], dtype=torch.long)
|
| 105 |
|
| 106 |
-
# Numerical features (
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
return X_cat, X_num
|
| 110 |
|
| 111 |
def predict(self, data: Dict) -> List[Dict]:
|
|
|
|
| 55 |
return None
|
| 56 |
# Use training-script hyperparameters
|
| 57 |
model = FTTransformer.make_baseline(
|
| 58 |
+
n_num_features=5, # Updated: 5 numerical features (age ranges are now categorical)
|
| 59 |
cat_cardinalities=self.cat_cardinalities,
|
| 60 |
d_out=len(self.all_cids),
|
| 61 |
d_token=768, # Use the actual saved model's d_token
|
|
|
|
| 103 |
X_cat.append(self._encode_categorical(self.cat_encoders[col], '__UNKNOWN__'))
|
| 104 |
X_cat = torch.tensor([X_cat], dtype=torch.long)
|
| 105 |
|
| 106 |
+
# Numerical features (5 features to match training script - age ranges are now categorical)
|
| 107 |
+
# Remove age range fields from numerical features since they're now categorical
|
| 108 |
+
num_cols = ['LAT', 'LONG', 'DELIVERY_NUM', 'MEDIAN_GENDER_RATIO', 'TOTAL_VOLUME']
|
| 109 |
+
X_num = []
|
| 110 |
+
for col in num_cols:
|
| 111 |
+
if col in df.columns:
|
| 112 |
+
try:
|
| 113 |
+
X_num.append(float(df[col].iloc[0]))
|
| 114 |
+
except (ValueError, TypeError):
|
| 115 |
+
X_num.append(0.0)
|
| 116 |
+
else:
|
| 117 |
+
# Provide default values for missing fields
|
| 118 |
+
if col == 'TOTAL_VOLUME':
|
| 119 |
+
X_num.append(0.0) # Default total volume
|
| 120 |
+
else:
|
| 121 |
+
X_num.append(0.0)
|
| 122 |
+
|
| 123 |
+
X_num = torch.tensor([X_num], dtype=torch.float32)
|
| 124 |
return X_cat, X_num
|
| 125 |
|
| 126 |
def predict(self, data: Dict) -> List[Dict]:
|