tabito12345678910 commited on
Commit
11e93ed
·
1 Parent(s): 8610011

Fix age range handling and model architecture

Browse files
Files changed (1) hide show
  1. inference_gohan_cid.py +19 -4
inference_gohan_cid.py CHANGED
@@ -55,7 +55,7 @@ class GohanCIDInferenceEngine:
55
  return None
56
  # Use training-script hyperparameters
57
  model = FTTransformer.make_baseline(
58
- n_num_features=8,
59
  cat_cardinalities=self.cat_cardinalities,
60
  d_out=len(self.all_cids),
61
  d_token=768, # Use the actual saved model's d_token
@@ -103,9 +103,24 @@ class GohanCIDInferenceEngine:
103
  X_cat.append(self._encode_categorical(self.cat_encoders[col], '__UNKNOWN__'))
104
  X_cat = torch.tensor([X_cat], dtype=torch.long)
105
 
106
- # Numerical features (8 features to match training script)
107
- num_cols = ['LAT', 'LONG', 'DELIVERY_NUM', 'MEDIAN_GENDER_RATIO', 'MODE_TOP_AGE_RANGE_1', 'MODE_TOP_AGE_RANGE_2', 'MODE_TOP_AGE_RANGE_3', 'TOTAL_VOLUME']
108
- X_num = torch.tensor([[float(df.get(c, pd.Series([0])).iloc[0]) for c in num_cols]], dtype=torch.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  return X_cat, X_num
110
 
111
  def predict(self, data: Dict) -> List[Dict]:
 
55
  return None
56
  # Use training-script hyperparameters
57
  model = FTTransformer.make_baseline(
58
+ n_num_features=5, # Updated: 5 numerical features (age ranges are now categorical)
59
  cat_cardinalities=self.cat_cardinalities,
60
  d_out=len(self.all_cids),
61
  d_token=768, # Use the actual saved model's d_token
 
103
  X_cat.append(self._encode_categorical(self.cat_encoders[col], '__UNKNOWN__'))
104
  X_cat = torch.tensor([X_cat], dtype=torch.long)
105
 
106
+ # Numerical features (5 features to match training script - age ranges are now categorical)
107
+ # Remove age range fields from numerical features since they're now categorical
108
+ num_cols = ['LAT', 'LONG', 'DELIVERY_NUM', 'MEDIAN_GENDER_RATIO', 'TOTAL_VOLUME']
109
+ X_num = []
110
+ for col in num_cols:
111
+ if col in df.columns:
112
+ try:
113
+ X_num.append(float(df[col].iloc[0]))
114
+ except (ValueError, TypeError):
115
+ X_num.append(0.0)
116
+ else:
117
+ # Provide default values for missing fields
118
+ if col == 'TOTAL_VOLUME':
119
+ X_num.append(0.0) # Default total volume
120
+ else:
121
+ X_num.append(0.0)
122
+
123
+ X_num = torch.tensor([X_num], dtype=torch.float32)
124
  return X_cat, X_num
125
 
126
  def predict(self, data: Dict) -> List[Dict]: