gamaly commited on
Commit
a9d4f37
Β·
verified Β·
1 Parent(s): 7bfc262

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +167 -115
app.py CHANGED
@@ -1,161 +1,201 @@
1
- """Gradio app for Maritime Intelligence Classifier."""
2
  import gradio as gr
3
  from setfit import SetFitModel
 
4
  from pathlib import Path
5
  import os
6
 
7
- # Try to load model from Hugging Face Hub first, then fall back to local
8
- # Set MODEL_PATH environment variable or update this line with your Hugging Face repo ID
9
- MODEL_PATH = os.getenv("MODEL_PATH", "gamaly/maritime-intelligence-classifier")
10
- LOCAL_MODEL_PATH = "./maritime_classifier"
 
 
11
 
12
- # Load model
13
- print("Loading model...")
14
- print(f"MODEL_PATH: {MODEL_PATH}")
15
- print(f"LOCAL_MODEL_PATH: {LOCAL_MODEL_PATH}")
16
- model = None
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  try:
19
- # Check if MODEL_PATH is a Hugging Face repo (contains "/" and doesn't exist locally)
20
- if "/" in MODEL_PATH and not Path(MODEL_PATH).exists():
21
- print(f"Loading from Hugging Face Hub: {MODEL_PATH}")
22
- model = SetFitModel.from_pretrained(MODEL_PATH)
23
- print(f"βœ“ Successfully loaded model from Hugging Face: {MODEL_PATH}")
24
- # Check if local model path exists
25
- elif Path(LOCAL_MODEL_PATH).exists():
26
- print(f"Loading from local path: {LOCAL_MODEL_PATH}")
27
- model = SetFitModel.from_pretrained(LOCAL_MODEL_PATH)
28
- print(f"βœ“ Successfully loaded model from local path: {LOCAL_MODEL_PATH}")
29
- # If MODEL_PATH is a local path that exists
30
- elif Path(MODEL_PATH).exists():
31
- print(f"Loading from local path: {MODEL_PATH}")
32
- model = SetFitModel.from_pretrained(MODEL_PATH)
33
- print(f"βœ“ Successfully loaded model from local path: {MODEL_PATH}")
34
- # Default: try MODEL_PATH as Hugging Face repo
35
  else:
36
- print(f"Attempting to load from Hugging Face Hub: {MODEL_PATH}")
37
- model = SetFitModel.from_pretrained(MODEL_PATH)
38
- print(f"βœ“ Successfully loaded model from Hugging Face: {MODEL_PATH}")
39
  except Exception as e:
40
- print(f"❌ Error loading model: {e}")
41
- print(f" Attempted paths:")
42
- print(f" - Hugging Face: {MODEL_PATH}")
43
- print(f" - Local: {LOCAL_MODEL_PATH}")
44
- import traceback
45
- print("\nFull traceback:")
46
- traceback.print_exc()
47
- model = None
48
 
49
- if model is None:
50
- print("\n⚠️ WARNING: Model failed to load. The app will not work correctly.")
51
- print(" Please check:")
52
- print(f" 1. Model exists at: https://huggingface.co/{MODEL_PATH}")
53
- print(" 2. Internet connection is available")
54
- print(" 3. All dependencies are installed (setfit, sentence-transformers, etc.)")
55
  else:
56
- print("\nβœ… Model loaded successfully! Ready for inference.")
 
57
 
 
 
 
58
  def truncate_text(text, max_tokens=256):
59
- """
60
- Truncate text to approximately max_tokens.
61
- Uses a simple word-based approximation (roughly 1 token = 0.75 words).
62
- """
63
  if not text:
64
  return text
65
 
66
- # Rough approximation: 1 token β‰ˆ 0.75 words (conservative estimate)
67
  max_words = int(max_tokens * 0.75)
68
  words = text.split()
69
 
70
  if len(words) <= max_words:
71
  return text
72
 
73
- # Truncate and add ellipsis
74
  truncated = " ".join(words[:max_words])
75
  return truncated + "... [truncated]"
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  def predict_text(text):
78
- """Predict whether text is actionable (YES) or not (NO)."""
79
- if model is None:
80
- return "Error: Model not loaded. Please check the console logs.", 0.0, "error"
81
 
82
  if not text or not text.strip():
83
  return "Please enter some text to classify.", 0.0, "neutral"
84
 
85
  try:
