gamaly's picture
Update app.py
a9d4f37 verified
"""Gradio app for Maritime Intelligence Classifier + Entity Extraction."""
import gradio as gr
from setfit import SetFitModel
from transformers import pipeline
from pathlib import Path
import os
# ============================================================
# MODEL PATHS
# ============================================================
# Classification model (SetFit)
CLASSIFIER_PATH = os.getenv("CLASSIFIER_PATH", "gamaly/maritime-intelligence-classifier")
LOCAL_CLASSIFIER_PATH = "./maritime_classifier"
# NER model (BERT) - UPDATE THIS WITH YOUR HF REPO
NER_PATH = os.getenv("NER_PATH", "gamaly/bert-vessel-ner") # ← Change to your repo!
LOCAL_NER_PATH = "./models/bert-vessel-ner"
# ============================================================
# LOAD MODELS
# ============================================================
print("="*60)
print("Loading models...")
print("="*60)
# Load Classification Model
classifier = None
try:
if "/" in CLASSIFIER_PATH and not Path(CLASSIFIER_PATH).exists():
print(f"Loading classifier from HuggingFace: {CLASSIFIER_PATH}")
classifier = SetFitModel.from_pretrained(CLASSIFIER_PATH)
elif Path(LOCAL_CLASSIFIER_PATH).exists():
print(f"Loading classifier from local: {LOCAL_CLASSIFIER_PATH}")
classifier = SetFitModel.from_pretrained(LOCAL_CLASSIFIER_PATH)
else:
print(f"Loading classifier from HuggingFace: {CLASSIFIER_PATH}")
classifier = SetFitModel.from_pretrained(CLASSIFIER_PATH)
print(f"βœ“ Classifier loaded")
except Exception as e:
print(f"❌ Classifier failed to load: {e}")
# Load NER Model
ner_model = None
try:
if "/" in NER_PATH and not Path(NER_PATH).exists():
print(f"Loading NER from HuggingFace: {NER_PATH}")
ner_model = pipeline("ner", model=NER_PATH, aggregation_strategy="simple")
elif Path(LOCAL_NER_PATH).exists():
print(f"Loading NER from local: {LOCAL_NER_PATH}")
ner_model = pipeline("ner", model=LOCAL_NER_PATH, aggregation_strategy="simple")
else:
print(f"Loading NER from HuggingFace: {NER_PATH}")
ner_model = pipeline("ner", model=NER_PATH, aggregation_strategy="simple")
print(f"βœ“ NER model loaded")
except Exception as e:
print(f"❌ NER model failed to load: {e}")
print("="*60)
if classifier and ner_model:
print("βœ… All models loaded successfully!")
else:
print("⚠️ Some models failed to load. Check logs above.")
print("="*60)
# ============================================================
# HELPER FUNCTIONS
# ============================================================
def truncate_text(text, max_tokens=256):
"""Truncate text to approximately max_tokens."""
if not text:
return text
max_words = int(max_tokens * 0.75)
words = text.split()
if len(words) <= max_words:
return text
truncated = " ".join(words[:max_words])
return truncated + "... [truncated]"
def extract_entities(text):
"""Extract VESSEL and ORG entities from text."""
if ner_model is None:
return [], []
if not text or not text.strip():
return [], []
try:
entities = ner_model(text)
vessels = []
orgs = []
for e in entities:
entity_text = e['word'].strip()
score = e['score']
entity_type = e['entity_group']
# Skip low confidence
if score < 0.5:
continue
# Clean up tokenization artifacts
entity_text = entity_text.replace(" ##", "").replace("##", "")
if entity_type == 'VESSEL':
vessels.append({"text": entity_text, "score": score})
elif entity_type == 'ORG':
orgs.append({"text": entity_text, "score": score})
# Deduplicate
vessels = list({v['text']: v for v in vessels}.values())
orgs = list({o['text']: o for o in orgs}.values())
return vessels, orgs
except Exception as e:
print(f"NER error: {e}")
return [], []
def predict_text(text):
"""Predict whether text is actionable and extract entities."""
if classifier is None:
return "Error: Classifier not loaded.", 0.0, "error"
if not text or not text.strip():
return "Please enter some text to classify.", 0.0, "neutral"
try:
# Truncate if needed
word_count = len(text.split())
token_estimate = int(word_count / 0.75)
if token_estimate > 300:
processed_text = truncate_text(text, max_tokens=256)
else:
processed_text = text
# Make prediction
prediction = classifier.predict([processed_text])[0]
# Get probabilities
try:
probabilities = classifier.predict_proba([processed_text])[0]
confidence = probabilities[prediction] * 100
except AttributeError:
confidence = 85.0
label = "YES (Actionable)" if prediction == 1 else "NO (Not Actionable)"
status = "actionable" if prediction == 1 else "not_actionable"
return label, confidence, status
except Exception as e:
print(f"Classification error: {e}")
return f"Error: {str(e)}", 0.0, "error"
def format_entities(vessels, orgs):
"""Format extracted entities as markdown."""
if not vessels and not orgs:
return "No entities detected."
output = ""
if vessels:
output += "### 🚒 Vessels\n"
for v in vessels:
output += f"- **{v['text']}** ({v['score']:.0%})\n"
output += "\n"
if orgs:
output += "### 🏒 Organizations\n"
for o in orgs:
output += f"- **{o['text']}** ({o['score']:.0%})\n"
return output
def get_explanation(status):
"""Get explanation based on prediction status."""
explanations = {
"actionable": "βœ“ This text contains actionable vessel-specific evidence.",
"not_actionable": "βœ— This text does not contain actionable vessel-specific evidence.",
"error": "⚠️ An error occurred. Please check the model is properly loaded.",
"neutral": ""
}
return explanations.get(status, "")
# ============================================================
# GRADIO APP
# ============================================================
with gr.Blocks(title="Maritime Intelligence Classifier") as app:
gr.Markdown(
"""
# 🚒 Maritime Intelligence Classifier
**Two-stage analysis:**
1. **Classification** - Is this article actionable?
2. **Entity Extraction** - What vessels and organizations are mentioned?
"""
)
with gr.Row():
with gr.Column(scale=2):
text_input = gr.Textbox(
label="Article Text",
placeholder="Paste or type the maritime news article text here...",
lines=10,
max_lines=20
)
submit_btn = gr.Button("Analyze", variant="primary", size="lg")
with gr.Column(scale=1):
# Classification results
gr.Markdown("### πŸ“Š Classification")
prediction_output = gr.Label(
label="Prediction",
value={"YES (Actionable)": 0.0, "NO (Not Actionable)": 0.0}
)
confidence_output = gr.Number(
label="Confidence",
value=0.0,
precision=1
)
explanation_output = gr.Markdown()
# Entity extraction results
gr.Markdown("---")
entities_output = gr.Markdown(
label="Extracted Entities",
value="### πŸ” Extracted Entities\nNo entities detected yet."
)
# Example texts
gr.Markdown("### πŸ“ Example Texts")
with gr.Row():
example_yes = gr.Examples(
examples=[
["The fishing vessel Marine 707 was involved in the disappearance of fisheries observer Samuel Abayateye in Ghanaian waters. The observer's decapitated body was found weeks later."],
["Authorities detained the Meng Xin 15 after discovering evidence of illegal saiko transshipment. Pacific Seafood Inc. was identified as the vessel operator."],
],
inputs=text_input,
label="Actionable Examples"
)
example_no = gr.Examples(
examples=[
["A new maritime museum opened in the port city, showcasing historical ships and ocean exploration artifacts."],
["Marine scientists are studying the effects of ocean acidification on coral reefs in tropical waters."],
],
inputs=text_input,
label="Non-Actionable Examples"
)
# Main analysis function
def analyze_text(text):
# Classification
label, confidence, status = predict_text(text)
# Create label dict
if status == "actionable":
label_dict = {"YES (Actionable)": confidence / 100, "NO (Not Actionable)": (100 - confidence) / 100}
elif status == "not_actionable":
label_dict = {"YES (Actionable)": (100 - confidence) / 100, "NO (Not Actionable)": confidence / 100}
else:
label_dict = {"YES (Actionable)": 0.0, "NO (Not Actionable)": 0.0}
explanation = get_explanation(status)
# Entity extraction
vessels, orgs = extract_entities(text)
entities_md = "### πŸ” Extracted Entities\n" + format_entities(vessels, orgs)
return label_dict, confidence, explanation, entities_md
submit_btn.click(
fn=analyze_text,
inputs=text_input,
outputs=[prediction_output, confidence_output, explanation_output, entities_output]
)
text_input.submit(
fn=analyze_text,
inputs=text_input,
outputs=[prediction_output, confidence_output, explanation_output, entities_output]
)
gr.Markdown(
"""
---
### ℹ️ About
**Classification**: SetFit model identifies actionable maritime intelligence.
**Entity Extraction**: BERT-NER model extracts vessel names and organizations.
Built for The Outlaw Ocean Project.
"""
)
if __name__ == "__main__":
app.launch(share=False, theme=gr.themes.Soft())