| | import os |
| | import praw |
| | import gradio as gr |
| | from transformers import TextClassificationPipeline, AutoModelForSequenceClassification, AutoTokenizer |
| |
|
| |
|
| |
|
| | client_id = os.environ["client_id"] |
| | client_secret = os.environ["client_secret"] |
| | user_agent = os.environ["user_agent"] |
| |
|
| | reddit = praw.Reddit(client_id =client_id, |
| | client_secret =client_secret, user_agent =user_agent) |
| |
|
| |
|
| | model_name = "ProsusAI/finbert" |
| | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) |
| | model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels = 3) |
| | pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer, max_length=64, truncation=True, padding = 'max_length') |
| |
|
| |
|
| | def reddit_analysis(subreddit_name, num_posts): |
| | |
| | local_score = 0 |
| | local_titles = [] |
| | subreddit = reddit.subreddit(subreddit_name) |
| | if int(num_posts) > 16: |
| | return "Number of posts should be less than 15" |
| | else: |
| | for post in subreddit.new(limit=int(num_posts)): |
| | |
| | prediction = pipe(post.title) |
| | local_titles.append(post.title) |
| | |
| | if prediction[0]["label"] == "negative": |
| | local_score-= prediction[0]["score"] |
| | elif prediction[0]["label"] == "positive": |
| | local_score+= prediction[0]["score"] |
| | |
| | titles_string = "\n".join(local_titles) |
| | |
| | return local_score, titles_string |
| | |
| | |
| | |
| |
|
| | |
| | |
| | total_score = 0 |
| | text_list = [] |
| | def manual_analysis(text): |
| | |
| | global total_score |
| | prediction = pipe(text) |
| | |
| | text_list.append(text) |
| | if prediction[0]["label"] == "negative": |
| | total_score-= prediction[0]["score"] |
| | elif prediction[0]["label"] == "positive": |
| | total_score+= prediction[0]["score"] |
| | |
| | return prediction, total_score |
| |
|
| |
|
| | with gr.Blocks() as demo: |
| | with gr.Tab("Seperate Analysis"): |
| | first_title = """<p><h1 align="center" style="font-size: 24px;">Analyse texts manually</h1></p>""" |
| | gr.HTML(first_title) |
| | with gr.Row(): |
| | with gr.Column(): |
| | text = gr.Textbox(label="text") |
| | analyse = gr.Button("Analyse") |
| |
|
| |
|
| | with gr.Column(): |
| | label_score = gr.Textbox(label="Label/Score") |
| | average_score = gr.Textbox(label="Average Score") |
| |
|
| | analyse.click(fn=manual_analysis, inputs=text, outputs=[label_score, average_score], api_name="Calc1") |
| | |
| | with gr.Tab("Mass Analysis"): |
| | second_title = """<p><h1 align="center" style="font-size: 24px;">Analyse latest posts from Reddit</h1></p>""" |
| | gr.HTML(second_title) |
| | with gr.Row(): |
| | with gr.Column(): |
| | subreddit_name = gr.Textbox(label="Subreddit Name") |
| | |
| | num_post = gr.Textbox(label="Number of Posts") |
| | analyse = gr.Button("Analyse") |
| | with gr.Column(): |
| | average_score = gr.Textbox(label="Average Score") |
| | tifu_titles = gr.Textbox(label="Tifu Titles") |
| |
|
| | analyse.click(fn=reddit_analysis, inputs=[subreddit_name, num_post], outputs=[average_score, tifu_titles], api_name="Calc2") |
| | |
| |
|
| |
|
| | demo.launch() |