Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| from transformers import AutoModelForSequenceClassification, pipeline, AutoTokenizer, DistilBertForSequenceClassification, DistilBertTokenizerFast | |
| import pandas as pd | |
| import comments | |
| from random import randint | |
| import requests | |
| def predict_cyberbullying_probability(sentence, tokenizer, model): | |
| # Preprocess the input sentence | |
| inputs = tokenizer(sentence, padding='max_length', return_token_type_ids=False, return_attention_mask=True, truncation=True, max_length=512, return_tensors='pt') | |
| attention_mask = inputs['attention_mask'] | |
| inputs = inputs['input_ids'] | |
| with torch.no_grad(): | |
| # Forward pass | |
| outputs = model(inputs, attention_mask=attention_mask) | |
| probs = torch.sigmoid(outputs.logits.unsqueeze(1).flatten()) | |
| res = probs.numpy().tolist() | |
| return res | |
| # @st.cache | |
| def perform_cyberbullying_analysis(tweet): | |
| with st.spinner(text="loading model, wait until spinner ends..."): | |
| model = AutoModelForSequenceClassification.from_pretrained('kingsotn/finetuned_cyberbullying') | |
| tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased') | |
| df = pd.DataFrame({'comment': [tweet]}) | |
| list_probs = predict_cyberbullying_probability(tweet, tokenizer, model) | |
| for i, label in enumerate(labels[1:]): | |
| df[label] = list_probs[i] | |
| return df | |
| def perform_default_analysis(model_name): | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
| clf = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer, framework="pt") | |
| tweet = st.text_area(label="Enter Text:",value="I'm nice at ping pong") | |
| submitted = st.form_submit_button("Analyze") | |
| if submitted: | |
| #loading bar | |
| with st.spinner(text="loading..."): | |
| out = clf(tweet) | |
| st.json(out) | |
| if out[0]["label"] == "POSITIVE" or out[0]["label"] == "POS": | |
| st.balloons() | |
| # prompt = f"{basic_prompt} + \n\nThe user wrote a tweet that says: {tweet}, compliment them on how nice of a person they are! Remember try to be as cringe and awkard as possible!" | |
| # response = generator(prompt, max_length=1000)[0] | |
| st.success("nice tweet!") | |
| else: | |
| # prompt = f"{basic_prompt} + \n\nThe user wrote a tweet that says: {tweet}, tell them on how terrible of a person they are! Remember try to be as cringe and awkard as possible!" | |
| # response = generator(prompt, max_length=1000)[0] | |
| st.error("bad tweet!") | |
| # main --> | |
| st.title("Toxic Tweets Analyzer") | |
| st.write("π‘ Toxic Tweets Analyzer uses AI with kingsotn/finetuned_cyberbullying (distilbert) to score tweets for toxicity, threat, and insult.") | |
| image = "kanye_loves_tweet.jpg" | |
| st.image(image, use_column_width=True) | |
| labels = ['comment', 'toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] | |
| with st.form("my_form"): | |
| #select model | |
| model_name = st.selectbox("Enter a text and select a pre-trained model to get the sentiment analysis", ["kingsotn/finetuned_cyberbullying", "distilbert-base-uncased-finetuned-sst-2-english", "finiteautomata/bertweet-base-sentiment-analysis", "distilbert-base-uncased"]) | |
| if model_name == "kingsotn/finetuned_cyberbullying": | |
| default = "I'm not even going to lie to you. I love me so much right now." | |
| tweet = st.text_area(label="Enter Text:",value=default) | |
| submitted = st.form_submit_button("Analyze textbox") | |
| random = st.form_submit_button("Get a random πππ tweet (warning!!)") | |
| kanye = st.form_submit_button("Get a ye quote π»π€π§πΆ") | |
| if random: | |
| tweet = comments.comments[randint(0, 354)] | |
| st.write(tweet) | |
| submitted = True | |
| if kanye: | |
| response = requests.get('https://api.kanye.rest/') | |
| if response.status_code == 200: | |
| data = response.json() | |
| tweet = data['quote'] | |
| else: | |
| st.error("Error getting Kanye quote | status code: " + str(response.status_code)) | |
| st.write(tweet) | |
| submitted = True | |
| if submitted: | |
| df = perform_cyberbullying_analysis(tweet) | |
| st.table(df) | |
| else: | |
| perform_default_analysis(model_name) |