Spaces:
Runtime error
Runtime error
| from transformers import pipeline | |
| import gradio as gr | |
| import praw | |
| from praw.models import MoreComments | |
| import os | |
| from statistics import mean | |
| import torch | |
| import time | |
| MIN_SCORE_THRESHOLD=50 | |
| NUM_VIBE_LABELS=5 | |
| device = 0 | |
| if not torch.cuda.is_available(): | |
| print("GPU isn't available. Running on CPU.") | |
| device = -1 | |
| pipe = pipeline(model="facebook/bart-large-mnli", device=device) | |
| reddit = praw.Reddit( | |
| client_id=os.environ.get("reddit_client_id"), | |
| client_secret=os.environ.get("reddit_client_secret"), | |
| user_agent="Hugging Face Vibe Checker", | |
| ) | |
| vibe_names = [ | |
| "wholesome", | |
| "chill", | |
| "funny", | |
| "inspiring", | |
| "aesthetic", | |
| "nerdy", | |
| "supportive", | |
| "informative", | |
| "activism", | |
| "nostalgic", | |
| "creative", | |
| "memorable", | |
| "cryptic", | |
| "dark", | |
| "whimsical", | |
| "spiritual", | |
| "intellectual", | |
| "meme", | |
| ] | |
| vibe_name_color_map={vibe:"red" for vibe in vibe_names} | |
| def vibe_check(url, min_karma): | |
| comments = get_comments(url) | |
| #comments.sort(key=comment_compare, reverse=True) | |
| comment_bodies = [c["Comment"] for c in comments if c["Score"] > int(min_karma)] | |
| print("Total comments: " + str(len(comment_bodies))) | |
| print("Starting comment classification.") | |
| start = time.time() | |
| classes = pipe( | |
| comment_bodies[:1], | |
| candidate_labels=vibe_names, | |
| ) | |
| end = time.time() | |
| print("Comment classification took: " + str(end - start) + "ms") | |
| averages = {} | |
| for i in range(len(vibe_names)): | |
| averages[vibe_names[i]] = mean([c["scores"][i] for c in classes]) | |
| return averages | |
| def get_vibes_html(vibes): | |
| return " ".join( | |
| [ | |
| f"<span style=\"color:{vibe_name_color_map[vibe]};font-size:{12 * i}px\">{vibe}</span>" | |
| for i, vibe in enumerate(vibes) | |
| ]) | |
| def get_comments(url): | |
| submission = reddit.submission(url=url) | |
| comments = [] | |
| for comment in submission.comments: | |
| if isinstance(comment, MoreComments) or comment.body == "[deleted]": | |
| continue | |
| val = { | |
| "Comment": comment.body, | |
| "Author": comment.author, | |
| "Date Posted": comment.created_utc, | |
| "Score": comment.score, | |
| } | |
| comments.append(val) | |
| return comments | |
| def comment_compare(comment): | |
| return comment["Score"] | |
| with gr.Blocks() as demo: | |
| url = gr.Textbox(label="Url") | |
| min_karma = gr.Textbox(label="Minimum Karma", value=MIN_SCORE_THRESHOLD) | |
| output = gr.Label(label="Output", num_top_classes=NUM_VIBE_LABELS) | |
| submit_button = gr.Button("Submit") | |
| submit_button.click( | |
| fn=vibe_check, inputs=[url, min_karma], outputs=output, api_name="vibe_check" | |
| ) | |
| gr.Examples( | |
| [ | |
| "https://www.reddit.com/r/AskReddit/comments/yiazab/would_you_support_a_mandatory_retirement_age_of/", | |
| "https://www.reddit.com/r/politics/comments/yqa3cg/john_fetterman_wins_pennsylvania_senate_race/", | |
| "https://www.reddit.com/r/pics/comments/zyj3ll/andrew_and_tristan_tate_were_arrested_they_are/", | |
| ], | |
| url, | |
| output, | |
| vibe_check, | |
| # cache_examples=True, | |
| ) | |
| demo.theme=gr.themes.Base() | |
| demo.launch() | |