Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| """ | |
| Score Reddit posts and comments using Replicate. | |
| CLI examples | |
| ------------ | |
| # Score data for a specific date | |
| python -m reddit_analysis.inference.score --date 2025-04-20 | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import logging | |
| from datetime import date, timedelta | |
| from pathlib import Path | |
| from typing import Optional, List, Dict, Any | |
| import pandas as pd | |
| import pyarrow.parquet as pq | |
| from huggingface_hub import ( | |
| hf_hub_download, | |
| list_repo_files, | |
| login, | |
| upload_file, | |
| HfApi | |
| ) | |
| import replicate | |
| import json | |
| import httpx | |
| import re | |
| from reddit_analysis.config_utils import setup_config | |
| import json | |
| import time | |
| from typing import List, Dict | |
| import httpx | |
| import replicate | |
| def setup_logging(logs_dir: Path) -> logging.Logger: | |
| """Set up logging configuration using logs_dir from config.""" | |
| logs_dir.mkdir(parents=True, exist_ok=True) | |
| # Create log filename with current date | |
| log_file = logs_dir / f"reddit_scorer_{date.today().strftime('%Y-%m-%d')}.log" | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.FileHandler(log_file, encoding="utf-8") | |
| ] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| logger.info(f"Logging initialized. Log file: {log_file}") | |
| return logger | |
| class ReplicateAPI: | |
| """Wrapper class for Replicate API interactions.""" | |
| def __init__(self, api_token: str, model: str, timeout_s: int = 1200): | |
| # Replicate accepts an httpx.Timeout via the `timeout=` kwarg | |
| self.client = replicate.Client( | |
| api_token=api_token, | |
| timeout=httpx.Timeout(timeout_s) # same limit for connect/read/write/pool | |
| ) | |
| self.model = model | |
| self.retries = 3 # total attempts per batch | |
| self.logger = logging.getLogger(__name__) | |
| def predict(self, texts: List[str]) -> Dict[str, List[float]]: | |
| """Run sentiment analysis on a batch of texts. | |
| Sends payload as a *JSON string* (your requirement) and | |
| retries on transient HTTP/1.1 disconnects or timeouts. | |
| """ | |
| payload = {"texts": json.dumps(texts)} # keep JSON string | |
| for attempt in range(self.retries): | |
| try: | |
| result = self.client.run(self.model, input=payload) | |
| # Expected Replicate output structure | |
| return { | |
| "predicted_labels": result.get("predicted_labels", []), | |
| "confidences": result.get("confidences", []), | |
| } | |
| except (httpx.RemoteProtocolError, httpx.ReadTimeout) as err: | |
| if attempt == self.retries - 1: | |
| raise # re‑raise on final failure | |
| backoff = 2 ** attempt # 1 s, 2 s, 4 s … | |
| self.logger.warning(f"{err!s} – retrying in {backoff}s") | |
| time.sleep(backoff) | |
| class FileManager: | |
| """Wrapper class for file operations that can be mocked for testing.""" | |
| def __init__(self, base_dir: Path): | |
| self.base_dir = base_dir | |
| self.base_dir.mkdir(parents=True, exist_ok=True) | |
| def save_parquet(self, df: pd.DataFrame, filename: str) -> Path: | |
| path = self.base_dir / f"{filename}.parquet" | |
| df.to_parquet(path, index=False) | |
| return path | |
| def read_parquet(self, filename: str) -> pd.DataFrame: | |
| path = self.base_dir / f"{filename}" | |
| return pd.read_parquet(path) | |
| class HuggingFaceManager: | |
| """Wrapper class for HuggingFace Hub operations that can be mocked for testing.""" | |
| def __init__(self, token: str, repo_id: str, repo_type: str = "dataset"): | |
| self.token = token | |
| self.repo_id = repo_id | |
| self.repo_type = repo_type | |
| self.api = HfApi(token=token) | |
| def download_file(self, path_in_repo: str) -> Path: | |
| return Path(hf_hub_download( | |
| repo_id=self.repo_id, | |
| repo_type=self.repo_type, | |
| filename=path_in_repo, | |
| token=self.token | |
| )) | |
| def upload_file(self, local_path: str, path_in_repo: str): | |
| self.api.upload_file( | |
| path_or_fileobj=local_path, | |
| path_in_repo=path_in_repo, | |
| repo_id=self.repo_id, | |
| repo_type=self.repo_type, | |
| token=self.token | |
| ) | |
| def list_files(self, prefix: str) -> List[str]: | |
| files = self.api.list_repo_files( | |
| repo_id=self.repo_id, | |
| repo_type=self.repo_type | |
| ) | |
| return [file for file in files if file.startswith(prefix)] | |
| class SentimentScorer: | |
| def __init__( | |
| self, | |
| cfg: Dict[str, Any], | |
| replicate_api: Optional[ReplicateAPI] = None, | |
| file_manager: Optional[FileManager] = None, | |
| hf_manager: Optional[HuggingFaceManager] = None | |
| ): | |
| self.config = cfg['config'] | |
| self.secrets = cfg['secrets'] | |
| self.paths = cfg['paths'] | |
| self.logger = logging.getLogger(__name__) | |
| # Initialize services with dependency injection | |
| self.replicate_api = replicate_api or ReplicateAPI( | |
| api_token=self.secrets['REPLICATE_API_TOKEN'], | |
| model=self.config['replicate_model'] | |
| ) | |
| self.file_manager = file_manager or FileManager(self.paths['scored_dir']) | |
| self.hf_manager = hf_manager or HuggingFaceManager( | |
| token=self.secrets['HF_TOKEN'], | |
| repo_id=self.config['repo_id'], | |
| repo_type=self.config.get('repo_type', 'dataset') | |
| ) | |
| def process_batch(self, texts: List[str]) -> tuple[List[float], List[float]]: | |
| """Process a batch of texts through the sentiment model.""" | |
| result = self.replicate_api.predict(texts) | |
| return result['predicted_labels'], result['confidences'] | |
| def get_existing_subreddits(self, date_str: str) -> set: | |
| """Get set of subreddits that already have scored files for the given date.""" | |
| scored_files = self.hf_manager.list_files("data_scored_subreddit/") | |
| existing_subreddits = set() | |
| for fn in scored_files: | |
| if fn.startswith(f"data_scored_subreddit/{date_str}__") and fn.endswith('.parquet'): | |
| # Extract subreddit from filename: data_scored_subreddit/{date}__{subreddit}.parquet | |
| subreddit = Path(fn).stem.split('__', 1)[1] | |
| existing_subreddits.add(subreddit) | |
| return existing_subreddits | |
| def _sanitize(self, name: str) -> str: | |
| """ | |
| Make subreddit safe for filenames (removes slashes, spaces, etc.). | |
| """ | |
| name = name.strip().lower() | |
| name = re.sub(r"[^\w\-\.]", "_", name) | |
| return name | |
| def score_date(self, date_str: str, overwrite: bool = False) -> None: | |
| """Process a single date: download, score, save, and upload separate files per subreddit.""" | |
| self.logger.info(f"Scoring date: {date_str}") | |
| # Get existing subreddits if not overwriting | |
| existing_subreddits = set() | |
| if not overwrite: | |
| existing_subreddits = self.get_existing_subreddits(date_str) | |
| if existing_subreddits: | |
| self.logger.info(f"Found {len(existing_subreddits)} existing subreddit files for {date_str}") | |
| # Download raw file | |
| raw_path = f"{self.paths['hf_raw_dir']}/{date_str}.parquet" | |
| local_path = self.hf_manager.download_file(raw_path) | |
| df = self.file_manager.read_parquet(str(local_path)) | |
| # Validate required columns | |
| required_columns = {'text', 'score', 'post_id', 'subreddit'} | |
| missing_columns = required_columns - set(df.columns) | |
| if missing_columns: | |
| raise ValueError(f"Missing required columns: {', '.join(missing_columns)}") | |
| # Filter out existing subreddits if not overwriting | |
| subreddits_to_process = df['subreddit'].unique() | |
| if not overwrite and existing_subreddits: | |
| subreddits_to_process = [s for s in subreddits_to_process if s not in existing_subreddits] | |
| if not subreddits_to_process: | |
| self.logger.info(f"All subreddits already processed for {date_str}") | |
| return | |
| df = df[df['subreddit'].isin(subreddits_to_process)].copy() | |
| self.logger.info(f"Processing {len(subreddits_to_process)} new subreddits for {date_str}") | |
| # Process in batches | |
| batch_size = self.config.get('batch_size', 16) | |
| texts = df['text'].tolist() | |
| sentiments = [] | |
| confidences = [] | |
| for i in range(0, len(texts), batch_size): | |
| batch = texts[i:i + batch_size] | |
| batch_sentiments, batch_confidences = self.process_batch(batch) | |
| sentiments.extend(batch_sentiments[:len(batch)]) # Only take as many results as input texts | |
| confidences.extend(batch_confidences[:len(batch)]) # Only take as many results as input texts | |
| # Add results to DataFrame | |
| df['sentiment'] = sentiments | |
| df['confidence'] = confidences | |
| # Group by subreddit and save separate files | |
| subreddits = df['subreddit'].unique() | |
| self.logger.info(f"Found {len(subreddits)} subreddits to process for {date_str}") | |
| for subreddit in subreddits: | |
| subreddit_df = df[df['subreddit'] == subreddit].copy() | |
| # Save scored file per subreddit using sanitized subreddit | |
| safe_sub = self._sanitize(subreddit) | |
| filename = f"{date_str}__{safe_sub}" | |
| scored_path = self.file_manager.save_parquet(subreddit_df, filename) | |
| # Upload to HuggingFace with new path structure | |
| path_in_repo = f"data_scored_subreddit/{date_str}__{safe_sub}.parquet" | |
| self.hf_manager.upload_file(str(scored_path), path_in_repo) | |
| self.logger.info(f"Uploaded scored file for {date_str}/{subreddit} ({len(subreddit_df)} rows) to {self.config['repo_id']}/{path_in_repo}") | |
| def main(date_arg: str = None, overwrite: bool = False) -> None: | |
| if date_arg is None: | |
| raise ValueError("Date argument is required") | |
| # Load configuration | |
| cfg = setup_config() | |
| # Initialize logging | |
| logger = setup_logging(cfg['paths']['logs_dir']) | |
| # Check if REPLICATE_API_TOKEN is available | |
| if 'REPLICATE_API_TOKEN' not in cfg['secrets']: | |
| raise ValueError("REPLICATE_API_TOKEN is required for scoring") | |
| # Initialize scorer | |
| scorer = SentimentScorer(cfg) | |
| # Check if date exists in raw files | |
| raw_dates = set() | |
| for fn in scorer.hf_manager.list_files(scorer.paths['hf_raw_dir']): | |
| if fn.endswith('.parquet'): | |
| raw_dates.add(Path(fn).stem) | |
| if date_arg not in raw_dates: | |
| logger.warning(f"No raw file found for date {date_arg}") | |
| return | |
| # Check if date already exists in scored files (check subreddit files) | |
| if not overwrite: | |
| # Get existing scored files for this date | |
| scored_files = scorer.hf_manager.list_files("data_scored_subreddit/") | |
| existing_subreddits = set() | |
| for fn in scored_files: | |
| if fn.startswith(f"data_scored_subreddit/{date_arg}__") and fn.endswith('.parquet'): | |
| # Extract subreddit from filename: data_scored_subreddit/{date}__{subreddit}.parquet | |
| subreddit = Path(fn).stem.split('__', 1)[1] | |
| existing_subreddits.add(subreddit) | |
| # Check what subreddits are in the raw data | |
| raw_path = f"{scorer.paths['hf_raw_dir']}/{date_arg}.parquet" | |
| try: | |
| local_path = scorer.hf_manager.download_file(raw_path) | |
| df = scorer.file_manager.read_parquet(str(local_path)) | |
| raw_subreddits = set(df['subreddit'].unique()) | |
| # If all subreddits already exist, skip processing | |
| if raw_subreddits.issubset(existing_subreddits): | |
| logger.info(f"All subreddits for date {date_arg} already scored ({len(existing_subreddits)} files)") | |
| return | |
| else: | |
| missing_subreddits = raw_subreddits - existing_subreddits | |
| logger.info(f"Some subreddits missing for {date_arg}: {missing_subreddits}") | |
| except Exception as e: | |
| logger.warning(f"Could not check existing subreddits for {date_arg}: {e}") | |
| # Score the specified date | |
| scorer.score_date(date_arg, overwrite) | |
| if __name__ == '__main__': | |
| from reddit_analysis.common_metrics import run_with_metrics | |
| parser = argparse.ArgumentParser(description='Score raw HF dataset files via Replicate.') | |
| parser.add_argument('--date', type=str, required=True, help='YYYY-MM-DD date to process') | |
| parser.add_argument('--overwrite', action='store_true', help='Overwrite existing scored file') | |
| args = parser.parse_args() | |
| run_with_metrics("score", main, args.date, args.overwrite) | |