Bachstelze commited on
Commit
a5e8cf9
·
1 Parent(s): 71bcc71

add asynchronous model loading

Browse files
Files changed (1) hide show
  1. app.py +84 -35
app.py CHANGED
@@ -2,6 +2,8 @@ import gradio as gr
2
  import pandas as pd
3
  import pickle
4
  import os
 
 
5
  from A5.CorrelationFilter import CorrelationFilter
6
 
7
 
@@ -33,6 +35,10 @@ CLASSIFICATION_FEATURE_NAMES = None
33
  CLASSIFICATION_CLASSES = None
34
  CLASSIFICATION_METRICS = None
35
 
 
 
 
 
36
  BODY_REGION_RECOMMENDATIONS = {
37
  'Upper Body': (
38
  "Focus on shoulder mobility, thoracic spine extension, "
@@ -44,22 +50,31 @@ BODY_REGION_RECOMMENDATIONS = {
44
 
45
 
46
  def load_champion_model():
47
- global model, FEATURE_NAMES, MODEL_METRICS
48
 
49
  if os.path.exists(MODEL_PATH):
50
  print(f"Loading champion model from {MODEL_PATH}")
51
- with open(MODEL_PATH, "rb") as f:
52
- artifact = pickle.load(f)
53
-
54
- model = artifact["model"]
55
- FEATURE_NAMES = artifact["feature_columns"]
56
- MODEL_METRICS = artifact.get("test_metrics", {})
57
-
58
- print(f"Model loaded: {len(FEATURE_NAMES)} features")
59
- print(f"Test R2: {MODEL_METRICS.get('r2', 'N/A')}")
60
- return True
61
-
62
- print(f"Champion model not found at {MODEL_PATH}")
 
 
 
 
 
 
 
 
 
63
  return False
64
 
65
 
@@ -68,30 +83,42 @@ def load_classification_model():
68
  global CLASSIFICATION_FEATURE_NAMES
69
  global CLASSIFICATION_CLASSES
70
  global CLASSIFICATION_METRICS
 
71
 
72
  if os.path.exists(CLASSIFICATION_MODEL_PATH):
73
  print(f"Loading classification model from {CLASSIFICATION_MODEL_PATH}")
74
- with open(CLASSIFICATION_MODEL_PATH, "rb") as f:
75
- artifact = pickle.load(f)
76
-
77
- classification_model = artifact["model"]
78
- CLASSIFICATION_FEATURE_NAMES = artifact["feature_columns"]
79
- CLASSIFICATION_CLASSES = artifact["classes"]
80
- CLASSIFICATION_METRICS = artifact.get("test_metrics", {})
81
-
82
- len_features = len(CLASSIFICATION_FEATURE_NAMES)
83
- print(
84
- f"Classification model loaded: {len_features} features")
85
- print(f"Classes: {CLASSIFICATION_CLASSES}")
86
- return True
87
-
88
- print(f"Classification model not found at {CLASSIFICATION_MODEL_PATH}")
 
 
 
 
 
 
 
 
 
89
  return False
90
 
91
 
92
  def predict_score(*feature_values):
93
  if model is None:
94
- return "Error", "Model not loaded", ""
 
 
95
 
96
  features_df = pd.DataFrame([feature_values], columns=FEATURE_NAMES)
97
  raw_score = model.predict(features_df)[0]
@@ -130,7 +157,9 @@ def predict_score(*feature_values):
130
 
131
  def predict_weakest_link(*feature_values):
132
  if classification_model is None:
133
- return "Error", "Model not loaded", ""
 
 
134
 
135
  features_df = pd.DataFrame(
136
  [feature_values], columns=CLASSIFICATION_FEATURE_NAMES)
@@ -224,9 +253,12 @@ def load_classification_example():
224
 
225
 
226
  def create_interface():
 
 
227
  if FEATURE_NAMES is None:
 
228
  return gr.Interface(
229
- fn=lambda: "Model not loaded",
230
  inputs=[],
231
  outputs="text",
232
  title="Error: Model not loaded"
@@ -413,12 +445,29 @@ def create_interface():
413
 
414
  return demo
415
 
416
-
417
- if __name__ == "__main__":
418
- # load the pickled models
 
419
  load_champion_model()
420
  load_classification_model()
 
 
 
 
 
 
421
 
422
- # create the interface
 
 
 
 
423
  demo = create_interface()
 
 
 
 
 
 
424
  demo.launch(share=False, server_name="0.0.0.0", server_port=7860)
 
2
  import pandas as pd
3
  import pickle
4
  import os
5
+ import threading
6
+ import time
7
  from A5.CorrelationFilter import CorrelationFilter
8
 
9
 
 
35
  CLASSIFICATION_CLASSES = None
36
  CLASSIFICATION_METRICS = None
37
 
38
+ # Loading state tracking
39
+ models_loaded = False
40
+ loading_error = None
41
+
42
  BODY_REGION_RECOMMENDATIONS = {
43
  'Upper Body': (
44
  "Focus on shoulder mobility, thoracic spine extension, "
 
50
 
51
 
52
  def load_champion_model():
53
+ global model, FEATURE_NAMES, MODEL_METRICS, loading_error
54
 
55
  if os.path.exists(MODEL_PATH):
56
  print(f"Loading champion model from {MODEL_PATH}")
57
+ start_time = time.perf_counter()
58
+ try:
59
+ with open(MODEL_PATH, "rb") as f:
60
+ artifact = pickle.load(f)
61
+
62
+ model = artifact["model"]
63
+ FEATURE_NAMES = artifact["feature_columns"]
64
+ MODEL_METRICS = artifact.get("test_metrics", {})
65
+
66
+ elapsed_time = time.perf_counter() - start_time
67
+ print(f"Model loaded: {len(FEATURE_NAMES)} features")
68
+ print(f"Test R2: {MODEL_METRICS.get('r2', 'N/A')}")
69
+ print(f"Model loading time: {elapsed_time:.2f} seconds")
70
+ return True
71
+ except Exception as e:
72
+ loading_error = f"Error loading champion model: {e}"
73
+ print(loading_error)
74
+ return False
75
+
76
+ loading_error = f"Champion model not found at {MODEL_PATH}"
77
+ print(loading_error)
78
  return False
79
 
80
 
 
83
  global CLASSIFICATION_FEATURE_NAMES
84
  global CLASSIFICATION_CLASSES
85
  global CLASSIFICATION_METRICS
86
+ global loading_error
87
 
88
  if os.path.exists(CLASSIFICATION_MODEL_PATH):
89
  print(f"Loading classification model from {CLASSIFICATION_MODEL_PATH}")
90
+ start_time = time.perf_counter()
91
+ try:
92
+ with open(CLASSIFICATION_MODEL_PATH, "rb") as f:
93
+ artifact = pickle.load(f)
94
+
95
+ classification_model = artifact["model"]
96
+ CLASSIFICATION_FEATURE_NAMES = artifact["feature_columns"]
97
+ CLASSIFICATION_CLASSES = artifact["classes"]
98
+ CLASSIFICATION_METRICS = artifact.get("test_metrics", {})
99
+
100
+ len_features = len(CLASSIFICATION_FEATURE_NAMES)
101
+ elapsed_time = time.perf_counter() - start_time
102
+ print(
103
+ f"Classification model loaded: {len_features} features")
104
+ print(f"Classes: {CLASSIFICATION_CLASSES}")
105
+ print(f"Classification model loading time: {elapsed_time:.2f} seconds")
106
+ return True
107
+ except Exception as e:
108
+ loading_error = f"Error loading classification model: {e}"
109
+ print(loading_error)
110
+ return False
111
+
112
+ loading_error = f"Classification model not found at {CLASSIFICATION_MODEL_PATH}"
113
+ print(loading_error)
114
  return False
115
 
116
 
117
  def predict_score(*feature_values):
118
  if model is None:
119
+ if loading_error:
120
+ return "Error", loading_error, ""
121
+ return "Error", "Model not loaded yet", ""
122
 
123
  features_df = pd.DataFrame([feature_values], columns=FEATURE_NAMES)
124
  raw_score = model.predict(features_df)[0]
 
157
 
158
  def predict_weakest_link(*feature_values):
159
  if classification_model is None:
160
+ if loading_error:
161
+ return "Error", loading_error, ""
162
+ return "Error", "Classification model not loaded yet", ""
163
 
164
  features_df = pd.DataFrame(
165
  [feature_values], columns=CLASSIFICATION_FEATURE_NAMES)
 
253
 
254
 
255
  def create_interface():
256
+ global models_loaded
257
+
258
  if FEATURE_NAMES is None:
259
+ error_message = loading_error if loading_error else "Model not loaded"
260
  return gr.Interface(
261
+ fn=lambda: error_message,
262
  inputs=[],
263
  outputs="text",
264
  title="Error: Model not loaded"
 
445
 
446
  return demo
447
 
448
+ def load_models_async():
449
+ global models_loaded
450
+ start_time = time.perf_counter()
451
+ print("Starting asynchronous model loading...")
452
  load_champion_model()
453
  load_classification_model()
454
+ models_loaded = True
455
+ elapsed_time = time.perf_counter() - start_time
456
+ print(f"Model loading complete (total time: {elapsed_time:.2f} seconds)")
457
+
458
+ if __name__ == "__main__":
459
+ # Load models asynchronously in background threads
460
 
461
+ # Start model loading in background thread
462
+ loading_thread = threading.Thread(target=load_models_async, daemon=True)
463
+ loading_thread.start()
464
+
465
+ # Create the interface immediately (models loading in background)
466
  demo = create_interface()
467
+
468
+ # Add loading status to the interface
469
+ if not models_loaded:
470
+ print("Models are loading in the background...")
471
+ print("You can use the interface while models load.")
472
+
473
  demo.launch(share=False, server_name="0.0.0.0", server_port=7860)