Kshitijk20 commited on
Commit
1b347a0
·
1 Parent(s): 2cc7b15

adding mlfow registered model loading

Browse files
Files changed (3) hide show
  1. Dockerfile +0 -2
  2. app.py +114 -13
  3. src/components/model_trainer.py +13 -7
Dockerfile CHANGED
@@ -26,8 +26,6 @@ RUN mkdir -p /app/data /app/final_model /app/templates
26
 
27
  # run the load_data_to_sqlite.py script to initialize the database
28
  RUN python load_data_to_sqlite.py
29
- # Train the model during build (this persists across container restarts)
30
- RUN python -c "from src.pipeline.training_pipeline import Trainingpipeline; Trainingpipeline().run_pipeline()"
31
 
32
  # Expose port 7860 (HF Space requirement)
33
  EXPOSE 7860
 
26
 
27
  # run the load_data_to_sqlite.py script to initialize the database
28
  RUN python load_data_to_sqlite.py
 
 
29
 
30
  # Expose port 7860 (HF Space requirement)
31
  EXPOSE 7860
app.py CHANGED
@@ -11,6 +11,9 @@ from fastapi.responses import Response
11
  from starlette.responses import RedirectResponse
12
  import pandas as pd
13
  from src.utils.ml_utils.model.estimator import NetworkSecurityModel
 
 
 
14
 
15
  ca = certifi.where()
16
  load_dotenv()
@@ -18,15 +21,92 @@ mongo_db_uri = os.getenv("MONGO_DB_URI")
18
 
19
  from src.constant.training_pipeline import DATA_INGESTION_COLLECTION_NAME
20
  from src.constant.training_pipeline import DATA_INGESTION_DATBASE_NANE
21
- from src.utils.main_utils.utils import load_object
22
  # import pymongo
23
 
24
  # client = pymongo.MongoClient(mongo_db_uri,tlsCAFile=ca)
25
  # database = client[DATA_INGESTION_DATBASE_NANE]
26
  # collection = database[DATA_INGESTION_COLLECTION_NAME]
27
  from fastapi.templating import Jinja2Templates
28
- templates = Jinja2Templates(directory="./templates")
29
- app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  orgin = ["*"]
32
 
