gamaly commited on
Commit
327be00
·
verified ·
1 Parent(s): fa9a3ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -12
app.py CHANGED
@@ -11,27 +11,54 @@ LOCAL_MODEL_PATH = "./maritime_classifier"
11
 
12
  # Load model
13
  print("Loading model...")
 
 
 
 
14
  try:
15
- # Try Hugging Face Hub first
16
  if "/" in MODEL_PATH and not Path(MODEL_PATH).exists():
 
 
 
 
 
 
 
 
 
 
 
17
  model = SetFitModel.from_pretrained(MODEL_PATH)
18
- print(f"✓ Loaded model from Hugging Face: {MODEL_PATH}")
 
19
  else:
20
- # Try local path
21
- if Path(LOCAL_MODEL_PATH).exists():
22
- model = SetFitModel.from_pretrained(LOCAL_MODEL_PATH)
23
- print(f"✓ Loaded model from local path: {LOCAL_MODEL_PATH}")
24
- else:
25
- raise FileNotFoundError(f"Model not found at {MODEL_PATH} or {LOCAL_MODEL_PATH}")
26
  except Exception as e:
27
- print(f"⚠️ Error loading model: {e}")
28
- print("Make sure the model is trained or uploaded to Hugging Face")
 
 
 
 
 
29
  model = None
30
 
 
 
 
 
 
 
 
 
 
31
  def predict_text(text):
32
  """Predict whether text is actionable (YES) or not (NO)."""
33
  if model is None:
34
- return "Error: Model not loaded. Please train the model first.", 0.0, "error"
35
 
36
  if not text or not text.strip():
37
  return "Please enter some text to classify.", 0.0, "neutral"
@@ -52,7 +79,11 @@ def predict_text(text):
52
 
53
  return label, confidence, status
54
  except Exception as e:
55
- return f"Error during prediction: {str(e)}", 0.0, "error"
 
 
 
 
56
 
57
  def get_explanation(status):
58
  """Get explanation based on prediction status."""
 
11
 
12
  # Load model
13
  print("Loading model...")
14
+ print(f"MODEL_PATH: {MODEL_PATH}")
15
+ print(f"LOCAL_MODEL_PATH: {LOCAL_MODEL_PATH}")
16
+ model = None
17
+
18
  try:
19
+ # Check if MODEL_PATH is a Hugging Face repo (contains "/" and doesn't exist locally)
20
  if "/" in MODEL_PATH and not Path(MODEL_PATH).exists():
21
+ print(f"Loading from Hugging Face Hub: {MODEL_PATH}")
22
+ model = SetFitModel.from_pretrained(MODEL_PATH)
23
+ print(f"✓ Successfully loaded model from Hugging Face: {MODEL_PATH}")
24
+ # Check if local model path exists
25
+ elif Path(LOCAL_MODEL_PATH).exists():
26
+ print(f"Loading from local path: {LOCAL_MODEL_PATH}")
27
+ model = SetFitModel.from_pretrained(LOCAL_MODEL_PATH)
28
+ print(f"✓ Successfully loaded model from local path: {LOCAL_MODEL_PATH}")
29
+ # If MODEL_PATH is a local path that exists
30
+ elif Path(MODEL_PATH).exists():
31
+ print(f"Loading from local path: {MODEL_PATH}")
32
  model = SetFitModel.from_pretrained(MODEL_PATH)
33
+ print(f"✓ Successfully loaded model from local path: {MODEL_PATH}")
34
+ # Default: try MODEL_PATH as Hugging Face repo
35
  else:
36
+ print(f"Attempting to load from Hugging Face Hub: {MODEL_PATH}")
37
+ model = SetFitModel.from_pretrained(MODEL_PATH)
38
+ print(f"✓ Successfully loaded model from Hugging Face: {MODEL_PATH}")
 
 
 
39
  except Exception as e:
40
+ print(f"Error loading model: {e}")
41
+ print(f" Attempted paths:")
42
+ print(f" - Hugging Face: {MODEL_PATH}")
43
+ print(f" - Local: {LOCAL_MODEL_PATH}")
44
+ import traceback
45
+ print("\nFull traceback:")
46
+ traceback.print_exc()
47
  model = None
48
 
49
+ if model is None:
50
+ print("\n⚠️ WARNING: Model failed to load. The app will not work correctly.")
51
+ print(" Please check:")
52
+ print(f" 1. Model exists at: https://huggingface.co/{MODEL_PATH}")
53
+ print(" 2. Internet connection is available")
54
+ print(" 3. All dependencies are installed (setfit, sentence-transformers, etc.)")
55
+ else:
56
+ print("\n✅ Model loaded successfully! Ready for inference.")
57
+
58
  def predict_text(text):
59
  """Predict whether text is actionable (YES) or not (NO)."""
60
  if model is None:
61
+ return "Error: Model not loaded. Please check the console logs.", 0.0, "error"
62
 
63
  if not text or not text.strip():
64
  return "Please enter some text to classify.", 0.0, "neutral"
 
79
 
80
  return label, confidence, status
81
  except Exception as e:
82
+ error_msg = f"Error during prediction: {str(e)}"
83
+ print(error_msg)
84
+ import traceback
85
+ traceback.print_exc()
86
+ return error_msg, 0.0, "error"
87
 
88
  def get_explanation(status):
89
  """Get explanation based on prediction status."""