Spaces:
Sleeping
Sleeping
Bachstelze commited on
Commit ·
a5e8cf9
1
Parent(s): 71bcc71
add asynchronous model loading
Browse files
app.py
CHANGED
|
@@ -2,6 +2,8 @@ import gradio as gr
|
|
| 2 |
import pandas as pd
|
| 3 |
import pickle
|
| 4 |
import os
|
|
|
|
|
|
|
| 5 |
from A5.CorrelationFilter import CorrelationFilter
|
| 6 |
|
| 7 |
|
|
@@ -33,6 +35,10 @@ CLASSIFICATION_FEATURE_NAMES = None
|
|
| 33 |
CLASSIFICATION_CLASSES = None
|
| 34 |
CLASSIFICATION_METRICS = None
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
BODY_REGION_RECOMMENDATIONS = {
|
| 37 |
'Upper Body': (
|
| 38 |
"Focus on shoulder mobility, thoracic spine extension, "
|
|
@@ -44,22 +50,31 @@ BODY_REGION_RECOMMENDATIONS = {
|
|
| 44 |
|
| 45 |
|
| 46 |
def load_champion_model():
|
| 47 |
-
global model, FEATURE_NAMES, MODEL_METRICS
|
| 48 |
|
| 49 |
if os.path.exists(MODEL_PATH):
|
| 50 |
print(f"Loading champion model from {MODEL_PATH}")
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
return False
|
| 64 |
|
| 65 |
|
|
@@ -68,30 +83,42 @@ def load_classification_model():
|
|
| 68 |
global CLASSIFICATION_FEATURE_NAMES
|
| 69 |
global CLASSIFICATION_CLASSES
|
| 70 |
global CLASSIFICATION_METRICS
|
|
|
|
| 71 |
|
| 72 |
if os.path.exists(CLASSIFICATION_MODEL_PATH):
|
| 73 |
print(f"Loading classification model from {CLASSIFICATION_MODEL_PATH}")
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
return False
|
| 90 |
|
| 91 |
|
| 92 |
def predict_score(*feature_values):
|
| 93 |
if model is None:
|
| 94 |
-
|
|
|
|
|
|
|
| 95 |
|
| 96 |
features_df = pd.DataFrame([feature_values], columns=FEATURE_NAMES)
|
| 97 |
raw_score = model.predict(features_df)[0]
|
|
@@ -130,7 +157,9 @@ def predict_score(*feature_values):
|
|
| 130 |
|
| 131 |
def predict_weakest_link(*feature_values):
|
| 132 |
if classification_model is None:
|
| 133 |
-
|
|
|
|
|
|
|
| 134 |
|
| 135 |
features_df = pd.DataFrame(
|
| 136 |
[feature_values], columns=CLASSIFICATION_FEATURE_NAMES)
|
|
@@ -224,9 +253,12 @@ def load_classification_example():
|
|
| 224 |
|
| 225 |
|
| 226 |
def create_interface():
|
|
|
|
|
|
|
| 227 |
if FEATURE_NAMES is None:
|
|
|
|
| 228 |
return gr.Interface(
|
| 229 |
-
fn=lambda:
|
| 230 |
inputs=[],
|
| 231 |
outputs="text",
|
| 232 |
title="Error: Model not loaded"
|
|
@@ -413,12 +445,29 @@ def create_interface():
|
|
| 413 |
|
| 414 |
return demo
|
| 415 |
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
|
|
|
| 419 |
load_champion_model()
|
| 420 |
load_classification_model()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 421 |
|
| 422 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
demo = create_interface()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
demo.launch(share=False, server_name="0.0.0.0", server_port=7860)
|
|
|
|
| 2 |
import pandas as pd
|
| 3 |
import pickle
|
| 4 |
import os
|
| 5 |
+
import threading
|
| 6 |
+
import time
|
| 7 |
from A5.CorrelationFilter import CorrelationFilter
|
| 8 |
|
| 9 |
|
|
|
|
| 35 |
CLASSIFICATION_CLASSES = None
|
| 36 |
CLASSIFICATION_METRICS = None
|
| 37 |
|
| 38 |
+
# Loading state tracking
|
| 39 |
+
models_loaded = False
|
| 40 |
+
loading_error = None
|
| 41 |
+
|
| 42 |
BODY_REGION_RECOMMENDATIONS = {
|
| 43 |
'Upper Body': (
|
| 44 |
"Focus on shoulder mobility, thoracic spine extension, "
|
|
|
|
| 50 |
|
| 51 |
|
| 52 |
def load_champion_model():
|
| 53 |
+
global model, FEATURE_NAMES, MODEL_METRICS, loading_error
|
| 54 |
|
| 55 |
if os.path.exists(MODEL_PATH):
|
| 56 |
print(f"Loading champion model from {MODEL_PATH}")
|
| 57 |
+
start_time = time.perf_counter()
|
| 58 |
+
try:
|
| 59 |
+
with open(MODEL_PATH, "rb") as f:
|
| 60 |
+
artifact = pickle.load(f)
|
| 61 |
+
|
| 62 |
+
model = artifact["model"]
|
| 63 |
+
FEATURE_NAMES = artifact["feature_columns"]
|
| 64 |
+
MODEL_METRICS = artifact.get("test_metrics", {})
|
| 65 |
+
|
| 66 |
+
elapsed_time = time.perf_counter() - start_time
|
| 67 |
+
print(f"Model loaded: {len(FEATURE_NAMES)} features")
|
| 68 |
+
print(f"Test R2: {MODEL_METRICS.get('r2', 'N/A')}")
|
| 69 |
+
print(f"Model loading time: {elapsed_time:.2f} seconds")
|
| 70 |
+
return True
|
| 71 |
+
except Exception as e:
|
| 72 |
+
loading_error = f"Error loading champion model: {e}"
|
| 73 |
+
print(loading_error)
|
| 74 |
+
return False
|
| 75 |
+
|
| 76 |
+
loading_error = f"Champion model not found at {MODEL_PATH}"
|
| 77 |
+
print(loading_error)
|
| 78 |
return False
|
| 79 |
|
| 80 |
|
|
|
|
| 83 |
global CLASSIFICATION_FEATURE_NAMES
|
| 84 |
global CLASSIFICATION_CLASSES
|
| 85 |
global CLASSIFICATION_METRICS
|
| 86 |
+
global loading_error
|
| 87 |
|
| 88 |
if os.path.exists(CLASSIFICATION_MODEL_PATH):
|
| 89 |
print(f"Loading classification model from {CLASSIFICATION_MODEL_PATH}")
|
| 90 |
+
start_time = time.perf_counter()
|
| 91 |
+
try:
|
| 92 |
+
with open(CLASSIFICATION_MODEL_PATH, "rb") as f:
|
| 93 |
+
artifact = pickle.load(f)
|
| 94 |
+
|
| 95 |
+
classification_model = artifact["model"]
|
| 96 |
+
CLASSIFICATION_FEATURE_NAMES = artifact["feature_columns"]
|
| 97 |
+
CLASSIFICATION_CLASSES = artifact["classes"]
|
| 98 |
+
CLASSIFICATION_METRICS = artifact.get("test_metrics", {})
|
| 99 |
+
|
| 100 |
+
len_features = len(CLASSIFICATION_FEATURE_NAMES)
|
| 101 |
+
elapsed_time = time.perf_counter() - start_time
|
| 102 |
+
print(
|
| 103 |
+
f"Classification model loaded: {len_features} features")
|
| 104 |
+
print(f"Classes: {CLASSIFICATION_CLASSES}")
|
| 105 |
+
print(f"Classification model loading time: {elapsed_time:.2f} seconds")
|
| 106 |
+
return True
|
| 107 |
+
except Exception as e:
|
| 108 |
+
loading_error = f"Error loading classification model: {e}"
|
| 109 |
+
print(loading_error)
|
| 110 |
+
return False
|
| 111 |
+
|
| 112 |
+
loading_error = f"Classification model not found at {CLASSIFICATION_MODEL_PATH}"
|
| 113 |
+
print(loading_error)
|
| 114 |
return False
|
| 115 |
|
| 116 |
|
| 117 |
def predict_score(*feature_values):
|
| 118 |
if model is None:
|
| 119 |
+
if loading_error:
|
| 120 |
+
return "Error", loading_error, ""
|
| 121 |
+
return "Error", "Model not loaded yet", ""
|
| 122 |
|
| 123 |
features_df = pd.DataFrame([feature_values], columns=FEATURE_NAMES)
|
| 124 |
raw_score = model.predict(features_df)[0]
|
|
|
|
| 157 |
|
| 158 |
def predict_weakest_link(*feature_values):
|
| 159 |
if classification_model is None:
|
| 160 |
+
if loading_error:
|
| 161 |
+
return "Error", loading_error, ""
|
| 162 |
+
return "Error", "Classification model not loaded yet", ""
|
| 163 |
|
| 164 |
features_df = pd.DataFrame(
|
| 165 |
[feature_values], columns=CLASSIFICATION_FEATURE_NAMES)
|
|
|
|
| 253 |
|
| 254 |
|
| 255 |
def create_interface():
|
| 256 |
+
global models_loaded
|
| 257 |
+
|
| 258 |
if FEATURE_NAMES is None:
|
| 259 |
+
error_message = loading_error if loading_error else "Model not loaded"
|
| 260 |
return gr.Interface(
|
| 261 |
+
fn=lambda: error_message,
|
| 262 |
inputs=[],
|
| 263 |
outputs="text",
|
| 264 |
title="Error: Model not loaded"
|
|
|
|
| 445 |
|
| 446 |
return demo
|
| 447 |
|
| 448 |
+
def load_models_async():
|
| 449 |
+
global models_loaded
|
| 450 |
+
start_time = time.perf_counter()
|
| 451 |
+
print("Starting asynchronous model loading...")
|
| 452 |
load_champion_model()
|
| 453 |
load_classification_model()
|
| 454 |
+
models_loaded = True
|
| 455 |
+
elapsed_time = time.perf_counter() - start_time
|
| 456 |
+
print(f"Model loading complete (total time: {elapsed_time:.2f} seconds)")
|
| 457 |
+
|
| 458 |
+
if __name__ == "__main__":
|
| 459 |
+
# Load models asynchronously in background threads
|
| 460 |
|
| 461 |
+
# Start model loading in background thread
|
| 462 |
+
loading_thread = threading.Thread(target=load_models_async, daemon=True)
|
| 463 |
+
loading_thread.start()
|
| 464 |
+
|
| 465 |
+
# Create the interface immediately (models loading in background)
|
| 466 |
demo = create_interface()
|
| 467 |
+
|
| 468 |
+
# Add loading status to the interface
|
| 469 |
+
if not models_loaded:
|
| 470 |
+
print("Models are loading in the background...")
|
| 471 |
+
print("You can use the interface while models load.")
|
| 472 |
+
|
| 473 |
demo.launch(share=False, server_name="0.0.0.0", server_port=7860)
|