86
- # Note: SetFit uses the base model's max_length (256 tokens for all-MiniLM-L6-v2)
87
- # The model will automatically truncate longer texts, but we can pre-truncate
88
- # to ensure we're using the most relevant part (beginning of text)
89
- # For longer articles, the beginning usually contains the most important info
90
-
91
- # Check approximate length (rough estimate: 1 token β‰ˆ 0.75 words)
92
  word_count = len(text.split())
93
  token_estimate = int(word_count / 0.75)
94
 
95
- # If text is significantly longer than 256 tokens, truncate intelligently
96
- # (SetFit will truncate anyway, but we can control which part)
97
- if token_estimate > 300: # Give some buffer
98
- # For news articles, the beginning usually has the key info
99
- # But we could also try: beginning + end, or just beginning
100
  processed_text = truncate_text(text, max_tokens=256)
101
- print(f"⚠️ Text truncated from ~{token_estimate} tokens to ~256 tokens")
102
  else:
103
  processed_text = text
104
 
105
  # Make prediction
106
- prediction = model.predict([processed_text])[0]
107
 
108
- # Get probabilities (handle version compatibility)
109
  try:
110
- probabilities = model.predict_proba([processed_text])[0]
111
  confidence = probabilities[prediction] * 100
112
- except AttributeError as e:
113
- # Fallback if predict_proba fails due to version mismatch
114
- # Use a simple confidence estimate based on prediction
115
- print(f"Warning: predict_proba failed ({e}), using fallback confidence")
116
- # For binary classification, we can estimate confidence from the decision function
117
- # or just use a default high confidence
118
- confidence = 85.0 # Default confidence when we can't get probabilities
119
-
120
- # Convert to labels
121
- label = "YES (Actionable)" if prediction == 1 else "NO (Not Actionable)"
122
 
123
- # Determine status for styling
124
  status = "actionable" if prediction == 1 else "not_actionable"
125
 
126
  return label, confidence, status
127
  except Exception as e:
128
- error_msg = f"Error during prediction: {str(e)}"
129
- print(error_msg)
130
- import traceback
131
- traceback.print_exc()
132
- return error_msg, 0.0, "error"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  def get_explanation(status):
135
  """Get explanation based on prediction status."""
136
  explanations = {
137
- "actionable": "βœ“ This text contains actionable vessel-specific evidence (e.g., specific vessel names, crimes, incidents).",
138
- "not_actionable": "βœ— This text does not contain actionable vessel-specific evidence (e.g., general maritime news, non-specific information).",
139
  "error": "⚠️ An error occurred. Please check the model is properly loaded.",
140
  "neutral": ""
141
  }
142
  return explanations.get(status, "")
143
 
144
- # Create Gradio interface
145
- # Note: theme parameter moved to launch() in Gradio 6.0+
 
146
  with gr.Blocks(title="Maritime Intelligence Classifier") as app:
147
  gr.Markdown(
148
  """
149
  # 🚒 Maritime Intelligence Classifier
150
 
151
- Classify maritime news articles as containing **actionable vessel-specific evidence** (YES) or not (NO).
152
-
153
- **Actionable articles** typically include:
154
- - Specific vessel names
155
- - Specific crimes or incidents
156
- - Evidence that can be used for investigation
157
-
158
- **Non-actionable articles** are general maritime news without specific vessel details.
159
  """
160
  )
161
 
@@ -168,9 +208,11 @@ with gr.Blocks(title="Maritime Intelligence Classifier") as app:
168
  max_lines=20
169
  )
170
 
171
- submit_btn = gr.Button("Classify", variant="primary", size="lg")
172
 
173
  with gr.Column(scale=1):
 
 
174
  prediction_output = gr.Label(
175
  label="Prediction",
176
  value={"YES (Actionable)": 0.0, "NO (Not Actionable)": 0.0}
@@ -183,6 +225,13 @@ with gr.Blocks(title="Maritime Intelligence Classifier") as app:
183
  )
184
 
185
  explanation_output = gr.Markdown()
 
 
 
 
 
 
 
186
 
187
  # Example texts
188
  gr.Markdown("### πŸ“ Example Texts")
@@ -190,10 +239,10 @@ with gr.Blocks(title="Maritime Intelligence Classifier") as app:
190
  example_yes = gr.Examples(
191
  examples=[
192
  ["The fishing vessel Marine 707 was involved in the disappearance of fisheries observer Samuel Abayateye in Ghanaian waters. The observer's decapitated body was found weeks later."],
193
- ["Authorities detained the Meng Xin 15 after discovering evidence of illegal saiko transshipment and threats against fisheries observers."],
194
  ],
195
  inputs=text_input,
196
- label="YES Examples (Actionable)"
197
  )
