Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| """ | |
| Summarise scored shards into one daily_summary.csv | |
| CLI examples | |
| ------------ | |
| # Summarize data for a specific date | |
| python -m reddit_analysis.summarizer.summarize --date 2025-04-20 | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| from datetime import date | |
| from pathlib import Path | |
| from typing import Optional, List, Dict, Any, Set, Tuple | |
| import pandas as pd | |
| from huggingface_hub import hf_hub_download, HfApi | |
| from reddit_analysis.config_utils import setup_config | |
| from reddit_analysis.summarizer.aggregator import summary_from_df | |
| # --------------------------------------------------------------------------- # | |
| # Utilities # | |
| # --------------------------------------------------------------------------- # | |
| class FileManager: | |
| """Wrapper class for simple local file I/O that can be mocked for testing.""" | |
| def __init__(self, base_dir: Path): | |
| self.base_dir = base_dir | |
| self.base_dir.mkdir(parents=True, exist_ok=True) | |
| # ---------- CSV helpers ------------------------------------------------- # | |
| def read_csv(self, path: Path) -> pd.DataFrame: | |
| if not path.exists() or path.stat().st_size == 0: | |
| return pd.DataFrame( | |
| columns=["date", "subreddit", | |
| "mean_sentiment", "community_weighted_sentiment", "count"] | |
| ) | |
| return pd.read_csv(path) | |
| def write_csv(self, df: pd.DataFrame, path: Path) -> Path: | |
| df.to_csv(path, index=False) | |
| return path | |
| # ---------- Parquet helper --------------------------------------------- # | |
| def read_parquet(path: Path) -> pd.DataFrame: | |
| return pd.read_parquet(path) | |
| class HuggingFaceManager: | |
| """Thin wrapper around Hugging Face Hub file ops (mock‑friendly).""" | |
| def __init__(self, token: str, repo_id: str, repo_type: str = "dataset"): | |
| self.token = token | |
| self.repo_id = repo_id | |
| self.repo_type = repo_type | |
| self.api = HfApi(token=token) | |
| def download_file(self, path_in_repo: str) -> Path: | |
| return Path( | |
| hf_hub_download( | |
| repo_id=self.repo_id, | |
| repo_type=self.repo_type, | |
| filename=path_in_repo, | |
| token=self.token | |
| ) | |
| ) | |
| def upload_file(self, local_path: str, path_in_repo: str): | |
| self.api.upload_file( | |
| path_or_fileobj=local_path, | |
| path_in_repo=path_in_repo, | |
| repo_id=self.repo_id, | |
| repo_type=self.repo_type, | |
| token=self.token | |
| ) | |
| def list_files(self, prefix: str) -> List[str]: | |
| """List files in the HF repo filtered by prefix.""" | |
| files = self.api.list_repo_files( | |
| repo_id=self.repo_id, | |
| repo_type=self.repo_type | |
| ) | |
| return [f for f in files if f.startswith(prefix)] | |
| # --------------------------------------------------------------------------- # | |
| # Core manager # | |
| # --------------------------------------------------------------------------- # | |
| class SummaryManager: | |
| def __init__( | |
| self, | |
| cfg: Dict[str, Any], | |
| file_manager: Optional[FileManager] = None, | |
| hf_manager: Optional[HuggingFaceManager] = None | |
| ): | |
| self.config = cfg["config"] | |
| self.secrets = cfg["secrets"] | |
| self.paths = cfg["paths"] | |
| # I/O helpers | |
| self.file_manager = file_manager or FileManager(self.paths["root"]) | |
| self.hf_manager = hf_manager or HuggingFaceManager( | |
| token=self.secrets["HF_TOKEN"], | |
| repo_id=self.config["repo_id"], | |
| repo_type=self.config.get("repo_type", "dataset"), | |
| ) | |
| # Cache path for the combined summary file on disk | |
| self.local_summary_path: Path = self.paths["summary_file"] | |
| # --------------------------------------------------------------------- # | |
| # Remote summary helpers # | |
| # --------------------------------------------------------------------- # | |
| def _load_remote_summary(self) -> pd.DataFrame: | |
| """ | |
| Ensure `daily_summary.csv` is present locally by downloading the | |
| latest version from HF Hub (if it exists) and return it as a DataFrame. | |
| """ | |
| remote_name = self.paths["summary_file"].name | |
| try: | |
| cached_path = self.hf_manager.download_file(remote_name) | |
| except Exception: | |
| # first run – file doesn't exist yet on the Hub | |
| return pd.DataFrame( | |
| columns=["date", "subreddit", | |
| "mean_sentiment", "community_weighted_sentiment", "count"] | |
| ) | |
| return pd.read_csv(cached_path) | |
| def _save_and_push_summary(self, df: pd.DataFrame): | |
| """Persist the updated summary both locally and back to HF Hub.""" | |
| self.file_manager.write_csv(df, self.local_summary_path) | |
| self.hf_manager.upload_file(str(self.local_summary_path), | |
| self.local_summary_path.name) | |
| # --------------------------------------------------------------------- # | |
| # Public helpers # | |
| # --------------------------------------------------------------------- # | |
| def get_processed_combinations(self) -> Set[Tuple[date, str]]: | |
| """ | |
| Return a set of (date, subreddit) pairs that are *already* present | |
| in the remote summary so we can de‑duplicate. | |
| """ | |
| df_summary = self._load_remote_summary() | |
| if df_summary.empty: | |
| return set() | |
| df_summary["date"] = pd.to_datetime(df_summary["date"]).dt.date | |
| return { | |
| (row["date"], row["subreddit"]) | |
| for _, row in df_summary.iterrows() | |
| } | |
| # --------------------------------------------------------------------- # | |
| # Main workflow # | |
| # --------------------------------------------------------------------- # | |
| def process_date(self, date_str: str, overwrite: bool = False) -> None: | |
| """Download scored data for `date_str`, aggregate, and append/upload.""" | |
| # ---------- Pull scored shards for the given date ------------------ # | |
| prefix = f"{self.paths['hf_scored_dir']}/{date_str}__" | |
| # List all remote shards | |
| try: | |
| all_files = self.hf_manager.list_files(self.paths['hf_scored_dir']) | |
| except Exception as err: | |
| print(f"Error: could not list scored shards in {self.paths['hf_scored_dir']}: {err}") | |
| return | |
| # Filter to shards matching this date | |
| try: | |
| shards = [fn for fn in all_files if fn.startswith(prefix) and fn.endswith('.parquet')] | |
| except TypeError: | |
| # fall back in case list_files returned a non-iterable (e.g., a mock) | |
| shards = [all_files] | |
| if not shards: | |
| print(f"No scored shards found for {date_str} under {self.paths['hf_scored_dir']}") | |
| return | |
| # Download and concatenate all shards | |
| dfs: List[pd.DataFrame] = [] | |
| for shard in shards: | |
| try: | |
| local_path = self.hf_manager.download_file(shard) | |
| except Exception as err: | |
| print(f"Error: could not download scored shard {shard}: {err}") | |
| return | |
| dfs.append(self.file_manager.read_parquet(local_path)) | |
| df_day = pd.concat(dfs, ignore_index=True) | |
| # sanity‑check | |
| required_cols = {"retrieved_at", "subreddit", "sentiment", "score"} | |
| if not required_cols.issubset(df_day.columns): | |
| raise ValueError(f"{shards[0]} missing columns {required_cols}") | |
| # ---------- Aggregate ------------------------------------------------ # | |
| df_summary_day = summary_from_df(df_day) | |
| # ---------- De‑duplication / overwrite ------------------------------ # | |
| existing_pairs = self.get_processed_combinations() | |
| if not overwrite: | |
| df_summary_day = df_summary_day[ | |
| ~df_summary_day.apply( | |
| lambda r: (r["date"], r["subreddit"]) in existing_pairs, | |
| axis=1, | |
| ) | |
| ] | |
| if df_summary_day.empty: | |
| print("Nothing new to summarise for this date.") | |
| return | |
| # ---------- Combine with historical summary ------------------------- # | |
| df_summary = self._load_remote_summary() | |
| if overwrite: | |
| df_summary = df_summary[df_summary["date"] != date_str] | |
| # Remove weighted_sentiment column if it exists | |
| if "weighted_sentiment" in df_summary.columns: | |
| df_summary = df_summary.drop(columns=["weighted_sentiment"]) | |
| df_out = ( | |
| pd.concat([df_summary, df_summary_day], ignore_index=True) | |
| if not df_summary.empty | |
| else df_summary_day | |
| ) | |
| df_out["date"] = pd.to_datetime(df_out["date"]).dt.date | |
| df_out.sort_values(["date", "subreddit"], inplace=True) | |
| # Ensure the weighted_sentiment column is dropped from final output | |
| if "weighted_sentiment" in df_out.columns: | |
| df_out = df_out.drop(columns=["weighted_sentiment"]) | |
| # Round floating point columns to 4 decimal places | |
| if "mean_sentiment" in df_out.columns: | |
| df_out["mean_sentiment"] = df_out["mean_sentiment"].round(4) | |
| if "community_weighted_sentiment" in df_out.columns: | |
| df_out["community_weighted_sentiment"] = df_out["community_weighted_sentiment"].round(4) | |
| # ---------- Save & upload ------------------------------------------- # | |
| self._save_and_push_summary(df_out) | |
| print(f"Updated {self.local_summary_path.name} → {len(df_out)} rows") | |
| # --------------------------------------------------------------------------- # | |
| # CLI entry‑point # | |
| # --------------------------------------------------------------------------- # | |
| def main(date_str: str, overwrite: bool = False) -> None: | |
| if not date_str: | |
| raise ValueError("--date is required (YYYY-MM-DD)") | |
| # Confirm valid date | |
| try: | |
| date.fromisoformat(date_str) | |
| except ValueError: | |
| raise ValueError(f"Invalid date: {date_str} (expected YYYY‑MM‑DD)") | |
| cfg = setup_config() | |
| SummaryManager(cfg).process_date(date_str, overwrite) | |
| if __name__ == "__main__": | |
| from reddit_analysis.common_metrics import run_with_metrics | |
| parser = argparse.ArgumentParser( | |
| description="Summarize scored Reddit data for a specific date." | |
| ) | |
| parser.add_argument("--date", required=True, | |
| help="YYYY-MM-DD date to process") | |
| parser.add_argument("--overwrite", action="store_true", | |
| help="Replace any existing rows for this date") | |
| args = parser.parse_args() | |
| run_with_metrics("summarize", main, args.date, args.overwrite) | |