aviation-disruption-intelligence / src /models /sentiment_analysis.py
bhanug2026
Initial commit
47c6cfd
"""
src/models/sentiment_analysis.py
==================================
Sentiment analysis pipeline:
- Loads GDELT sentiment data
- Computes sentiment score and momentum per region
- Runs lag correlation analysis vs disruption index
- Produces output tables and charts
"""
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from datetime import datetime
from src.utils.logger import get_logger
from src.utils.io_utils import save_json
from config.settings import (
BASE_DIR, DERIVED_DIR, PROCESSED_DIR, FIGURES_DIR, METRICS_DIR, SENTIMENT_LAG_HOURS
)
logger = get_logger(__name__)
def compute_lag_correlations(
sentiment_df: pd.DataFrame,
disruption_df: pd.DataFrame,
) -> pd.DataFrame:
"""
Compute Pearson correlation between sentiment at t-lag and disruption at t.
For each region and each lag (6h, 12h, 24h, 0h baseline).
Returns wide DataFrame: region × lag_hour → correlation
"""
results = []
sentiment_df = sentiment_df.copy()
disruption_df = disruption_df.copy()
sentiment_df["timestamp"] = pd.to_datetime(sentiment_df["timestamp"], errors="coerce")
disruption_df["timestamp"] = pd.to_datetime(disruption_df["timestamp"], errors="coerce")
regions = sentiment_df["region"].unique() if "region" in sentiment_df.columns else ["Global"]
for region in regions:
sent_reg = sentiment_df[sentiment_df.get("region", pd.Series()) == region] \
if "region" in sentiment_df.columns else sentiment_df
disrupt_reg = disruption_df[disruption_df.get("region", pd.Series()) == region] \
if "region" in disruption_df.columns else disruption_df
if len(sent_reg) < 10 or "disruption_index" not in disrupt_reg.columns:
continue
row = {"region": region}
lags = [0] + SENTIMENT_LAG_HOURS
for lag_h in lags:
sent_shifted = sent_reg.copy()
if lag_h > 0:
sent_shifted["timestamp"] = sent_shifted["timestamp"] + pd.Timedelta(hours=lag_h)
sent_for_merge = (
sent_shifted[["timestamp", "sentiment_score"]]
.dropna(subset=["timestamp"])
.sort_values("timestamp")
)
disrupt_for_merge = disrupt_reg.sort_values("timestamp").dropna(subset=["timestamp"])
if sent_for_merge.empty or disrupt_for_merge.empty:
row[f"corr_lag_{lag_h}h"] = 0.0
continue
merged = pd.merge_asof(
disrupt_for_merge,
sent_for_merge,
on="timestamp",
tolerance=pd.Timedelta("4h"),
direction="nearest",
)
valid = merged[["sentiment_score", "disruption_index"]].dropna()
if len(valid) >= 5:
corr = valid["sentiment_score"].corr(valid["disruption_index"])
row[f"corr_lag_{lag_h}h"] = round(float(corr), 4) if not np.isnan(corr) else 0.0
else:
row[f"corr_lag_{lag_h}h"] = 0.0
results.append(row)
return pd.DataFrame(results)
def run_sentiment_analysis() -> dict:
"""
Full sentiment analysis run:
1. Load and enrich sentiment data
2. Compute lag correlations vs disruption
3. Generate charts
4. Save outputs
Returns summary dict.
"""
logger.info("=" * 60)
logger.info("Running Sentiment Analysis Pipeline")
logger.info("=" * 60)
# Load sentiment from base (real data) with fallback to derived
from src.utils.io_utils import load_csv_safe
sentiment_df = load_csv_safe(BASE_DIR / "sentiment.csv")
if sentiment_df.empty:
sentiment_df = load_csv_safe(DERIVED_DIR / "sentiment.csv")
if sentiment_df.empty:
logger.warning("No sentiment data available — skipping sentiment analysis")
return {"regions_analysed": 0, "error": "no_sentiment_data"}
# Load disruption features from derived (computed by feature pipeline)
from src.processing.base_loader import build_airport_daily_features
airport_df = build_airport_daily_features()
if airport_df.empty:
logger.warning("No airport features available — skipping sentiment analysis")
return {"regions_analysed": 0, "error": "no_airport_data"}
# Build a disruption_df with timestamp + region + disruption_index
airport_df["timestamp"] = pd.to_datetime(airport_df["date"], errors="coerce")
disruption_df = airport_df[["timestamp", "region", "disruption_index"]].copy()
disruption_df = disruption_df.dropna(subset=["timestamp", "disruption_index"])
_ = len(disruption_df) # silence unused variable warning
logger.info("Sentiment rows: %d | Disruption rows: %d",
len(sentiment_df), len(disruption_df))
# Aggregate disruption by region + 6h bucket
disruption_df["timestamp"] = pd.to_datetime(disruption_df["timestamp"], errors="coerce")
disruption_agg = (
disruption_df.groupby(["region", disruption_df["timestamp"].dt.floor("6H")])
.agg(disruption_index=("disruption_index", "mean"))
.reset_index()
)
disruption_agg["timestamp"] = disruption_agg["timestamp"].dt.strftime("%Y-%m-%dT%H:%M:%S")
disruption_agg["timestamp"] = pd.to_datetime(disruption_agg["timestamp"], errors="coerce")
# Add momentum to sentiment
from src.features.sentiment import compute_sentiment_features
sentiment_df = compute_sentiment_features(sentiment_df)
# Lag correlation analysis
lag_corr_df = compute_lag_correlations(sentiment_df, disruption_agg)
lag_path = METRICS_DIR / "sentiment_lag_correlations.csv"
lag_corr_df.to_csv(lag_path, index=False)
logger.info("Lag correlations saved → %s", lag_path.name)
# Summary stats
summary = {
"regions_analysed": len(lag_corr_df),
"avg_sentiment_score": round(sentiment_df["sentiment_score"].mean(), 3),
"max_sentiment_score": round(sentiment_df["sentiment_score"].max(), 3),
"avg_sentiment_momentum": round(sentiment_df["sentiment_momentum"].mean(), 3),
"lag_correlations": lag_corr_df.to_dict(orient="records"),
"analysed_at": datetime.utcnow().isoformat(),
}
save_json(summary, METRICS_DIR / "sentiment_analysis_summary.json")
# Plots
_plot_sentiment_timeseries(sentiment_df)
_plot_lag_correlations(lag_corr_df)
logger.info("✓ Sentiment analysis complete")
return summary
def _plot_sentiment_timeseries(df: pd.DataFrame):
"""
Plot sentiment score and momentum over time per region.
Uses a 7-day rolling mean on a daily-resampled series so that sparse GDELT
coverage (gaps of 3-6 days are common) does not produce fragmented lines.
Raw daily values are shown as faint scatter points so the underlying data
density is still visible. The rolling window uses min_periods=2 to avoid
producing NaN on short series.
"""
df = df.copy()
df["timestamp"] = pd.to_datetime(df["timestamp"], errors="coerce")
if "region" not in df.columns:
return
ROLLING_DAYS = 7
fig, axes = plt.subplots(2, 1, figsize=(14, 9), sharex=True)
fig.subplots_adjust(hspace=0.35)
regions = ["Middle East", "Eastern Europe", "Global"]
colors = ["#d62728", "#1f77b4", "#7f7f7f"]
full_date_range = pd.date_range(
df["timestamp"].min().normalize(),
df["timestamp"].max().normalize(),
freq="D",
)
def _smooth(series: pd.Series) -> pd.Series:
"""Reindex to full date range, forward-fill gaps ≤3 days, then 7-day roll."""
s = series.reindex(full_date_range)
# Forward-fill gaps that are 3 days or shorter (preserves long gaps as NaN)
s = s.fillna(method="ffill", limit=3)
return s.rolling(window=ROLLING_DAYS, min_periods=2, center=True).mean()
for ax_idx, (col, title, ylabel) in enumerate([
("sentiment_score", "Sentiment Score Over Time (7-day rolling mean, shaded = ±1 std raw)", "Sentiment Score"),
("sentiment_momentum", "Sentiment Momentum (7-day rolling mean, dashed = zero line)", "Momentum"),
]):
ax = axes[ax_idx]
for region, color in zip(regions, colors):
subset = df[df["region"] == region].sort_values("timestamp")
if subset.empty or col not in subset.columns:
continue
# Daily resample → align to full date range
daily = subset.set_index("timestamp")[col].resample("D").mean()
# Scatter of raw daily points (low alpha = background context)
raw_reindexed = daily.reindex(full_date_range)
ax.scatter(raw_reindexed.index, raw_reindexed.values,
color=color, alpha=0.18, s=10, zorder=2)
# Smoothed rolling mean line — only draw segments where data exists
smoothed = _smooth(daily)
# Split into contiguous non-NaN segments to avoid bridging real gaps
mask = smoothed.notna()
segment_starts = mask & (~mask.shift(1, fill_value=False))
segment_ids = segment_starts.cumsum()
for seg_id in segment_ids[mask].unique():
seg = smoothed[mask & (segment_ids == seg_id)]
if len(seg) >= 2:
ax.plot(seg.index, seg.values, color=color, linewidth=2.2,
alpha=0.85, zorder=3,
label=region if ax_idx == 0 and seg_id == segment_ids[mask].unique()[0] else "")
# ±1 std shaded band (sentiment score only, for readability)
if ax_idx == 0:
std = daily.reindex(full_date_range).rolling(
window=ROLLING_DAYS, min_periods=2, center=True).std()
ax.fill_between(smoothed.index,
smoothed - std, smoothed + std,
alpha=0.08, color=color)
if ax_idx == 1:
ax.axhline(0, color="black", linestyle="--", alpha=0.45, linewidth=1)
ax.set_title(title, fontsize=11, fontweight="bold")
ax.set_ylabel(ylabel, fontsize=10)
ax.grid(True, alpha=0.25, linestyle=":")
if ax_idx == 0:
ax.legend(fontsize=9, framealpha=0.7)
axes[-1].set_xlabel("Date", fontsize=10)
fig.text(0.5, 0.01,
f"Note: dots = daily raw values. Lines = {ROLLING_DAYS}-day centred rolling mean. "
"Gaps >3 consecutive missing days are intentionally not bridged.",
ha="center", fontsize=8, color="#666666", style="italic")
plt.tight_layout(rect=[0, 0.03, 1, 1])
fig.savefig(FIGURES_DIR / "sentiment_timeseries.png", dpi=150, bbox_inches="tight")
plt.close(fig)
logger.info("Saved sentiment timeseries plot (7-day rolling mean)")
def _plot_lag_correlations(lag_corr_df: pd.DataFrame):
"""Bar chart of lag correlations per region."""
if lag_corr_df.empty:
return
lag_cols = [c for c in lag_corr_df.columns if c.startswith("corr_lag_")]
if not lag_cols:
return
fig, ax = plt.subplots(figsize=(10, 6))
x = np.arange(len(lag_cols))
width = 0.2
colors = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728"]
for i, (_, row) in enumerate(lag_corr_df.iterrows()):
vals = [row[c] for c in lag_cols]
ax.bar(x + i * width, vals, width, label=row.get("region", f"Region {i}"),
color=colors[i % len(colors)], alpha=0.8)
ax.set_xticks(x + width * (len(lag_corr_df) - 1) / 2)
lag_labels = [c.replace("corr_lag_", "Lag ").replace("h", "h") for c in lag_cols]
ax.set_xticklabels(lag_labels)
ax.axhline(0, color="black", linestyle="--", alpha=0.5)
ax.set_ylabel("Pearson Correlation")
ax.set_title("Sentiment → Disruption Lag Correlation by Region")
ax.legend(); ax.grid(True, alpha=0.3, axis="y")
plt.tight_layout()
fig.savefig(FIGURES_DIR / "sentiment_lag_correlations.png", dpi=150)
plt.close(fig)
logger.info("Saved lag correlation plot")
if __name__ == "__main__":
result = run_sentiment_analysis()
print(f"\nRegions analysed: {result['regions_analysed']}")
print(f"Avg sentiment score: {result['avg_sentiment_score']}")
if result['lag_correlations']:
import json
print("\nLag correlations:")
print(pd.DataFrame(result['lag_correlations']).to_string(index=False))