198
 
199
  example_no = gr.Examples(
@@ -202,14 +251,15 @@ with gr.Blocks(title="Maritime Intelligence Classifier") as app:
202
  ["Marine scientists are studying the effects of ocean acidification on coral reefs in tropical waters."],
203
  ],
204
  inputs=text_input,
205
- label="NO Examples (Not Actionable)"
206
  )
207
 
208
- # Connect the prediction function
209
- def update_prediction(text):
 
210
  label, confidence, status = predict_text(text)
211
 
212
- # Create label dict for gradio Label component
213
  if status == "actionable":
214
  label_dict = {"YES (Actionable)": confidence / 100, "NO (Not Actionable)": (100 - confidence) / 100}
215
  elif status == "not_actionable":
@@ -219,18 +269,22 @@ with gr.Blocks(title="Maritime Intelligence Classifier") as app:
219
 
220
  explanation = get_explanation(status)
221
 
222
- return label_dict, confidence, explanation
 
 
 
 
223
 
224
  submit_btn.click(
225
- fn=update_prediction,
226
  inputs=text_input,
227
- outputs=[prediction_output, confidence_output, explanation_output]
228
  )
229
 
230
  text_input.submit(
231
- fn=update_prediction,
232
  inputs=text_input,
233
- outputs=[prediction_output, confidence_output, explanation_output]
234
  )
235
 
