tabito123 commited on
Commit
0fa4db6
·
verified ·
1 Parent(s): d401fb9

Update inference_gohan_cid.py

Browse files
Files changed (1) hide show
  1. inference_gohan_cid.py +1 -1
inference_gohan_cid.py CHANGED
@@ -104,7 +104,7 @@ class GohanCIDInferenceEngine:
104
  X_cat = torch.tensor([X_cat], dtype=torch.long)
105
 
106
  # Numerical features (8 features to match training script)
107
- num_cols = ['LAT', 'LONG', 'DELIVERY_NUM', 'MEDIAN_GENDER_RATIO', 'MODE_TOP_AGE_RANGE_1', 'MODE_TOP_AGE_RANGE_2', 'MODE_TOP_AGE_RANGE_3', 'TOTAL_VOLUME']
108
  X_num = torch.tensor([[float(df.get(c, pd.Series([0])).iloc[0]) for c in num_cols]], dtype=torch.float32)
109
  return X_cat, X_num
110
 
 
104
  X_cat = torch.tensor([X_cat], dtype=torch.long)
105
 
106
  # Numerical features (8 features to match training script)
107
+ num_cols = ['LAT', 'LONG', 'DELIVERY_NUM', 'MEDIAN_GENDER_RATIO', 'TOTAL_VOLUME']
108
  X_num = torch.tensor([[float(df.get(c, pd.Series([0])).iloc[0]) for c in num_cols]], dtype=torch.float32)
109
  return X_cat, X_num
110