File size: 4,047 Bytes
f6c54d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# scripts/fetch_posts.py
import praw
import requests
import pandas as pd
from datetime import datetime, timedelta
from config import settings
from utils.helpers import logger, ensure_folder

def fetch_posts(days=settings.FETCH_DAYS, limit=None):
    """Fetch posts via Reddit API"""
    ensure_folder(settings.RAW_DATA_PATH)
    reddit = praw.Reddit(
        client_id=settings.REDDIT_CLIENT_ID,
        client_secret=settings.REDDIT_CLIENT_SECRET,
        user_agent=settings.REDDIT_USER_AGENT
    )

    posts_data = []
    end_time = datetime.utcnow()
    start_time = end_time - timedelta(days=days)

    for subreddit_name in settings.SUBREDDITS:
        try:
            subreddit = reddit.subreddit(subreddit_name)
            logger.info(f"Fetching posts from r/{subreddit_name}")
            for post in subreddit.new(limit=limit):
                post_time = datetime.utcfromtimestamp(post.created_utc)
                if post_time >= start_time:
                    posts_data.append({
                        "id": post.id,
                        "subreddit": subreddit_name,
                        "title": post.title,
                        "text": post.selftext,
                        "author": str(post.author),
                        "created_utc": post_time,
                        "score": post.score,
                        "num_comments": post.num_comments,
                        "permalink": f"https://reddit.com{post.permalink}"
                    })
        except Exception as e:
            logger.error(f"Failed to fetch posts from r/{subreddit_name}: {e}")

    df = pd.DataFrame(posts_data)
    file_path = f"{settings.RAW_DATA_PATH}reddit_posts.csv"
    df.to_csv(file_path, index=False)
    logger.info(f"Saved {len(df)} posts to {file_path}")
    return df

def fetch_posts_pushshift(subreddit, start_epoch, end_epoch, limit=500):
    """Fetch posts via Pushshift API as fallback"""
    url = f"https://api.pushshift.io/reddit/submission/search/?subreddit={subreddit}&after={start_epoch}&before={end_epoch}&size={limit}"
    try:
        resp = requests.get(url)
        resp.raise_for_status()
        data = resp.json()["data"]
        posts = []
        for p in data:
            posts.append({
                "id": p.get("id"),
                "subreddit": subreddit,
                "title": p.get("title"),
                "text": p.get("selftext"),
                "author": p.get("author"),
                "created_utc": datetime.utcfromtimestamp(p.get("created_utc")),
                "score": p.get("score"),
                "num_comments": p.get("num_comments"),
                "permalink": f"https://reddit.com{p.get('permalink')}"
            })
        return posts
    except Exception as e:
        logger.error(f"Pushshift fetch failed for {subreddit}: {e}")
        return []

def fetch_posts_with_fallback(subreddit, days=settings.FETCH_DAYS, limit=None):
    """Fetch posts using Reddit API first, fallback to Pushshift"""
    end_time = datetime.utcnow()
    start_time = end_time - timedelta(days=days)
    start_epoch = int(start_time.timestamp())
    end_epoch = int(end_time.timestamp())

    df_posts = fetch_posts(days=days, limit=limit)
    posts = df_posts.to_dict("records")

    if limit and len(posts) < limit:
        remaining = limit - len(posts)
        logger.info(f"Reddit API returned {len(posts)} posts, fetching {remaining} more from Pushshift...")
        pushshift_posts = fetch_posts_pushshift(subreddit, start_epoch, end_epoch, limit=remaining)
        posts += pushshift_posts

    # Deduplicate
    posts_dict = {p['id']: p for p in posts}
    posts = list(posts_dict.values())

    df_posts = pd.DataFrame(posts)
    ensure_folder(settings.RAW_DATA_PATH)
    file_path = f"{settings.RAW_DATA_PATH}reddit_posts.csv"
    df_posts.to_csv(file_path, index=False)
    logger.info(f"Saved {len(df_posts)} combined posts to {file_path}")

    return df_posts