236
  gr.Markdown(
@@ -238,15 +292,13 @@ with gr.Blocks(title="Maritime Intelligence Classifier") as app:
238
  ---
239
  ### ℹ️ About
240
 
241
- This classifier uses SetFit to identify maritime news articles containing actionable vessel-specific evidence.
242
- Built for The Outlaw Ocean Project.
 
243
 
244
- **Model**: SetFit (sentence-transformers/all-MiniLM-L6-v2 base)
245
  """
246
  )
247
 
248
  if __name__ == "__main__":
249
- app.launch(share=False, theme=gr.themes.Soft())
250
-
251
-
252
-
 
1
+ """Gradio app for Maritime Intelligence Classifier + Entity Extraction."""
2
  import gradio as gr
3
  from setfit import SetFitModel
4
+ from transformers import pipeline
5
  from pathlib import Path
6
  import os
7
 
8
+ # ============================================================
9
+ # MODEL PATHS
10
+ # ============================================================
11
+ # Classification model (SetFit)
12
+ CLASSIFIER_PATH = os.getenv("CLASSIFIER_PATH", "gamaly/maritime-intelligence-classifier")
13
+ LOCAL_CLASSIFIER_PATH = "./maritime_classifier"
14
 
15
+ # NER model (BERT) - UPDATE THIS WITH YOUR HF REPO
16
+ NER_PATH = os.getenv("NER_PATH", "gamaly/bert-vessel-ner") # ← Change to your repo!
17
+ LOCAL_NER_PATH = "./models/bert-vessel-ner"
 
 
18
 
19
+ # ============================================================
20
+ # LOAD MODELS
21
+ # ============================================================
22
+ print("="*60)
23
+ print("Loading models...")
24
+ print("="*60)
25
+
26
+ # Load Classification Model
27
+ classifier = None
28
+ try:
29
+ if "/" in CLASSIFIER_PATH and not Path(CLASSIFIER_PATH).exists():
30
+ print(f"Loading classifier from HuggingFace: {CLASSIFIER_PATH}")
31
+ classifier = SetFitModel.from_pretrained(CLASSIFIER_PATH)
32
+ elif Path(LOCAL_CLASSIFIER_PATH).exists():
33
+ print(f"Loading classifier from local: {LOCAL_CLASSIFIER_PATH}")
34
+ classifier = SetFitModel.from_pretrained(LOCAL_CLASSIFIER_PATH)
35
+ else:
36
+ print(f"Loading classifier from HuggingFace: {CLASSIFIER_PATH}")
37
+ classifier = SetFitModel.from_pretrained(CLASSIFIER_PATH)
38
+ print(f"βœ“ Classifier loaded")
39
+ except Exception as e:
40
+ print(f"❌ Classifier failed to load: {e}")
41
+
42
+ # Load NER Model
43
+ ner_model = None
44
  try:
45
+ if "/" in NER_PATH and not Path(NER_PATH).exists():
46
+ print(f"Loading NER from HuggingFace: {NER_PATH}")
47
+ ner_model = pipeline("ner", model=NER_PATH, aggregation_strategy="simple")
48
+ elif Path(LOCAL_NER_PATH).exists():
49
+ print(f"Loading NER from local: {LOCAL_NER_PATH}")
50
+ ner_model = pipeline("ner", model=LOCAL_NER_PATH, aggregation_strategy="simple")
 
 
 
 
 
 
 
 
 
 
51
  else:
52
+ print(f"Loading NER from HuggingFace: {NER_PATH}")
53
+ ner_model = pipeline("ner", model=NER_PATH, aggregation_strategy="simple")
54
+ print(f"βœ“ NER model loaded")
55
  except Exception as e:
56
+ print(f"❌ NER model failed to load: {e}")
 
 
 
 
 
 
 
57
 
58
+ print("="*60)
59
+ if classifier and ner_model:
60
+ print("βœ… All models loaded successfully!")
 
 
 
61
  else:
62
+ print("⚠️ Some models failed to load. Check logs above.")
63
+ print("="*60)
64
 
65
+ # ============================================================
66
+ # HELPER FUNCTIONS
67
+ # ============================================================
68
  def truncate_text(text, max_tokens=256):
69
+ """Truncate text to approximately max_tokens."""
 
 
 
70
  if not text:
71
  return text
72
 
 
73
  max_words = int(max_tokens * 0.75)
74
  words = text.split()
75
 
76
  if len(words) <= max_words:
77
  return text
78
 
 
79
  truncated = " ".join(words[:max_words])
80
  return truncated + "... [truncated]"
81
 
82
+ def extract_entities(text):
83
+ """Extract VESSEL and ORG entities from text."""
84
+ if ner_model is None:
85
+ return [], []
86
+
87
+ if not text or not text.strip():
88
+ return [], []
89
+
90
+ try:
91
+ entities = ner_model(text)
92
+
93
+ vessels = []
94
+ orgs = []
95
+
96
+ for e in entities:
97
+ entity_text = e['word'].strip()
98
+ score = e['score']
99
+ entity_type = e['entity_group']
100
+
101
+ # Skip low confidence
102
+ if score < 0.5:
103
+ continue
104
+
105
+ # Clean up tokenization artifacts
106
+ entity_text = entity_text.replace(" ##", "").replace("##", "")
107
+
108
+ if entity_type == 'VESSEL':
109
+ vessels.append({"text": entity_text, "score": score})
110
+ elif entity_type == 'ORG':
111
+ orgs.append({"text": entity_text, "score": score})
112
+
113
+ # Deduplicate
114
+ vessels = list({v['text']: v for v in vessels}.values())
115
+ orgs = list({o['text']: o for o in orgs}.values())
116
+
117
+ return vessels, orgs
118
+ except Exception as e:
119
+ print(f"NER error: {e}")
120
+ return [], []
121
+
122
  def predict_text(text):
123
+ """Predict whether text is actionable and extract entities."""
124
+ if classifier is None:
125
+ return "Error: Classifier not loaded.", 0.0, "error"
126
 
127
  if not text or not text.strip():
128
  return "Please enter some text to classify.", 0.0, "neutral"
129
 
130
  try:
131
+ # Truncate if needed
 
 
 
 
 
132
  word_count = len(text.split())
133
  token_estimate = int(word_count / 0.75)
134
 
135
+ if token_estimate > 300:
 
 
 
 
136
  processed_text = truncate_text(text, max_tokens=256)
 
137
  else:
138
  processed_text = text
139
 
140
  # Make prediction
141
+ prediction = classifier.predict([processed_text])[0]
142
 
143
+ # Get probabilities
144
  try:
145
+ probabilities = classifier.predict_proba([processed_text])[0]
146
  confidence = probabilities[prediction] * 100
147
+ except AttributeError:
148
+ confidence = 85.0
 
 
 
 
 
 
 
 
149
 
150
+ label = "YES (Actionable)" if prediction == 1 else "NO (Not Actionable)"
151
  status = "actionable" if prediction == 1 else "not_actionable"
152
 
153
  return label, confidence, status
154
  except Exception as e:
155
+ print(f"Classification error: {e}")
156
+ return f"Error: {str(e)}", 0.0, "error"
157
+
158
+ def format_entities(vessels, orgs):
159
+ """Format extracted entities as markdown."""
160
+ if not vessels and not orgs:
161
+ return "No entities detected."
162
+
163
+ output = ""
164
+
165
+ if vessels:
166
+ output += "### 🚒 Vessels\n"
167
+ for v in vessels:
168
+ output += f"- **{v['text']}** ({v['score']:.0%})\n"
169
+ output += "\n"
170
+
171
+ if orgs:
172
+ output += "### 🏒 Organizations\n"
173
+ for o in orgs:
174
+ output += f"- **{o['text']}** ({o['score']:.0%})\n"
175
+
176
+ return output
177
 
178
  def get_explanation(status):
179
  """Get explanation based on prediction status."""
180
  explanations = {
181
+ "actionable": "βœ“ This text contains actionable vessel-specific evidence.",
182
+ "not_actionable": "βœ— This text does not contain actionable vessel-specific evidence.",
183
  "error": "⚠️ An error occurred. Please check the model is properly loaded.",
184
  "neutral": ""
185
  }
186
  return explanations.get(status, "")
187
 
188
+ # ============================================================
189
+ # GRADIO APP
190
+ # ============================================================
191
  with gr.Blocks(title="Maritime Intelligence Classifier") as app:
192
  gr.Markdown(
193
  """
194
  # 🚒 Maritime Intelligence Classifier
195
 
196
+ **Two-stage analysis:**
197
+ 1. **Classification** - Is this article actionable?
198
+ 2. **Entity Extraction** - What vessels and organizations are mentioned?
 
 
 
 
 
199
  """
200
  )
201
 
 
208
  max_lines=20
209
  )
210
 
211
+ submit_btn = gr.Button("Analyze", variant="primary", size="lg")
212
 
213
  with gr.Column(scale=1):
214
+ # Classification results
215
+ gr.Markdown("### πŸ“Š Classification")
216
  prediction_output = gr.Label(
217
  label="Prediction",
218
  value={"YES (Actionable)": 0.0, "NO (Not Actionable)": 0.0}
 
225
  )
226
 
227
  explanation_output = gr.Markdown()
228
+
229
+ # Entity extraction results
230
+ gr.Markdown("---")
231
+ entities_output = gr.Markdown(
232
+ label="Extracted Entities",
233
+ value="### πŸ” Extracted Entities\nNo entities detected yet."
234
+ )
235
 
236
  # Example texts
237
  gr.Markdown("### πŸ“ Example Texts")
 
239
  example_yes = gr.Examples(
240
  examples=[
241
  ["The fishing vessel Marine 707 was involved in the disappearance of fisheries observer Samuel Abayateye in Ghanaian waters. The observer's decapitated body was found weeks later."],
242
+ ["Authorities detained the Meng Xin 15 after discovering evidence of illegal saiko transshipment. Pacific Seafood Inc. was identified as the vessel operator."],
243
  ],
244
  inputs=text_input,
245
+ label="Actionable Examples"
246
  )
247
 
248
  example_no = gr.Examples(
 
251
  ["Marine scientists are studying the effects of ocean acidification on coral reefs in tropical waters."],
252
  ],
253
  inputs=text_input,
254
+ label="Non-Actionable Examples"
255
  )
256
 
257
+ # Main analysis function
258
+ def analyze_text(text):
259
+ # Classification
260
  label, confidence, status = predict_text(text)
261
 
262
+ # Create label dict
263
  if status == "actionable":
264
  label_dict = {"YES (Actionable)": confidence / 100, "NO (Not Actionable)": (100 - confidence) / 100}
265
  elif status == "not_actionable":
 
269
 
270
  explanation = get_explanation(status)
271
 
272
+ # Entity extraction
273
+ vessels, orgs = extract_entities(text)
274
+ entities_md = "### πŸ” Extracted Entities\n" + format_entities(vessels, orgs)
275
+
276
+ return label_dict, confidence, explanation, entities_md
277
 
278
  submit_btn.click(
279
+ fn=analyze_text,
280
  inputs=text_input,
281
+ outputs=[prediction_output, confidence_output, explanation_output, entities_output]
282
  )
283
 
284
  text_input.submit(
285
+ fn=analyze_text,
286
  inputs=text_input,
287
+ outputs=[prediction_output, confidence_output, explanation_output, entities_output]
288
  )
289
 
290
  gr.Markdown(
 
292
  ---
293
  ### ℹ️ About
294
 
295
+ **Classification**: SetFit model identifies actionable maritime intelligence.
296
+
297
+ **Entity Extraction**: BERT-NER model extracts vessel names and organizations.
298
 
299
+ Built for The Outlaw Ocean Project.
300
  """
301
  )
302
 
303
  if __name__ == "__main__":
304
+ app.launch(share=False, theme=gr.themes.Soft())