Kshitijk20 commited on
Commit
5aa701f
·
1 Parent(s): 1b347a0

adding mlfow registered model loading

Browse files
Files changed (2) hide show
  1. app.py +23 -5
  2. src/components/model_trainer.py +10 -1
app.py CHANGED
@@ -49,9 +49,17 @@ def load_models_from_mlflow():
49
  logging.error("MLflow not available")
50
  return False
51
 
 
 
52
  # Get the latest run from the experiment
53
  client = mlflow.tracking.MlflowClient()
54
- experiment = client.get_experiment_by_name("Default")
 
 
 
 
 
 
55
 
56
  if experiment is None:
57
  logging.warning("No MLflow experiment found. Train model first.")
@@ -89,20 +97,30 @@ def load_models_from_mlflow():
89
 
90
  except Exception as e:
91
  logging.error(f"Error loading models from MLflow: {e}")
 
 
92
  return False
93
 
94
  @asynccontextmanager
95
  async def lifespan(app: FastAPI):
96
  """Load models on startup"""
97
- logging.info("===== Application Startup - Loading models from MLflow =====")
98
 
99
  if MLFLOW_AVAILABLE:
100
- success = load_models_from_mlflow()
101
- if not success:
102
- logging.warning("⚠️ Could not load models from MLflow. Please train first via /train endpoint.")
 
 
 
 
 
 
 
103
  else:
104
  logging.warning("⚠️ MLflow not available. Please train via /train endpoint.")
105
 
 
106
  yield
107
  logging.info("===== Application Shutdown =====")
108
 
 
49
  logging.error("MLflow not available")
50
  return False
51
 
52
+ logging.info("Searching for latest MLflow run...")
53
+
54
  # Get the latest run from the experiment
55
  client = mlflow.tracking.MlflowClient()
56
+
57
+ # Try to get experiment, if it doesn't exist, no models are trained yet
58
+ try:
59
+ experiment = client.get_experiment_by_name("Default")
60
+ except Exception as e:
61
+ logging.warning(f"Could not get experiment: {e}")
62
+ return False
63
 
64
  if experiment is None:
65
  logging.warning("No MLflow experiment found. Train model first.")
 
97
 
98
  except Exception as e:
99
  logging.error(f"Error loading models from MLflow: {e}")
100
+ import traceback
101
+ logging.error(traceback.format_exc())
102
  return False
103
 
104
  @asynccontextmanager
105
  async def lifespan(app: FastAPI):
106
  """Load models on startup"""
107
+ logging.info("===== Application Startup - Checking for models =====")
108
 
109
  if MLFLOW_AVAILABLE:
110
+ try:
111
+ # Try to load models but don't block startup if it fails
112
+ logging.info("Attempting to load models from MLflow...")
113
+ success = load_models_from_mlflow()
114
+ if success:
115
+ logging.info("✅ Models loaded successfully from MLflow")
116
+ else:
117
+ logging.warning("⚠️ No models found in MLflow. Train via /train endpoint.")
118
+ except Exception as e:
119
+ logging.warning(f"⚠️ Could not load from MLflow: {e}. Train via /train endpoint.")
120
  else:
121
  logging.warning("⚠️ MLflow not available. Please train via /train endpoint.")
122
 
123
+ logging.info("✅ Application ready to serve requests")
124
  yield
125
  logging.info("===== Application Shutdown =====")
126
 
src/components/model_trainer.py CHANGED
@@ -21,7 +21,16 @@ import dagshub
21
  import os
22
  from dotenv import load_dotenv
23
  load_dotenv()
24
- dagshub.init(repo_owner='kshitijk146', repo_name='MLOPS_project_network_Security_system', mlflow=True)
 
 
 
 
 
 
 
 
 
25
  class ModelTrainer:
26
  def __init__(self, model_trainer_config: Model_trainer_config, data_transformation_artifact: DataTransformationArtifact):
27
  try:
 
21
  import os
22
  from dotenv import load_dotenv
23
  load_dotenv()
24
+
25
+ # Only initialize DagHub once, and make it optional
26
+ try:
27
+ if not os.getenv("DAGSHUB_INITIALIZED"):
28
+ dagshub.init(repo_owner='kshitijk146', repo_name='MLOPS_project_network_Security_system', mlflow=True)
29
+ os.environ["DAGSHUB_INITIALIZED"] = "1"
30
+ logging.info("✅ DagHub/MLflow initialized in model_trainer")
31
+ except Exception as e:
32
+ logging.warning(f"⚠️ DagHub initialization failed: {e}. Continuing without MLflow tracking.")
33
+
34
  class ModelTrainer:
35
  def __init__(self, model_trainer_config: Model_trainer_config, data_transformation_artifact: DataTransformationArtifact):
36
  try: