Alimiji commited on
Commit
e97af1b
·
verified ·
1 Parent(s): 4f34da4

Upload src/api/main.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/api/main.py +301 -0
src/api/main.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI application for weather temperature prediction.
2
+
3
+ Deployed on Hugging Face Spaces: https://alimiji-weather-prediction-api.hf.space
4
+ Last deployment trigger: 2026-01-19
5
+ """
6
+
7
+ import json
8
+ import logging
9
+ import os
10
+ from contextlib import asynccontextmanager
11
+ from pathlib import Path
12
+
13
+ import joblib
14
+ import numpy as np
15
+ import requests
16
+ from fastapi import FastAPI, HTTPException
17
+ from fastapi.middleware.cors import CORSMiddleware
18
+
19
+ from .schemas import (
20
+ BatchPredictionRequest,
21
+ BatchPredictionResponse,
22
+ HealthResponse,
23
+ ModelMetrics,
24
+ PredictionResponse,
25
+ WeatherFeatures,
26
+ )
27
+
28
+ # Configure logging
29
+ logging.basicConfig(level=logging.INFO)
30
+ logger = logging.getLogger(__name__)
31
+
32
+ # Global model storage
33
+ model = None
34
+ model_info = None
35
+ metrics = None
36
+
37
+ # Paths
38
+ ROOT = Path(__file__).resolve().parent.parent.parent
39
+ MODEL_PATH = ROOT / "models" / "random_forest" / "Production" / "model.pkl"
40
+ MODEL_INFO_PATH = ROOT / "models" / "random_forest" / "Production" / "model_info.json"
41
+ METRICS_PATH = ROOT / "metrics.json"
42
+
43
+ # Feature order (must match training)
44
+ FEATURE_ORDER = [
45
+ "min_temp",
46
+ "max_temp",
47
+ "global_radiation",
48
+ "sunshine",
49
+ "cloud_cover",
50
+ "precipitation",
51
+ "pressure",
52
+ "snow_depth",
53
+ ]
54
+
55
+
56
+ def download_model_from_dagshub():
57
+ """Download model from DagsHub via HTTP."""
58
+ dagshub_token = os.getenv("DAGSHUB_TOKEN")
59
+ dagshub_user = os.getenv("DAGSHUB_USER", "Alimiji")
60
+
61
+ if not dagshub_token:
62
+ logger.warning("DAGSHUB_TOKEN not set, skipping model download")
63
+ return False
64
+
65
+ # DVC file hashes from dvc.lock
66
+ files_to_download = [
67
+ ("models/random_forest/Production/model_info.json", "006851e7426c173879e57b2b91201121"),
68
+ ("models/random_forest/Production/model.pkl", "44bebd223b998cf7e177aed1e73de3a6"),
69
+ ]
70
+
71
+ base_url = "https://dagshub.com/Alimiji/mlops_alimiji1.dvc/files/md5"
72
+
73
+ logger.info(f"ROOT path: {ROOT}")
74
+ logger.info(f"MODEL_PATH: {MODEL_PATH}")
75
+
76
+ try:
77
+ for file_path, md5_hash in files_to_download:
78
+ full_path = ROOT / file_path
79
+ logger.info(f"Target path: {full_path}")
80
+
81
+ full_path.parent.mkdir(parents=True, exist_ok=True)
82
+
83
+ # DagsHub DVC storage URL format: /files/md5/{first2chars}/{remaining}
84
+ url = f"{base_url}/{md5_hash[:2]}/{md5_hash[2:]}"
85
+ logger.info(f"Downloading {file_path} from {url}...")
86
+
87
+ response = requests.get(
88
+ url,
89
+ auth=(dagshub_user, dagshub_token),
90
+ stream=True,
91
+ timeout=(30, 600), # 30s connect, 600s read timeout for large file
92
+ )
93
+ response.raise_for_status()
94
+
95
+ total_size = int(response.headers.get("content-length", 0))
96
+ logger.info(f"File size: {total_size / 1024 / 1024:.1f} MB")
97
+
98
+ downloaded = 0
99
+ with open(full_path, "wb") as f:
100
+ for chunk in response.iter_content(chunk_size=65536):
101
+ f.write(chunk)
102
+ downloaded += len(chunk)
103
+
104
+ logger.info(f"Downloaded {file_path}: {downloaded / 1024 / 1024:.1f} MB")
105
+
106
+ # Verify file exists
107
+ if full_path.exists():
108
+ logger.info(f"Verified: {full_path} exists, size: {full_path.stat().st_size}")
109
+ else:
110
+ logger.error(f"File not found after download: {full_path}")
111
+ return False
112
+
113
+ return True
114
+
115
+ except requests.RequestException as e:
116
+ logger.error(f"Failed to download model: {e}")
117
+ return False
118
+ except Exception as e:
119
+ logger.error(f"Unexpected error downloading model: {e}")
120
+ import traceback
121
+
122
+ logger.error(traceback.format_exc())
123
+ return False
124
+
125
+
126
+ def load_model():
127
+ """Load the trained model and metadata."""
128
+ global model, model_info, metrics
129
+
130
+ # Try to download model if not present
131
+ if not MODEL_PATH.exists():
132
+ logger.info("Model not found locally, downloading from DagsHub...")
133
+ download_model_from_dagshub()
134
+
135
+ if not MODEL_PATH.exists():
136
+ logger.error(f"Model not found at {MODEL_PATH}")
137
+ raise FileNotFoundError(f"Model not found at {MODEL_PATH}")
138
+
139
+ logger.info(f"Loading model from {MODEL_PATH}")
140
+ model = joblib.load(MODEL_PATH)
141
+
142
+ if MODEL_INFO_PATH.exists():
143
+ model_info = json.loads(MODEL_INFO_PATH.read_text(encoding="utf-8"))
144
+ logger.info(f"Model info loaded: run_id={model_info.get('run_id', 'unknown')}")
145
+ else:
146
+ model_info = {"run_id": "unknown"}
147
+
148
+ if METRICS_PATH.exists():
149
+ metrics = json.loads(METRICS_PATH.read_text(encoding="utf-8"))
150
+ logger.info("Metrics loaded successfully")
151
+
152
+ logger.info("Model loaded successfully")
153
+
154
+
155
+ @asynccontextmanager
156
+ async def lifespan(app: FastAPI):
157
+ """Application lifespan manager."""
158
+ # Startup
159
+ try:
160
+ load_model()
161
+ except FileNotFoundError as e:
162
+ logger.warning(f"Model not loaded at startup: {e}")
163
+ yield
164
+ # Shutdown
165
+ logger.info("Shutting down API")
166
+
167
+
168
+ # Create FastAPI app
169
+ app = FastAPI(
170
+ title="Weather Temperature Prediction API",
171
+ description="API for predicting mean temperature based on weather features using Random Forest model",
172
+ version="1.0.0",
173
+ lifespan=lifespan,
174
+ )
175
+
176
+ # CORS middleware
177
+ app.add_middleware(
178
+ CORSMiddleware,
179
+ allow_origins=["*"],
180
+ allow_credentials=True,
181
+ allow_methods=["*"],
182
+ allow_headers=["*"],
183
+ )
184
+
185
+
186
+ @app.get("/", tags=["Root"])
187
+ async def root():
188
+ """Root endpoint with API information."""
189
+ return {
190
+ "message": "Weather Temperature Prediction API",
191
+ "docs": "/docs",
192
+ "health": "/health",
193
+ }
194
+
195
+
196
+ @app.get("/health", response_model=HealthResponse, tags=["Health"])
197
+ async def health_check():
198
+ """Check API health and model status."""
199
+ return HealthResponse(
200
+ status="healthy" if model is not None else "degraded",
201
+ model_loaded=model is not None,
202
+ model_path=str(MODEL_PATH) if model is not None else None,
203
+ )
204
+
205
+
206
+ @app.post("/predict", response_model=PredictionResponse, tags=["Predictions"])
207
+ async def predict(features: WeatherFeatures):
208
+ """
209
+ Predict mean temperature from weather features.
210
+
211
+ Returns the predicted mean temperature in Celsius.
212
+ """
213
+ if model is None:
214
+ raise HTTPException(status_code=503, detail="Model not loaded")
215
+
216
+ # Convert features to numpy array in correct order
217
+ feature_values = [getattr(features, name) for name in FEATURE_ORDER]
218
+ X = np.array([feature_values])
219
+
220
+ # Make prediction
221
+ prediction = model.predict(X)[0]
222
+
223
+ return PredictionResponse(
224
+ predicted_mean_temp=round(float(prediction), 2),
225
+ model_version=model_info.get("run_id", "unknown") if model_info else "unknown",
226
+ )
227
+
228
+
229
+ @app.post("/predict/batch", response_model=BatchPredictionResponse, tags=["Predictions"])
230
+ async def predict_batch(request: BatchPredictionRequest):
231
+ """
232
+ Batch prediction for multiple weather instances.
233
+
234
+ Accepts up to 1000 instances per request.
235
+ """
236
+ if model is None:
237
+ raise HTTPException(status_code=503, detail="Model not loaded")
238
+
239
+ # Convert all instances to numpy array
240
+ X = np.array([[getattr(instance, name) for name in FEATURE_ORDER] for instance in request.instances])
241
+
242
+ # Make predictions
243
+ predictions = model.predict(X)
244
+
245
+ return BatchPredictionResponse(
246
+ predictions=[round(float(p), 2) for p in predictions],
247
+ model_version=model_info.get("run_id", "unknown") if model_info else "unknown",
248
+ count=len(predictions),
249
+ )
250
+
251
+
252
+ @app.get("/metrics", response_model=ModelMetrics, tags=["Model Info"])
253
+ async def get_metrics():
254
+ """Get model performance metrics from the last evaluation."""
255
+ if metrics is None:
256
+ raise HTTPException(status_code=404, detail="Metrics not available")
257
+
258
+ return ModelMetrics(
259
+ train_rmse=metrics["train"]["rmse"],
260
+ train_mae=metrics["train"]["mae"],
261
+ train_r2=metrics["train"]["r2"],
262
+ valid_rmse=metrics["valid"]["rmse"],
263
+ valid_mae=metrics["valid"]["mae"],
264
+ valid_r2=metrics["valid"]["r2"],
265
+ test_rmse=metrics["test"]["rmse"],
266
+ test_mae=metrics["test"]["mae"],
267
+ test_r2=metrics["test"]["r2"],
268
+ )
269
+
270
+
271
+ @app.get("/model/info", tags=["Model Info"])
272
+ async def get_model_info():
273
+ """Get information about the loaded model."""
274
+ if model_info is None:
275
+ raise HTTPException(status_code=404, detail="Model info not available")
276
+
277
+ return {
278
+ "run_id": model_info.get("run_id"),
279
+ "experiment_name": model_info.get("experiment_name"),
280
+ "model_type": model_info.get("model_type"),
281
+ "params": model_info.get("params"),
282
+ "features": FEATURE_ORDER,
283
+ }
284
+
285
+
286
+ @app.post("/model/reload", tags=["Model Info"])
287
+ async def reload_model():
288
+ """Reload the model from disk."""
289
+ try:
290
+ load_model()
291
+ return {"status": "success", "message": "Model reloaded successfully"}
292
+ except FileNotFoundError as e:
293
+ raise HTTPException(status_code=404, detail=str(e))
294
+ except Exception as e:
295
+ raise HTTPException(status_code=500, detail=f"Failed to reload model: {str(e)}")
296
+
297
+
298
+ if __name__ == "__main__":
299
+ import uvicorn
300
+
301
+ uvicorn.run(app, host="0.0.0.0", port=8000)