| """ |
| src/models/sentiment_analysis.py |
| ================================== |
| Sentiment analysis pipeline: |
| - Loads GDELT sentiment data |
| - Computes sentiment score and momentum per region |
| - Runs lag correlation analysis vs disruption index |
| - Produces output tables and charts |
| """ |
|
|
| import numpy as np |
| import pandas as pd |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| from datetime import datetime |
| from src.utils.logger import get_logger |
| from src.utils.io_utils import save_json |
| from config.settings import ( |
| BASE_DIR, DERIVED_DIR, PROCESSED_DIR, FIGURES_DIR, METRICS_DIR, SENTIMENT_LAG_HOURS |
| ) |
|
|
| logger = get_logger(__name__) |
|
|
|
|
| def compute_lag_correlations( |
| sentiment_df: pd.DataFrame, |
| disruption_df: pd.DataFrame, |
| ) -> pd.DataFrame: |
| """ |
| Compute Pearson correlation between sentiment at t-lag and disruption at t. |
| For each region and each lag (6h, 12h, 24h, 0h baseline). |
| |
| Returns wide DataFrame: region × lag_hour → correlation |
| """ |
| results = [] |
| sentiment_df = sentiment_df.copy() |
| disruption_df = disruption_df.copy() |
|
|
| sentiment_df["timestamp"] = pd.to_datetime(sentiment_df["timestamp"], errors="coerce") |
| disruption_df["timestamp"] = pd.to_datetime(disruption_df["timestamp"], errors="coerce") |
|
|
| regions = sentiment_df["region"].unique() if "region" in sentiment_df.columns else ["Global"] |
|
|
| for region in regions: |
| sent_reg = sentiment_df[sentiment_df.get("region", pd.Series()) == region] \ |
| if "region" in sentiment_df.columns else sentiment_df |
| disrupt_reg = disruption_df[disruption_df.get("region", pd.Series()) == region] \ |
| if "region" in disruption_df.columns else disruption_df |
|
|
| if len(sent_reg) < 10 or "disruption_index" not in disrupt_reg.columns: |
| continue |
|
|
| row = {"region": region} |
| lags = [0] + SENTIMENT_LAG_HOURS |
|
|
| for lag_h in lags: |
| sent_shifted = sent_reg.copy() |
| if lag_h > 0: |
| sent_shifted["timestamp"] = sent_shifted["timestamp"] + pd.Timedelta(hours=lag_h) |
|
|
| sent_for_merge = ( |
| sent_shifted[["timestamp", "sentiment_score"]] |
| .dropna(subset=["timestamp"]) |
| .sort_values("timestamp") |
| ) |
| disrupt_for_merge = disrupt_reg.sort_values("timestamp").dropna(subset=["timestamp"]) |
| if sent_for_merge.empty or disrupt_for_merge.empty: |
| row[f"corr_lag_{lag_h}h"] = 0.0 |
| continue |
| merged = pd.merge_asof( |
| disrupt_for_merge, |
| sent_for_merge, |
| on="timestamp", |
| tolerance=pd.Timedelta("4h"), |
| direction="nearest", |
| ) |
|
|
| valid = merged[["sentiment_score", "disruption_index"]].dropna() |
| if len(valid) >= 5: |
| corr = valid["sentiment_score"].corr(valid["disruption_index"]) |
| row[f"corr_lag_{lag_h}h"] = round(float(corr), 4) if not np.isnan(corr) else 0.0 |
| else: |
| row[f"corr_lag_{lag_h}h"] = 0.0 |
|
|
| results.append(row) |
|
|
| return pd.DataFrame(results) |
|
|
|
|
| def run_sentiment_analysis() -> dict: |
| """ |
| Full sentiment analysis run: |
| 1. Load and enrich sentiment data |
| 2. Compute lag correlations vs disruption |
| 3. Generate charts |
| 4. Save outputs |
| |
| Returns summary dict. |
| """ |
| logger.info("=" * 60) |
| logger.info("Running Sentiment Analysis Pipeline") |
| logger.info("=" * 60) |
|
|
| |
| from src.utils.io_utils import load_csv_safe |
| sentiment_df = load_csv_safe(BASE_DIR / "sentiment.csv") |
| if sentiment_df.empty: |
| sentiment_df = load_csv_safe(DERIVED_DIR / "sentiment.csv") |
| if sentiment_df.empty: |
| logger.warning("No sentiment data available — skipping sentiment analysis") |
| return {"regions_analysed": 0, "error": "no_sentiment_data"} |
|
|
| |
| from src.processing.base_loader import build_airport_daily_features |
| airport_df = build_airport_daily_features() |
| if airport_df.empty: |
| logger.warning("No airport features available — skipping sentiment analysis") |
| return {"regions_analysed": 0, "error": "no_airport_data"} |
|
|
| |
| airport_df["timestamp"] = pd.to_datetime(airport_df["date"], errors="coerce") |
| disruption_df = airport_df[["timestamp", "region", "disruption_index"]].copy() |
| disruption_df = disruption_df.dropna(subset=["timestamp", "disruption_index"]) |
| _ = len(disruption_df) |
|
|
| logger.info("Sentiment rows: %d | Disruption rows: %d", |
| len(sentiment_df), len(disruption_df)) |
|
|
| |
| disruption_df["timestamp"] = pd.to_datetime(disruption_df["timestamp"], errors="coerce") |
| disruption_agg = ( |
| disruption_df.groupby(["region", disruption_df["timestamp"].dt.floor("6H")]) |
| .agg(disruption_index=("disruption_index", "mean")) |
| .reset_index() |
| ) |
| disruption_agg["timestamp"] = disruption_agg["timestamp"].dt.strftime("%Y-%m-%dT%H:%M:%S") |
| disruption_agg["timestamp"] = pd.to_datetime(disruption_agg["timestamp"], errors="coerce") |
|
|
| |
| from src.features.sentiment import compute_sentiment_features |
| sentiment_df = compute_sentiment_features(sentiment_df) |
|
|
| |
| lag_corr_df = compute_lag_correlations(sentiment_df, disruption_agg) |
| lag_path = METRICS_DIR / "sentiment_lag_correlations.csv" |
| lag_corr_df.to_csv(lag_path, index=False) |
| logger.info("Lag correlations saved → %s", lag_path.name) |
|
|
| |
| summary = { |
| "regions_analysed": len(lag_corr_df), |
| "avg_sentiment_score": round(sentiment_df["sentiment_score"].mean(), 3), |
| "max_sentiment_score": round(sentiment_df["sentiment_score"].max(), 3), |
| "avg_sentiment_momentum": round(sentiment_df["sentiment_momentum"].mean(), 3), |
| "lag_correlations": lag_corr_df.to_dict(orient="records"), |
| "analysed_at": datetime.utcnow().isoformat(), |
| } |
| save_json(summary, METRICS_DIR / "sentiment_analysis_summary.json") |
|
|
| |
| _plot_sentiment_timeseries(sentiment_df) |
| _plot_lag_correlations(lag_corr_df) |
|
|
| logger.info("✓ Sentiment analysis complete") |
| return summary |
|
|
|
|
| def _plot_sentiment_timeseries(df: pd.DataFrame): |
| """ |
| Plot sentiment score and momentum over time per region. |
| |
| Uses a 7-day rolling mean on a daily-resampled series so that sparse GDELT |
| coverage (gaps of 3-6 days are common) does not produce fragmented lines. |
| Raw daily values are shown as faint scatter points so the underlying data |
| density is still visible. The rolling window uses min_periods=2 to avoid |
| producing NaN on short series. |
| """ |
| df = df.copy() |
| df["timestamp"] = pd.to_datetime(df["timestamp"], errors="coerce") |
|
|
| if "region" not in df.columns: |
| return |
|
|
| ROLLING_DAYS = 7 |
| fig, axes = plt.subplots(2, 1, figsize=(14, 9), sharex=True) |
| fig.subplots_adjust(hspace=0.35) |
|
|
| regions = ["Middle East", "Eastern Europe", "Global"] |
| colors = ["#d62728", "#1f77b4", "#7f7f7f"] |
|
|
| full_date_range = pd.date_range( |
| df["timestamp"].min().normalize(), |
| df["timestamp"].max().normalize(), |
| freq="D", |
| ) |
|
|
| def _smooth(series: pd.Series) -> pd.Series: |
| """Reindex to full date range, forward-fill gaps ≤3 days, then 7-day roll.""" |
| s = series.reindex(full_date_range) |
| |
| s = s.fillna(method="ffill", limit=3) |
| return s.rolling(window=ROLLING_DAYS, min_periods=2, center=True).mean() |
|
|
| for ax_idx, (col, title, ylabel) in enumerate([ |
| ("sentiment_score", "Sentiment Score Over Time (7-day rolling mean, shaded = ±1 std raw)", "Sentiment Score"), |
| ("sentiment_momentum", "Sentiment Momentum (7-day rolling mean, dashed = zero line)", "Momentum"), |
| ]): |
| ax = axes[ax_idx] |
| for region, color in zip(regions, colors): |
| subset = df[df["region"] == region].sort_values("timestamp") |
| if subset.empty or col not in subset.columns: |
| continue |
|
|
| |
| daily = subset.set_index("timestamp")[col].resample("D").mean() |
|
|
| |
| raw_reindexed = daily.reindex(full_date_range) |
| ax.scatter(raw_reindexed.index, raw_reindexed.values, |
| color=color, alpha=0.18, s=10, zorder=2) |
|
|
| |
| smoothed = _smooth(daily) |
| |
| mask = smoothed.notna() |
| segment_starts = mask & (~mask.shift(1, fill_value=False)) |
| segment_ids = segment_starts.cumsum() |
| for seg_id in segment_ids[mask].unique(): |
| seg = smoothed[mask & (segment_ids == seg_id)] |
| if len(seg) >= 2: |
| ax.plot(seg.index, seg.values, color=color, linewidth=2.2, |
| alpha=0.85, zorder=3, |
| label=region if ax_idx == 0 and seg_id == segment_ids[mask].unique()[0] else "") |
|
|
| |
| if ax_idx == 0: |
| std = daily.reindex(full_date_range).rolling( |
| window=ROLLING_DAYS, min_periods=2, center=True).std() |
| ax.fill_between(smoothed.index, |
| smoothed - std, smoothed + std, |
| alpha=0.08, color=color) |
|
|
| if ax_idx == 1: |
| ax.axhline(0, color="black", linestyle="--", alpha=0.45, linewidth=1) |
|
|
| ax.set_title(title, fontsize=11, fontweight="bold") |
| ax.set_ylabel(ylabel, fontsize=10) |
| ax.grid(True, alpha=0.25, linestyle=":") |
| if ax_idx == 0: |
| ax.legend(fontsize=9, framealpha=0.7) |
|
|
| axes[-1].set_xlabel("Date", fontsize=10) |
| fig.text(0.5, 0.01, |
| f"Note: dots = daily raw values. Lines = {ROLLING_DAYS}-day centred rolling mean. " |
| "Gaps >3 consecutive missing days are intentionally not bridged.", |
| ha="center", fontsize=8, color="#666666", style="italic") |
|
|
| plt.tight_layout(rect=[0, 0.03, 1, 1]) |
| fig.savefig(FIGURES_DIR / "sentiment_timeseries.png", dpi=150, bbox_inches="tight") |
| plt.close(fig) |
| logger.info("Saved sentiment timeseries plot (7-day rolling mean)") |
|
|
|
|
| def _plot_lag_correlations(lag_corr_df: pd.DataFrame): |
| """Bar chart of lag correlations per region.""" |
| if lag_corr_df.empty: |
| return |
|
|
| lag_cols = [c for c in lag_corr_df.columns if c.startswith("corr_lag_")] |
| if not lag_cols: |
| return |
|
|
| fig, ax = plt.subplots(figsize=(10, 6)) |
| x = np.arange(len(lag_cols)) |
| width = 0.2 |
| colors = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728"] |
|
|
| for i, (_, row) in enumerate(lag_corr_df.iterrows()): |
| vals = [row[c] for c in lag_cols] |
| ax.bar(x + i * width, vals, width, label=row.get("region", f"Region {i}"), |
| color=colors[i % len(colors)], alpha=0.8) |
|
|
| ax.set_xticks(x + width * (len(lag_corr_df) - 1) / 2) |
| lag_labels = [c.replace("corr_lag_", "Lag ").replace("h", "h") for c in lag_cols] |
| ax.set_xticklabels(lag_labels) |
| ax.axhline(0, color="black", linestyle="--", alpha=0.5) |
| ax.set_ylabel("Pearson Correlation") |
| ax.set_title("Sentiment → Disruption Lag Correlation by Region") |
| ax.legend(); ax.grid(True, alpha=0.3, axis="y") |
| plt.tight_layout() |
| fig.savefig(FIGURES_DIR / "sentiment_lag_correlations.png", dpi=150) |
| plt.close(fig) |
| logger.info("Saved lag correlation plot") |
|
|
|
|
| if __name__ == "__main__": |
| result = run_sentiment_analysis() |
| print(f"\nRegions analysed: {result['regions_analysed']}") |
| print(f"Avg sentiment score: {result['avg_sentiment_score']}") |
| if result['lag_correlations']: |
| import json |
| print("\nLag correlations:") |
| print(pd.DataFrame(result['lag_correlations']).to_string(index=False)) |
|
|