CrisisAI-Sentinal / app /models /topic_model.py
karthik119's picture
Deploy CrisisAI Sentinel - Deep Learning NLP Dashboard
6328796 verified
"""
Topic Modeling using Deep Learning Zero-Shot Classification.
Dynamically categorizes disaster topics from tweet corpus.
"""
import numpy as np
import warnings
warnings.filterwarnings('ignore')
try:
import torch
torch.zeros(1) # Will raise OSError if DLL fails
from transformers import pipeline
TRANSFORMERS_AVAILABLE = True
except Exception as e:
print(f"[Warning] Transformers/Torch unavailable (falling back to keyword rules): {e}")
TRANSFORMERS_AVAILABLE = False
# Topic labels based on expected disaster topics
TOPIC_LABEL_MAP = {
0: "πŸ”₯ Wildfire & Fire Emergency",
1: "🌊 Flood & Water Disaster",
2: "πŸŒ€ Hurricane & Storm",
3: "🌍 Earthquake & Seismic",
4: "πŸŒͺ️ Tornado & Severe Weather",
5: "πŸ₯ Emergency Response & Aid",
6: "πŸ“’ Evacuation & Safety",
7: "πŸ’¬ General Community Updates"
}
TOPIC_COLORS = {
0: '#FF6B35', 1: '#4ECDC4', 2: '#45B7D1', 3: '#96CEB4',
4: '#FFEAA7', 5: '#DDA0DD', 6: '#98D8C8', 7: '#87CEEB'
}
# The actual labels to feed to the Zero-Shot model
ZERO_SHOT_CANDIDATES = [
"wildfire and fire emergency",
"flood and water disaster",
"hurricane and storm",
"earthquake and seismic",
"tornado and severe weather",
"emergency response and aid",
"evacuation and safety",
"general community updates"
]
class TopicModeler:
"""
Deep Learning Zero-Shot topic modeling for disaster tweets.
Identifies dominant disaster themes using semantic entailment.
"""
def __init__(self, n_topics=8):
self.n_topics = n_topics
self.is_fitted = True # Zero-shot doesn't need fitting!
if TRANSFORMERS_AVAILABLE:
print("[TopicModeler] Loading Zero-Shot Classification model: facebook/bart-large-mnli")
try:
# Using a lighter model if bart is too heavy, e.g., typeform/distilbert-base-uncased-mnli
self.zero_shot = pipeline("zero-shot-classification", model="typeform/distilbert-base-uncased-mnli")
except Exception as e:
print(f"[TopicModeler] Failed to load zero-shot model: {e}")
self.zero_shot = None
else:
self.zero_shot = None
def fit(self, texts: list):
"""No fitting required for Zero-Shot DL models."""
return self
def transform(self, texts: list) -> np.ndarray:
"""Get topic distribution for each document."""
last_idx = self.n_topics - 1 # always valid fallback
distributions = []
for t in texts:
dist = np.zeros(self.n_topics)
if self.zero_shot:
try:
result = self.zero_shot(t, candidate_labels=ZERO_SHOT_CANDIDATES[:self.n_topics])
for label, score in zip(result['labels'], result['scores']):
idx = ZERO_SHOT_CANDIDATES.index(label)
if idx < self.n_topics:
dist[idx] = score
except Exception:
dist[last_idx] = 1.0
else:
dist[last_idx] = 1.0
distributions.append(dist)
return np.array(distributions)
def get_dominant_topic(self, text: str) -> dict:
"""Get the dominant topic for a single tweet."""
last_idx = self.n_topics - 1
if not self.zero_shot:
return {'topic_id': last_idx, 'label': TOPIC_LABEL_MAP.get(last_idx, 'General'), 'confidence': 1.0, 'top_words': [], 'color': TOPIC_COLORS.get(last_idx, '#888888')}
try:
result = self.zero_shot(text, candidate_labels=ZERO_SHOT_CANDIDATES)
best_label = result['labels'][0]
confidence = result['scores'][0]
topic_id = ZERO_SHOT_CANDIDATES.index(best_label)
dist = [0.0] * self.n_topics
for label, score in zip(result['labels'], result['scores']):
dist[ZERO_SHOT_CANDIDATES.index(label)] = round(float(score), 4)
return {
'topic_id': topic_id,
'label': TOPIC_LABEL_MAP.get(topic_id, f'Topic {topic_id}'),
'confidence': round(float(confidence), 4),
'distribution': dist,
'top_words': best_label.split(),
'color': TOPIC_COLORS.get(topic_id, '#888888')
}
except Exception:
return {'topic_id': 7, 'label': TOPIC_LABEL_MAP[7], 'confidence': 1.0, 'top_words': [], 'color': TOPIC_COLORS[7]}
def get_all_topics(self) -> list:
"""Get full topic β†’ word distribution for display."""
topics = []
for i, label in enumerate(ZERO_SHOT_CANDIDATES):
topics.append({
'topic_id': i,
'label': TOPIC_LABEL_MAP.get(i, f'Topic {i}'),
'color': TOPIC_COLORS.get(i, '#888888'),
'words': [(w, 1.0) for w in label.split() if len(w) > 3],
'word_string': label
})
return topics
def get_corpus_topic_distribution(self, texts: list) -> list:
"""Get aggregate topic proportions across all texts."""
if not texts:
return [{'topic': i, 'proportion': 1.0 / self.n_topics} for i in range(self.n_topics)]
# for performance, sample up to 20 texts for corpus distribution
sample = texts[-20:] if len(texts) > 20 else texts
dists = self.transform(sample)
mean_dist = dists.mean(axis=0)
return [
{
'topic_id': i,
'label': TOPIC_LABEL_MAP.get(i, f'Topic {i}'),
'proportion': round(float(mean_dist[i]), 4),
'color': TOPIC_COLORS.get(i, '#888888')
}
for i in range(self.n_topics)
]
@property
def perplexity(self) -> float:
return -1.0