MHamzaShahid commited on
Commit
a3c60c5
·
verified ·
1 Parent(s): 92c3eaa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -16
app.py CHANGED
@@ -10,7 +10,9 @@ from sklearn.pipeline import make_pipeline
10
  from sklearn.base import BaseEstimator, TransformerMixin
11
  from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
12
 
13
- # ========== 1️⃣ Define Custom Preprocessing Functions ==========
 
 
14
 
15
  def temp_cat(X):
16
  if isinstance(X, pd.DataFrame):
@@ -29,6 +31,14 @@ def temp_cat(X):
29
  )
30
  return X
31
 
 
 
 
 
 
 
 
 
32
  def proxy_humidity(X):
33
  if isinstance(X, pd.DataFrame):
34
  X["proxy_humidity"] = X["average_rain_fall_mm_per_year"] / (X["avg_temp"] + 1)
@@ -38,7 +48,49 @@ def proxy_humidity(X):
38
  X["proxy_humidity"] = X["average_rain_fall_mm_per_year"] / (X["avg_temp"] + 1)
39
  return X
40
 
41
- # ========== 2️⃣ Define Custom Transformer Class ==========
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  class CorrelationThresholdSelector(BaseEstimator, TransformerMixin):
44
  def __init__(self, threshold=0.9, target_threshold=0.0, method="pearson", min_variance=0.0):
@@ -75,22 +127,17 @@ class CorrelationThresholdSelector(BaseEstimator, TransformerMixin):
75
  target_corr_series = X_df.corrwith(y_series, method=self.method).abs().fillna(0.0)
76
  target_corr = target_corr_series.values
77
 
78
- visited = set()
79
- drops = set()
80
 
81
  for i in range(n_features):
82
  if i in visited or i in low_var_idx:
83
  continue
84
-
85
  correlated_idx = set(np.where(corr_mat[i] > self.threshold)[0].tolist())
86
  cluster = {i} | correlated_idx
87
  visited |= cluster
88
-
89
  if len(cluster) == 1:
90
  continue
91
-
92
  best = max(cluster, key=lambda idx: (target_corr[idx], X_df.iloc[:, idx].var()))
93
-
94
  if self.target_threshold > 0 and target_corr[best] < self.target_threshold:
95
  drops |= cluster
96
  else:
@@ -114,15 +161,24 @@ class CorrelationThresholdSelector(BaseEstimator, TransformerMixin):
114
  return X_arr[:, sel]
115
 
116
 
117
- # ========== 3️⃣ Register them for joblib to find ==========
 
 
118
  sys.modules['__main__'].temp_cat = temp_cat
 
119
  sys.modules['__main__'].proxy_humidity = proxy_humidity
120
  sys.modules['__main__'].CorrelationThresholdSelector = CorrelationThresholdSelector
121
 
122
- # ========== 4️⃣ Initialize FastAPI ==========
 
 
 
123
  app = FastAPI(title="🌾 Crop Yield Predictor API", version="1.0")
124
 
125
- # ========== 5️⃣ Load Trained Model ==========
 
 
 
126
  try:
127
  model = joblib.load("CropYieldPredictor.pkl")
128
  print("✅ Model loaded successfully!")
@@ -130,7 +186,10 @@ except Exception as e:
130
  print(f"❌ Error loading model: {e}")
131
  model = None
132
 
133
- # ========== 6️⃣ Define Input Schema ==========
 
 
 
134
  class CropInput(BaseModel):
135
  Area: str
136
  Item: str
@@ -139,11 +198,15 @@ class CropInput(BaseModel):
139
  pesticides_tonnes: float
140
  avg_temp: float
141
 
142
- # ========== 7️⃣ Routes ==========
 
 
 
143
  @app.get("/")
144
  def home():
145
  return {"message": "🌾 Crop Yield Predictor API is live and running!"}
146
 
 
147
  @app.post("/predict")
148
  def predict_yield(data: CropInput):
149
  if model is None:
@@ -159,11 +222,16 @@ def predict_yield(data: CropInput):
159
  "predicted_yield_kg_per_ha": float(predicted_yield_kg_ha),
160
  "message": "✅ Prediction successful!"
161
  }
162
-
163
  except Exception as e:
164
- return {"error": str(e), "message": "❌ Prediction failed due to preprocessing or feature mismatch."}
 
 
 
 
165
 
166
- # ========== 8️⃣ Local Run ==========
 
 
167
  if __name__ == "__main__":
168
  import uvicorn
169
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
10
  from sklearn.base import BaseEstimator, TransformerMixin
11
  from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
12
 
13
+ # ================================
14
+ # 1️⃣ Custom Preprocessing Functions
15
+ # ================================
16
 
17
  def temp_cat(X):
18
  if isinstance(X, pd.DataFrame):
 
31
  )
32
  return X
33
 
34
+
35
+ def clean(X):
36
+ if isinstance(X, pd.DataFrame):
37
+ return X.dropna()
38
+ else:
39
+ return pd.DataFrame(X).dropna()
40
+
41
+
42
  def proxy_humidity(X):
43
  if isinstance(X, pd.DataFrame):
