Nugget-cloud commited on
Commit
d073ebf
·
1 Parent(s): d0a6efc

Add better error handling and mock prediction fallback

Browse files
Files changed (1) hide show
  1. app.py +59 -22
app.py CHANGED
@@ -20,31 +20,38 @@ def load_models():
20
  # Load models from your repository
21
  repo_id = "Nugget-cloud/nasa-space-apps-exoplanet"
22
 
 
 
 
 
 
 
 
23
  print("Loading ensemble model...")
24
- ensemble_model = joblib.load(hf_hub_download(repo_id, "exoplanet_ensemble_model.joblib"))
25
 
26
  print("Loading feature scaler...")
27
- feature_scaler = joblib.load(hf_hub_download(repo_id, "feature_scaler.joblib"))
28
 
29
  print("Loading feature imputer...")
30
- feature_imputer = joblib.load(hf_hub_download(repo_id, "feature_imputer.joblib"))
31
 
32
  print("Loading variance selector...")
33
- variance_selector = joblib.load(hf_hub_download(repo_id, "variance_selector.joblib"))
34
 
35
  # Optional files
36
  try:
37
  print("Loading feature info...")
38
- feature_info = joblib.load(hf_hub_download(repo_id, "feature_info.joblib"))
39
- except:
40
- print("Feature info not found, skipping...")
41
  feature_info = None
42
 
43
  try:
44
  print("Loading model metrics...")
45
- model_metrics = joblib.load(hf_hub_download(repo_id, "model_metrics.joblib"))
46
- except:
47
- print("Model metrics not found, skipping...")
48
  model_metrics = None
49
 
50
  print("All models loaded successfully!")
@@ -52,28 +59,52 @@ def load_models():
52
 
53
  except Exception as e:
54
  print(f"Error loading models: {str(e)}")
 
 
55
  return False
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def predict_exoplanet(features_input):
58
  """Make prediction using the loaded models"""
59
  global ensemble_model, feature_scaler, feature_imputer, variance_selector
60
 
61
  try:
62
- # Load models if not already loaded
63
- if ensemble_model is None:
64
- if not load_models():
65
- return {"error": "Failed to load models"}
66
-
67
- # Parse input features
68
  if isinstance(features_input, str):
69
- # If input is comma-separated string
70
  features = [float(x.strip()) for x in features_input.split(',')]
71
  elif isinstance(features_input, list):
72
- # If input is already a list
73
  features = [float(x) for x in features_input]
74
  else:
75
  return {"error": "Invalid input format. Expected comma-separated string or list of numbers."}
76
 
 
 
 
 
 
 
77
  # Convert to numpy array
78
  features_array = np.array(features).reshape(1, -1)
79
 
@@ -112,10 +143,16 @@ def predict_exoplanet(features_input):
112
  return result
113
 
114
  except Exception as e:
115
- return {
116
- "success": False,
117
- "error": str(e)
118
- }
 
 
 
 
 
 
119
 
120
  # Create a simple interface that works well with API calls
121
  def simple_predict(features_str):
 
20
  # Load models from your repository
21
  repo_id = "Nugget-cloud/nasa-space-apps-exoplanet"
22
 
23
+ # Try to get token from environment or use None for public repos
24
+ import os
25
+ token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
26
+
27
+ print(f"Loading models from {repo_id}...")
28
+ print(f"Using token: {'Yes' if token else 'No (public repo)'}")
29
+
30
  print("Loading ensemble model...")
31
+ ensemble_model = joblib.load(hf_hub_download(repo_id, "exoplanet_ensemble_model.joblib", token=token))
32
 
33
  print("Loading feature scaler...")
34
+ feature_scaler = joblib.load(hf_hub_download(repo_id, "feature_scaler.joblib", token=token))
35
 
36
  print("Loading feature imputer...")
37
+ feature_imputer = joblib.load(hf_hub_download(repo_id, "feature_imputer.joblib", token=token))
38
 
39
  print("Loading variance selector...")
40
+ variance_selector = joblib.load(hf_hub_download(repo_id, "variance_selector.joblib", token=token))
41
 
42
  # Optional files
43
  try:
44
  print("Loading feature info...")
45
+ feature_info = joblib.load(hf_hub_download(repo_id, "feature_info.joblib", token=token))
46
+ except Exception as e:
47
+ print(f"Feature info not found: {e}")
48
  feature_info = None
49
 
50
  try:
51
  print("Loading model metrics...")
52
+ model_metrics = joblib.load(hf_hub_download(repo_id, "model_metrics.joblib", token=token))
53
+ except Exception as e:
54
+ print(f"Model metrics not found: {e}")
55
  model_metrics = None
56
 
57
  print("All models loaded successfully!")
 
59
 
60
  except Exception as e:
61
  print(f"Error loading models: {str(e)}")
62
+ print(f"Repository: {repo_id}")
63
+ print("Make sure the repository exists and is public, or add HF_TOKEN to environment")
64
  return False
65
 
66
+ def mock_predict(features):
67
+ """Fallback mock prediction when models can't be loaded"""
68
+ try:
69
+ # Simple mock logic based on feature values
70
+ feature_sum = sum(features)
71
+ prediction = 1 if feature_sum > 20 else 0
72
+ confidence = min(0.95, max(0.55, abs(feature_sum - 20) / 50 + 0.5))
73
+
74
+ return {
75
+ "success": True,
76
+ "prediction": prediction,
77
+ "probabilities": [1-confidence, confidence] if prediction == 1 else [confidence, 1-confidence],
78
+ "confidence": confidence,
79
+ "input_features_count": len(features),
80
+ "note": "Using mock prediction - models could not be loaded",
81
+ "mock": True
82
+ }
83
+ except Exception as e:
84
+ return {
85
+ "success": False,
86
+ "error": f"Mock prediction failed: {str(e)}"
87
+ }
88
+
89
  def predict_exoplanet(features_input):
90
  """Make prediction using the loaded models"""
91
  global ensemble_model, feature_scaler, feature_imputer, variance_selector
92
 
93
  try:
94
+ # Parse input features first
 
 
 
 
 
95
  if isinstance(features_input, str):
 
96
  features = [float(x.strip()) for x in features_input.split(',')]
97
  elif isinstance(features_input, list):
 
98
  features = [float(x) for x in features_input]
99
  else:
100
  return {"error": "Invalid input format. Expected comma-separated string or list of numbers."}
101
 
102
+ # Load models if not already loaded
103
+ if ensemble_model is None:
104
+ if not load_models():
105
+ print("Models failed to load, using mock prediction")
106
+ return mock_predict(features)
107
+
108
  # Convert to numpy array
109
  features_array = np.array(features).reshape(1, -1)
110
 
 
143
  return result
144
 
145
  except Exception as e:
146
+ print(f"Prediction error: {str(e)}")
147
+ # Fallback to mock prediction
148
+ try:
149
+ features = [float(x.strip()) for x in features_input.split(',')] if isinstance(features_input, str) else features_input
150
+ return mock_predict(features)
151
+ except:
152
+ return {
153
+ "success": False,
154
+ "error": str(e)
155
+ }
156
 
157
  # Create a simple interface that works well with API calls
158
  def simple_predict(features_str):