Spaces:
Sleeping
Sleeping
| import os | |
| from pathlib import Path | |
| import pytest | |
| import pandas as pd | |
| from datetime import datetime, date | |
| import pytz | |
| from unittest.mock import Mock, patch | |
| from reddit_analysis.scraper.scrape import RedditScraper, RedditAPI, FileManager, HuggingFaceManager | |
| def mock_config(): | |
| """Create a mock configuration dictionary.""" | |
| return { | |
| 'config': { | |
| 'repo_id': 'test/repo', | |
| 'repo_type': 'dataset', | |
| 'subreddits': [ | |
| {'name': 'test1', 'post_limit': 2, 'comment_limit': 2}, | |
| {'name': 'test2', 'post_limit': 2, 'comment_limit': 2} | |
| ], | |
| 'post_limit': 100, | |
| 'timezone': 'UTC' | |
| }, | |
| 'paths': { | |
| 'raw_dir': Path('data/raw'), | |
| 'logs_dir': Path('logs'), | |
| 'hf_raw_dir': 'data/raw' | |
| }, | |
| 'secrets': { | |
| 'HF_TOKEN': 'test_token', | |
| 'REDDIT_CLIENT_ID': 'test_id', | |
| 'REDDIT_CLIENT_SECRET': 'test_secret', | |
| 'REDDIT_USER_AGENT': 'test_agent' | |
| } | |
| } | |
| def mock_reddit_api(): | |
| """Create a mock RedditAPI.""" | |
| mock = Mock(spec=RedditAPI) | |
| # Create mock submission objects | |
| mock_submissions = [] | |
| for i in range(2): | |
| submission = Mock() | |
| submission.id = f'post{i}' | |
| submission.title = f'Test Post {i}' | |
| submission.selftext = f'Test content {i}' | |
| submission.score = i + 1 | |
| submission.created_utc = datetime.now(pytz.UTC).timestamp() | |
| submission.url = f'https://reddit.com/test{i}' | |
| submission.num_comments = i * 10 | |
| # Mock the comments | |
| comment = Mock() | |
| comment.id = f'comment{i}' | |
| comment.body = f'Test comment {i}' | |
| comment.score = i + 5 | |
| comment.created_utc = datetime.now(pytz.UTC).timestamp() | |
| comment.parent_id = submission.id | |
| # Set up comment attributes | |
| submission.comments = Mock() | |
| submission.comments._comments = [comment] | |
| submission.comments.replace_more = Mock(return_value=None) | |
| mock_submissions.append(submission) | |
| # Set up the mock subreddit | |
| mock_subreddit = Mock() | |
| mock_subreddit.top.return_value = mock_submissions | |
| mock.get_subreddit.return_value = mock_subreddit | |
| return mock | |
| def mock_file_manager(): | |
| """Create a mock FileManager.""" | |
| mock = Mock(spec=FileManager) | |
| return mock | |
| def mock_hf_manager(): | |
| """Create a mock HuggingFaceManager.""" | |
| mock = Mock(spec=HuggingFaceManager) | |
| return mock | |
| def test_get_posts(mock_config, mock_reddit_api): | |
| """Test the get_posts method.""" | |
| # Initialize scraper with mocked RedditAPI | |
| scraper = RedditScraper(mock_config, reddit_api=mock_reddit_api) | |
| # Get posts for a test subreddit | |
| df = scraper.get_posts({'name': 'test1', 'post_limit': 2, 'comment_limit': 2}) | |
| # Verify DataFrame structure and content | |
| assert isinstance(df, pd.DataFrame) | |
| assert len(df) == 4 # 2 posts + 2 comments | |
| # Verify posts | |
| posts_df = df[df['type'] == 'post'] | |
| assert len(posts_df) == 2 | |
| assert posts_df['subreddit'].iloc[0] == 'test1' | |
| assert posts_df['post_id'].iloc[0] == 'post0' | |
| assert posts_df['post_id'].iloc[1] == 'post1' | |
| # Verify comments | |
| comments_df = df[df['type'] == 'comment'] | |
| assert len(comments_df) == 2 | |
| assert comments_df['subreddit'].iloc[0] == 'test1' | |
| assert comments_df['post_id'].iloc[0] == 'comment0' | |
| assert comments_df['parent_id'].iloc[0] == 'post0' | |
| def test_upload_to_hf_deduplication(mock_config, mock_file_manager, mock_hf_manager): | |
| """Test the upload_to_hf method with deduplication.""" | |
| # Create test DataFrames | |
| prev_df = pd.DataFrame({ | |
| 'post_id': ['post0', 'post1'], | |
| 'title': ['Old Post 0', 'Old Post 1'], | |
| 'text': ['Old content 0', 'Old content 1'], | |
| 'score': [1, 2], | |
| 'subreddit': ['test1', 'test1'], | |
| 'created_utc': [datetime.now(pytz.UTC)] * 2, | |
| 'url': ['https://reddit.com/old0', 'https://reddit.com/old1'], | |
| 'num_comments': [10, 20] | |
| }) | |
| new_df = pd.DataFrame({ | |
| 'post_id': ['post1', 'post2'], | |
| 'title': ['New Post 1', 'New Post 2'], | |
| 'text': ['New content 1', 'New content 2'], | |
| 'score': [3, 4], | |
| 'subreddit': ['test1', 'test1'], | |
| 'created_utc': [datetime.now(pytz.UTC)] * 2, | |
| 'url': ['https://reddit.com/new1', 'https://reddit.com/new2'], | |
| 'num_comments': [30, 40] | |
| }) | |
| # Mock file operations | |
| mock_hf_manager.download_file.return_value = Path('test.parquet') | |
| mock_file_manager.read_parquet.return_value = prev_df | |
| # Initialize scraper with mocked dependencies | |
| scraper = RedditScraper( | |
| mock_config, | |
| file_manager=mock_file_manager, | |
| hf_manager=mock_hf_manager | |
| ) | |
| # Upload new data | |
| scraper._upload_to_hf(new_df, '2025-04-20') | |
| # Verify file operations | |
| mock_file_manager.save_parquet.assert_called_once() | |
| mock_hf_manager.upload_file.assert_called_once() | |
| def test_cli_missing_env(monkeypatch, tmp_path): | |
| """Test CLI with missing environment variables.""" | |
| # Create a temporary .env file without required variables | |
| env_path = tmp_path / '.env' | |
| env_path.write_text('') | |
| # Set environment variable to point to our test .env | |
| monkeypatch.setenv('REDDIT_ANALYSIS_ENV', str(env_path)) | |
| # Remove any existing Reddit API credentials from environment | |
| for key in ['REDDIT_CLIENT_ID', 'REDDIT_CLIENT_SECRET', 'REDDIT_USER_AGENT']: | |
| monkeypatch.delenv(key, raising=False) | |
| # Ensure HF_TOKEN is present so only Reddit client vars are missing | |
| monkeypatch.setenv('HF_TOKEN', 'dummy_hf_token') | |
| # Mock Streamlit's HAS_STREAMLIT to True | |
| monkeypatch.setattr('reddit_analysis.config_utils.HAS_STREAMLIT', True) | |
| # Mock is_running_streamlit to True | |
| monkeypatch.setattr('reddit_analysis.config_utils.is_running_streamlit', lambda: True) | |
| # Mock Streamlit secrets to return None | |
| mock_secrets = Mock() | |
| mock_secrets.get.return_value = None | |
| monkeypatch.setattr('streamlit.secrets', mock_secrets) | |
| # Print for debug | |
| import os | |
| print('DEBUG: REDDIT_CLIENT_ID value before main:', os.environ.get('REDDIT_CLIENT_ID')) | |
| # Run the CLI with --date argument | |
| with pytest.raises(ValueError) as exc_info: | |
| from reddit_analysis.scraper.scrape import main | |
| main('2025-04-20') | |
| assert "Missing required environment variables: REDDIT_CLIENT_ID" in str(exc_info.value) |