44
  X["proxy_humidity"] = X["average_rain_fall_mm_per_year"] / (X["avg_temp"] + 1)
 
48
  X["proxy_humidity"] = X["average_rain_fall_mm_per_year"] / (X["avg_temp"] + 1)
49
  return X
50
 
51
+
52
+ # ================================
53
+ # 2️⃣ Transformers and Pipelines
54
+ # ================================
55
+
56
+ temp_cat_transformer = FunctionTransformer(temp_cat)
57
+ temp_cat_pipeline = make_pipeline(
58
+ temp_cat_transformer,
59
+ OrdinalEncoder(
60
+ handle_unknown='use_encoded_value',
61
+ unknown_value=-1
62
+ )
63
+ )
64
+
65
+ clean_transformer = FunctionTransformer(clean)
66
+ clean_pipeline = make_pipeline(
67
+ clean_transformer,
68
+ StandardScaler()
69
+ )
70
+
71
+ cat_pipeline = make_pipeline(
72
+ SimpleImputer(strategy="most_frequent"),
73
+ OrdinalEncoder(handle_unknown='use_encoded_value', unknown_value=-1)
74
+ )
75
+
76
+ proxy_humidity_transformer = FunctionTransformer(proxy_humidity)
77
+ proxy_humidity_pipeline = make_pipeline(
78
+ proxy_humidity_transformer,
79
+ StandardScaler()
80
+ )
81
+
82
+ square_transformer = FunctionTransformer(np.square)
83
+ square_pipeline = make_pipeline(square_transformer, StandardScaler())
84
+
85
+ log_transformer = FunctionTransformer(np.log1p)
86
+ log_pipeline = make_pipeline(log_transformer, StandardScaler())
87
+
88
+ default_num_pipeline = make_pipeline(StandardScaler())
89
+
90
+
91
+ # ================================
92
+ # 3️⃣ Custom Feature Selector
93
+ # ================================
94
 
95
  class CorrelationThresholdSelector(BaseEstimator, TransformerMixin):
96
  def __init__(self, threshold=0.9, target_threshold=0.0, method="pearson", min_variance=0.0):
 
127
  target_corr_series = X_df.corrwith(y_series, method=self.method).abs().fillna(0.0)
128
  target_corr = target_corr_series.values
129
 
130
+ visited, drops = set(), set()
 
131
 
132
  for i in range(n_features):
133
  if i in visited or i in low_var_idx:
134
  continue
 
135
  correlated_idx = set(np.where(corr_mat[i] > self.threshold)[0].tolist())
136
  cluster = {i} | correlated_idx
137
  visited |= cluster
 
138
  if len(cluster) == 1:
139
  continue
 
140
  best = max(cluster, key=lambda idx: (target_corr[idx], X_df.iloc[:, idx].var()))
 
141
  if self.target_threshold > 0 and target_corr[best] < self.target_threshold:
142
  drops |= cluster
143
  else:
 
161
  return X_arr[:, sel]
162
 
163
 
164
+ # ================================
165
+ # 4️⃣ Register All Functions for joblib
166
+ # ================================
167
  sys.modules['__main__'].temp_cat = temp_cat
168
+ sys.modules['__main__'].clean = clean
169
  sys.modules['__main__'].proxy_humidity = proxy_humidity
170
  sys.modules['__main__'].CorrelationThresholdSelector = CorrelationThresholdSelector
171
 
172
+
173
+ # ================================
174
+ # 5️⃣ Initialize FastAPI
175
+ # ================================
176
  app = FastAPI(title="🌾 Crop Yield Predictor API", version="1.0")
177
 
178
+
179
+ # ================================
180
+ # 6️⃣ Load Model
181
+ # ================================
182
  try:
183
  model = joblib.load("CropYieldPredictor.pkl")
184
  print("✅ Model loaded successfully!")
 
186
  print(f"❌ Error loading model: {e}")
187
  model = None
188
 
189
+
190
+ # ================================
191
+ # 7️⃣ Define Input Schema
192
+ # ================================
193
  class CropInput(BaseModel):
194
  Area: str
195
  Item: str
 
198
  pesticides_tonnes: float
199
  avg_temp: float
200
 
201
+
202
+ # ================================
203
+ # 8️⃣ Routes
204
+ # ================================
205
  @app.get("/")
206
  def home():
207
  return {"message": "🌾 Crop Yield Predictor API is live and running!"}
208
 
209
+
210
  @app.post("/predict")
211
  def predict_yield(data: CropInput):
212
  if model is None:
 
222
  "predicted_yield_kg_per_ha": float(predicted_yield_kg_ha),
223
  "message": "✅ Prediction successful!"
224
  }
 
225
  except Exception as e:
226
+ return {
227
+ "error": str(e),
228
+ "message": "❌ Prediction failed due to preprocessing or feature mismatch."
229
+ }
230
+
231
 
232
+ # ================================
233
+ # 9️⃣ Local Run
234
+ # ================================
235
  if __name__ == "__main__":
236
  import uvicorn
237
  uvicorn.run(app, host="0.0.0.0", port=7860)