Kshitijk20 commited on
Commit
679faff
·
1 Parent(s): 1bf4542

lazy model loading

Browse files
Files changed (1) hide show
  1. app.py +5 -14
app.py CHANGED
@@ -13,7 +13,6 @@ import pandas as pd
13
  from src.utils.ml_utils.model.estimator import NetworkSecurityModel
14
  from contextlib import asynccontextmanager
15
  import mlflow
16
- import dagshub
17
 
18
  ca = certifi.where()
19
  load_dotenv()
@@ -30,17 +29,9 @@ from src.utils.main_utils.utils import load_object, save_object
30
  from fastapi.templating import Jinja2Templates
31
  templates = Jinja2Templates(directory="./templates")
32
 
33
- # Initialize DagHub for MLflow tracking
34
- try:
35
- dagshub.init(repo_owner='kshitijk146', repo_name='MLOPS_project_network_Security_system', mlflow=True)
36
- MLFLOW_AVAILABLE = True
37
- logging.info("✅ MLflow tracking initialized")
38
- except Exception as e:
39
- logging.warning(f"⚠️ MLflow initialization failed: {e}")
40
- MLFLOW_AVAILABLE = False
41
-
42
  # Cache for loaded models
43
  MODEL_CACHE = {"model": None, "preprocessor": None}
 
44
 
45
  def load_models_from_mlflow():
46
  """Load latest models from MLflow"""
@@ -152,11 +143,11 @@ async def training_route():
152
  training_pipeline = Trainingpipeline()
153
  training_pipeline.run_pipeline()
154
 
155
- # Reload models from MLflow after training
156
- if MLFLOW_AVAILABLE:
157
- load_models_from_mlflow()
158
 
159
- return Response("✅ Training completed and models loaded from MLflow!")
160
  except Exception as e:
161
  raise NetworkSecurityException(e, sys)
162
 
 
13
  from src.utils.ml_utils.model.estimator import NetworkSecurityModel
14
  from contextlib import asynccontextmanager
15
  import mlflow
 
16
 
17
  ca = certifi.where()
18
  load_dotenv()
 
29
  from fastapi.templating import Jinja2Templates
30
  templates = Jinja2Templates(directory="./templates")
31
 
 
 
 
 
 
 
 
 
 
32
  # Cache for loaded models
33
  MODEL_CACHE = {"model": None, "preprocessor": None}
34
+ MLFLOW_AVAILABLE = True # Assume available, model_trainer.py handles initialization
35
 
36
  def load_models_from_mlflow():
37
  """Load latest models from MLflow"""
 
143
  training_pipeline = Trainingpipeline()
144
  training_pipeline.run_pipeline()
145
 
146
+ # Clear model cache so next prediction loads fresh models
147
+ MODEL_CACHE["model"] = None
148
+ MODEL_CACHE["preprocessor"] = None
149
 
150
+ return Response("✅ Training completed! Models logged to MLflow. Call /predict to use them.")
151
  except Exception as e:
152
  raise NetworkSecurityException(e, sys)
153