Kshitijk20 commited on
Commit
fc4ec5a
·
1 Parent(s): 679faff

lazy model loading

Browse files
Files changed (1) hide show
  1. app.py +65 -13
app.py CHANGED
@@ -29,6 +29,34 @@ from src.utils.main_utils.utils import load_object, save_object
29
  from fastapi.templating import Jinja2Templates
30
  templates = Jinja2Templates(directory="./templates")
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  # Cache for loaded models
33
  MODEL_CACHE = {"model": None, "preprocessor": None}
34
  MLFLOW_AVAILABLE = True # Assume available, model_trainer.py handles initialization
@@ -96,7 +124,22 @@ def load_models_from_mlflow():
96
  async def lifespan(app: FastAPI):
97
  """Initialize application on startup"""
98
  logging.info("===== Application Startup =====")
99
- logging.info("⚠️ Models will be loaded on first /train or /predict request")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  logging.info("✅ Application ready to serve requests")
101
 
102
  yield
@@ -122,17 +165,24 @@ app.add_middleware(
122
  @app.get("/")
123
  async def root():
124
  """Root endpoint with system status"""
125
- model_status = "✅ Ready" if MODEL_CACHE["model"] is not None else "⚠️ Not trained - call /train first"
 
 
 
 
 
 
126
 
127
  return {
128
  "status": "running",
129
  "service": "Network Security System - Phishing Detection",
130
  "model_status": model_status,
 
131
  "mlflow_enabled": MLFLOW_AVAILABLE,
132
  "endpoints": {
133
  "docs": "/docs",
134
- "train": "/train (trains and logs to MLflow)",
135
- "predict": "/predict (loads from MLflow)"
136
  }
137
  }
138
 
@@ -154,10 +204,13 @@ async def training_route():
154
  @app.post("/predict") # predict route
155
  async def predict_route(request: Request, file: UploadFile =File(...)):
156
  try:
157
- # Check if models are loaded
158
- if MODEL_CACHE["model"] is None or MODEL_CACHE["preprocessor"] is None:
159
- # Try to load from MLflow
160
- if not load_models_from_mlflow():
 
 
 
161
  return Response(
162
  "❌ No trained model available. Please call /train endpoint first.",
163
  status_code=400
@@ -168,17 +221,16 @@ async def predict_route(request: Request, file: UploadFile =File(...)):
168
  if 'Result' in df.columns:
169
  df = df.drop(columns=['Result'])
170
 
171
- # Use cached models from MLflow
172
- preprocessor = MODEL_CACHE["preprocessor"]
173
- model = MODEL_CACHE["model"]
174
 
175
  NSmodel = NetworkSecurityModel(preprocessing_object=preprocessor, trained_model_object=model)
176
  y_pred = NSmodel.predict(df)
177
  df['predicted_column'] = y_pred
178
 
179
  # Save predictions
180
- os.makedirs("final_model", exist_ok=True)
181
- df.to_csv("final_model/predicted.csv")
182
 
183
  table_html = df.to_html(classes='table table-striped')
184
  return templates.TemplateResponse("table.html", {"request": request, "table": table_html})
 
29
  from fastapi.templating import Jinja2Templates
30
  templates = Jinja2Templates(directory="./templates")
31
 
32
+ # Persistent storage paths
33
+ PERSISTENT_MODEL_DIR = "/data/models"
34
+ LOCAL_MODEL_DIR = "final_model"
35
+
36
+ def restore_models_from_persistent_storage():
37
+ """Restore models from HuggingFace persistent storage to local directory"""
38
+ try:
39
+ persistent_model = f"{PERSISTENT_MODEL_DIR}/model.pkl"
40
+ persistent_preprocessor = f"{PERSISTENT_MODEL_DIR}/preprocessor.pkl"
41
+ local_model = f"{LOCAL_MODEL_DIR}/model.pkl"
42
+ local_preprocessor = f"{LOCAL_MODEL_DIR}/preprocessor.pkl"
43
+
44
+ # Check if models exist in persistent storage
45
+ if os.path.exists(persistent_model) and os.path.exists(persistent_preprocessor):
46
+ # Copy from persistent storage to local directory
47
+ os.makedirs(LOCAL_MODEL_DIR, exist_ok=True)
48
+ import shutil
49
+ shutil.copy2(persistent_model, local_model)
50
+ shutil.copy2(persistent_preprocessor, local_preprocessor)
51
+ logging.info("✅ Models restored from persistent storage (/data/models)")
52
+ return True
53
+ else:
54
+ logging.warning("⚠️ No models found in persistent storage")
55
+ return False
56
+ except Exception as e:
57
+ logging.error(f"Error restoring models from persistent storage: {e}")
58
+ return False
59
+
60
  # Cache for loaded models
61
  MODEL_CACHE = {"model": None, "preprocessor": None}
62
  MLFLOW_AVAILABLE = True # Assume available, model_trainer.py handles initialization
 
124
  async def lifespan(app: FastAPI):
125
  """Initialize application on startup"""
126
  logging.info("===== Application Startup =====")
127
+
128
+ # Try to restore models from persistent storage
129
+ model_path = f"{LOCAL_MODEL_DIR}/model.pkl"
130
+ preprocessor_path = f"{LOCAL_MODEL_DIR}/preprocessor.pkl"
131
+
132
+ # Check if local models exist
133
+ if os.path.exists(model_path) and os.path.exists(preprocessor_path):
134
+ logging.info("✅ Models found in local directory")
135
+ else:
136
+ # Try to restore from persistent storage
137
+ logging.info("Checking persistent storage for models...")
138
+ if restore_models_from_persistent_storage():
139
+ logging.info("✅ Models restored and ready for predictions")
140
+ else:
141
+ logging.warning("⚠️ No models available. Please call /train endpoint first.")
142
+
143
  logging.info("✅ Application ready to serve requests")
144
 
145
  yield
 
165
  @app.get("/")
166
  async def root():
167
  """Root endpoint with system status"""
168
+ local_exists = os.path.exists(f"{LOCAL_MODEL_DIR}/model.pkl")
169
+ persistent_exists = os.path.exists(f"{PERSISTENT_MODEL_DIR}/model.pkl")
170
+
171
+ if local_exists or persistent_exists:
172
+ model_status = "✅ Ready"
173
+ else:
174
+ model_status = "⚠️ Not trained - call /train first"
175
 
176
  return {
177
  "status": "running",
178
  "service": "Network Security System - Phishing Detection",
179
  "model_status": model_status,
180
+ "persistent_storage": persistent_exists,
181
  "mlflow_enabled": MLFLOW_AVAILABLE,
182
  "endpoints": {
183
  "docs": "/docs",
184
+ "train": "/train (trains and saves to persistent storage)",
185
+ "predict": "/predict (uses persistent models)"
186
  }
187
  }
188
 
 
204
  @app.post("/predict") # predict route
205
  async def predict_route(request: Request, file: UploadFile =File(...)):
206
  try:
207
+ model_path = f"{LOCAL_MODEL_DIR}/model.pkl"
208
+ preprocessor_path = f"{LOCAL_MODEL_DIR}/preprocessor.pkl"
209
+
210
+ # Check if models exist locally, if not try to restore from persistent storage
211
+ if not (os.path.exists(model_path) and os.path.exists(preprocessor_path)):
212
+ logging.info("Local models not found, restoring from persistent storage...")
213
+ if not restore_models_from_persistent_storage():
214
  return Response(
215
  "❌ No trained model available. Please call /train endpoint first.",
216
  status_code=400
 
221
  if 'Result' in df.columns:
222
  df = df.drop(columns=['Result'])
223
 
224
+ # Load models from local files
225
+ preprocessor = load_object(file_path=preprocessor_path)
226
+ model = load_object(file_path=model_path)
227
 
228
  NSmodel = NetworkSecurityModel(preprocessing_object=preprocessor, trained_model_object=model)
229
  y_pred = NSmodel.predict(df)
230
  df['predicted_column'] = y_pred
231
 
232
  # Save predictions
233
+ df.to_csv(f"{LOCAL_MODEL_DIR}/predicted.csv")
 
234
 
235
  table_html = df.to_html(classes='table table-striped')
236
  return templates.TemplateResponse("table.html", {"request": request, "table": table_html})