Spaces:
Sleeping
Sleeping
File size: 4,047 Bytes
f6c54d5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 | # scripts/fetch_posts.py
import praw
import requests
import pandas as pd
from datetime import datetime, timedelta
from config import settings
from utils.helpers import logger, ensure_folder
def fetch_posts(days=settings.FETCH_DAYS, limit=None):
"""Fetch posts via Reddit API"""
ensure_folder(settings.RAW_DATA_PATH)
reddit = praw.Reddit(
client_id=settings.REDDIT_CLIENT_ID,
client_secret=settings.REDDIT_CLIENT_SECRET,
user_agent=settings.REDDIT_USER_AGENT
)
posts_data = []
end_time = datetime.utcnow()
start_time = end_time - timedelta(days=days)
for subreddit_name in settings.SUBREDDITS:
try:
subreddit = reddit.subreddit(subreddit_name)
logger.info(f"Fetching posts from r/{subreddit_name}")
for post in subreddit.new(limit=limit):
post_time = datetime.utcfromtimestamp(post.created_utc)
if post_time >= start_time:
posts_data.append({
"id": post.id,
"subreddit": subreddit_name,
"title": post.title,
"text": post.selftext,
"author": str(post.author),
"created_utc": post_time,
"score": post.score,
"num_comments": post.num_comments,
"permalink": f"https://reddit.com{post.permalink}"
})
except Exception as e:
logger.error(f"Failed to fetch posts from r/{subreddit_name}: {e}")
df = pd.DataFrame(posts_data)
file_path = f"{settings.RAW_DATA_PATH}reddit_posts.csv"
df.to_csv(file_path, index=False)
logger.info(f"Saved {len(df)} posts to {file_path}")
return df
def fetch_posts_pushshift(subreddit, start_epoch, end_epoch, limit=500):
"""Fetch posts via Pushshift API as fallback"""
url = f"https://api.pushshift.io/reddit/submission/search/?subreddit={subreddit}&after={start_epoch}&before={end_epoch}&size={limit}"
try:
resp = requests.get(url)
resp.raise_for_status()
data = resp.json()["data"]
posts = []
for p in data:
posts.append({
"id": p.get("id"),
"subreddit": subreddit,
"title": p.get("title"),
"text": p.get("selftext"),
"author": p.get("author"),
"created_utc": datetime.utcfromtimestamp(p.get("created_utc")),
"score": p.get("score"),
"num_comments": p.get("num_comments"),
"permalink": f"https://reddit.com{p.get('permalink')}"
})
return posts
except Exception as e:
logger.error(f"Pushshift fetch failed for {subreddit}: {e}")
return []
def fetch_posts_with_fallback(subreddit, days=settings.FETCH_DAYS, limit=None):
"""Fetch posts using Reddit API first, fallback to Pushshift"""
end_time = datetime.utcnow()
start_time = end_time - timedelta(days=days)
start_epoch = int(start_time.timestamp())
end_epoch = int(end_time.timestamp())
df_posts = fetch_posts(days=days, limit=limit)
posts = df_posts.to_dict("records")
if limit and len(posts) < limit:
remaining = limit - len(posts)
logger.info(f"Reddit API returned {len(posts)} posts, fetching {remaining} more from Pushshift...")
pushshift_posts = fetch_posts_pushshift(subreddit, start_epoch, end_epoch, limit=remaining)
posts += pushshift_posts
# Deduplicate
posts_dict = {p['id']: p for p in posts}
posts = list(posts_dict.values())
df_posts = pd.DataFrame(posts)
ensure_folder(settings.RAW_DATA_PATH)
file_path = f"{settings.RAW_DATA_PATH}reddit_posts.csv"
df_posts.to_csv(file_path, index=False)
logger.info(f"Saved {len(df_posts)} combined posts to {file_path}")
return df_posts
|