@@ -45,13 +125,17 @@ app.add_middleware(
45
  @app.get("/")
46
  async def root():
47
  """Root endpoint with system status"""
 
 
48
  return {
49
  "status": "running",
50
  "service": "Network Security System - Phishing Detection",
 
 
51
  "endpoints": {
52
  "docs": "/docs",
53
- "train": "/train",
54
- "predict": "/predict"
55
  }
56
  }
57
 
@@ -61,28 +145,45 @@ async def training_route():
61
  logging.info("Starting training pipeline...")
62
  training_pipeline = Trainingpipeline()
63
  training_pipeline.run_pipeline()
64
- return Response("Training completed successfully!")
 
 
 
 
 
65
  except Exception as e:
66
  raise NetworkSecurityException(e, sys)
67
 
68
  @app.post("/predict") # predict route
69
  async def predict_route(request: Request, file: UploadFile =File(...)):
70
  try:
 
 
 
 
 
 
 
 
 
71
  df = pd.read_csv(file.file)
72
  # Remove target column if it exists
73
  if 'Result' in df.columns:
74
  df = df.drop(columns=['Result'])
75
- preprocessor = load_object(file_path = "final_model/preprocessor.pkl")
76
- model = load_object(file_path= "final_model/model.pkl")
77
- NSmodel = NetworkSecurityModel(preprocessing_object= preprocessor, trained_model_object= model)
78
- print(df.iloc[0])
 
 
79
  y_pred = NSmodel.predict(df)
80
- print(y_pred)
81
  df['predicted_column'] = y_pred
82
- print(df['predicted_column'])
 
 
83
  df.to_csv("final_model/predicted.csv")
84
 
85
- table_html = df.to_html(classes = 'table table-striped')
86
  return templates.TemplateResponse("table.html", {"request": request, "table": table_html})
87
 
88
  except Exception as e:
 
11
  from starlette.responses import RedirectResponse
12
  import pandas as pd
13
  from src.utils.ml_utils.model.estimator import NetworkSecurityModel
14
+ from contextlib import asynccontextmanager
15
+ import mlflow
16
+ import dagshub
17
 
18
  ca = certifi.where()
19
  load_dotenv()
 
21
 
22
  from src.constant.training_pipeline import DATA_INGESTION_COLLECTION_NAME
23
  from src.constant.training_pipeline import DATA_INGESTION_DATBASE_NANE
24
+ from src.utils.main_utils.utils import load_object, save_object
25
  # import pymongo
26
 
27
  # client = pymongo.MongoClient(mongo_db_uri,tlsCAFile=ca)
28
  # database = client[DATA_INGESTION_DATBASE_NANE]
29
  # collection = database[DATA_INGESTION_COLLECTION_NAME]
30
  from fastapi.templating import Jinja2Templates
31
+ templates = Jinja2Templates(directory="./templates")
32
+
33
+ # Initialize DagHub for MLflow tracking
34
+ try:
35
+ dagshub.init(repo_owner='kshitijk146', repo_name='MLOPS_project_network_Security_system', mlflow=True)
36
+ MLFLOW_AVAILABLE = True
37
+ logging.info("✅ MLflow tracking initialized")
38
+ except Exception as e:
39
+ logging.warning(f"⚠️ MLflow initialization failed: {e}")
40
+ MLFLOW_AVAILABLE = False
41
+
42
+ # Cache for loaded models
43
+ MODEL_CACHE = {"model": None, "preprocessor": None}
44
+
45
+ def load_models_from_mlflow():
46
+ """Load latest models from MLflow"""
47
+ try:
48
+ if not MLFLOW_AVAILABLE:
49
+ logging.error("MLflow not available")
50
+ return False
51
+
52
+ # Get the latest run from the experiment
53
+ client = mlflow.tracking.MlflowClient()
54
+ experiment = client.get_experiment_by_name("Default")
55
+
56
+ if experiment is None:
57
+ logging.warning("No MLflow experiment found. Train model first.")
58
+ return False
59
+
60
+ runs = client.search_runs(
61
+ experiment_ids=[experiment.experiment_id],
62
+ order_by=["start_time DESC"],
63
+ max_results=1
64
+ )
65
+
66
+ if not runs:
67
+ logging.warning("No MLflow runs found. Train model first.")
68
+ return False
69
+
70
+ latest_run = runs[0]
71
+ run_id = latest_run.info.run_id
72
+
73
+ logging.info(f"Loading models from MLflow run: {run_id}")
74
+
75
+ # Load model and preprocessor
76
+ model_uri = f"runs:/{run_id}/model"
77
+ preprocessor_uri = f"runs:/{run_id}/preprocessor"
78
+
79
+ MODEL_CACHE["model"] = mlflow.sklearn.load_model(model_uri)
80
+ MODEL_CACHE["preprocessor"] = mlflow.sklearn.load_model(preprocessor_uri)
81
+
82
+ # Save to local directory as backup
83
+ os.makedirs("final_model", exist_ok=True)
84
+ save_object("final_model/model.pkl", MODEL_CACHE["model"])
85
+ save_object("final_model/preprocessor.pkl", MODEL_CACHE["preprocessor"])
86
+
87
+ logging.info("✅ Models loaded from MLflow and cached locally")
88
+ return True
89
+
90
+ except Exception as e:
91
+ logging.error(f"Error loading models from MLflow: {e}")
92
+ return False
93
+
94
+ @asynccontextmanager
95
+ async def lifespan(app: FastAPI):
96
+ """Load models on startup"""
97
+ logging.info("===== Application Startup - Loading models from MLflow =====")
98
+
99
+ if MLFLOW_AVAILABLE:
100
+ success = load_models_from_mlflow()
101
+ if not success:
102
+ logging.warning("⚠️ Could not load models from MLflow. Please train first via /train endpoint.")
103
+ else:
104
+ logging.warning("⚠️ MLflow not available. Please train via /train endpoint.")
105
+
106
+ yield
107
+ logging.info("===== Application Shutdown =====")
108
+
109
+ app = FastAPI(lifespan=lifespan)
110
 
111
  orgin = ["*"]
112
 
 
125
  @app.get("/")
126
  async def root():
127
  """Root endpoint with system status"""
128
+ model_status = "✅ Ready" if MODEL_CACHE["model"] is not None else "⚠️ Not trained - call /train first"
129
+
130
  return {
131
  "status": "running",
132
  "service": "Network Security System - Phishing Detection",
133
+ "model_status": model_status,
134
+ "mlflow_enabled": MLFLOW_AVAILABLE,
135
  "endpoints": {
136
  "docs": "/docs",
137
+ "train": "/train (trains and logs to MLflow)",
138
+ "predict": "/predict (loads from MLflow)"
139
  }
140
  }
141
 
 
145
  logging.info("Starting training pipeline...")
146
  training_pipeline = Trainingpipeline()
147
  training_pipeline.run_pipeline()
148
+
149
+ # Reload models from MLflow after training
150
+ if MLFLOW_AVAILABLE:
151
+ load_models_from_mlflow()
152
+
153
+ return Response("✅ Training completed and models loaded from MLflow!")
154
  except Exception as e:
155
  raise NetworkSecurityException(e, sys)
156
 
157
  @app.post("/predict") # predict route
158
  async def predict_route(request: Request, file: UploadFile =File(...)):
159
  try:
160
+ # Check if models are loaded
161
+ if MODEL_CACHE["model"] is None or MODEL_CACHE["preprocessor"] is None:
162
+ # Try to load from MLflow
163
+ if not load_models_from_mlflow():
164
+ return Response(
165
+ "❌ No trained model available. Please call /train endpoint first.",
166
+ status_code=400
167
+ )
168
+
169
  df = pd.read_csv(file.file)
170
  # Remove target column if it exists
171
  if 'Result' in df.columns:
172
  df = df.drop(columns=['Result'])
173
+
174
+ # Use cached models from MLflow
175
+ preprocessor = MODEL_CACHE["preprocessor"]
176
+ model = MODEL_CACHE["model"]
177
+
178
+ NSmodel = NetworkSecurityModel(preprocessing_object=preprocessor, trained_model_object=model)
179
  y_pred = NSmodel.predict(df)
 
180
  df['predicted_column'] = y_pred
181
+
182
+ # Save predictions
183
+ os.makedirs("final_model", exist_ok=True)
184
  df.to_csv("final_model/predicted.csv")
185
 
186
+ table_html = df.to_html(classes='table table-striped')
187
  return templates.TemplateResponse("table.html", {"request": request, "table": table_html})
188
 
189
  except Exception as e:
src/components/model_trainer.py CHANGED
@@ -30,8 +30,9 @@ class ModelTrainer:
30
  except Exception as e:
31
  raise NetworkSecurityException(e, sys) from e
32
 
33
- def track_mlflow(self,best_model, classificationmetric):
34
- with mlflow.start_run():
 
35
  f1_score = classificationmetric.f1_score
36
  precision_score = classificationmetric.precision_score
37
  recall_score = classificationmetric.recall_score
@@ -39,7 +40,14 @@ class ModelTrainer:
39
  mlflow.log_metric("f1_score", f1_score)
40
  mlflow.log_metric("precision_score", precision_score)
41
  mlflow.log_metric("recall_score", recall_score)
 
 
42
  mlflow.sklearn.log_model(best_model, "model")
 
 
 
 
 
43
 
44
  def train_model(self, x_train, y_train,x_test, y_test):
45
  models = {
@@ -104,15 +112,13 @@ class ModelTrainer:
104
  y_train_pred = best_model.predict(x_train)
105
  classification_train_metric= classification_score(y_true = y_train, y_pred=y_train_pred)
106
 
107
- # track mlfow
108
- self.track_mlflow(best_model, classification_train_metric)
109
-
110
-
111
-
112
  y_test_pred = best_model.predict(x_test)
113
  classification_test_metric = classification_score(y_true = y_test, y_pred=y_test_pred)
114
 
115
  preprocessor = load_object(file_path=self.data_transformation_artifact.transformed_object_file_path)
 
 
 
116
  model_dir_path = os.path.dirname(self.model_trainer_config.trained_model_file_path)
117
  os.makedirs(model_dir_path, exist_ok=True)
118
 
 
30
  except Exception as e:
31
  raise NetworkSecurityException(e, sys) from e
32
 
33
+ def track_mlflow(self, best_model, preprocessor, classificationmetric):
34
+ """Log model, preprocessor, and metrics to MLflow"""
35
+ with mlflow.start_run() as run:
36
  f1_score = classificationmetric.f1_score
37
  precision_score = classificationmetric.precision_score
38
  recall_score = classificationmetric.recall_score
 
40
  mlflow.log_metric("f1_score", f1_score)
41
  mlflow.log_metric("precision_score", precision_score)
42
  mlflow.log_metric("recall_score", recall_score)
43
+
44
+ # Log both model and preprocessor
45
  mlflow.sklearn.log_model(best_model, "model")
46
+ mlflow.sklearn.log_model(preprocessor, "preprocessor")
47
+
48
+ # Log run ID for easy retrieval
49
+ logging.info(f"✅ Models logged to MLflow - Run ID: {run.info.run_id}")
50
+ return run.info.run_id
51
 
52
  def train_model(self, x_train, y_train,x_test, y_test):
53
  models = {
 
112
  y_train_pred = best_model.predict(x_train)
113
  classification_train_metric= classification_score(y_true = y_train, y_pred=y_train_pred)
114
 
 
 
 
 
 
115
  y_test_pred = best_model.predict(x_test)
116
  classification_test_metric = classification_score(y_true = y_test, y_pred=y_test_pred)
117
 
118
  preprocessor = load_object(file_path=self.data_transformation_artifact.transformed_object_file_path)
119
+
120
+ # Track to MLflow (logs model + preprocessor)
121
+ self.track_mlflow(best_model, preprocessor, classification_train_metric)
122
  model_dir_path = os.path.dirname(self.model_trainer_config.trained_model_file_path)
123
  os.makedirs(model_dir_path, exist_ok=True)
124