Spaces:
Runtime error
Runtime error
| import time | |
| import pandas as pd | |
| import streamlit as st | |
| from transformers import pipeline | |
| from constants import tweet_generator_prompt, absa_prompt | |
| # Adjust the layout for wider containers | |
| st.set_page_config(layout="wide") | |
| # @st.cache_resource | |
| def load_model(): | |
| classification_pipe = pipeline( | |
| "text-classification", model="tweetpie/toxic-content-detector", top_k=None) | |
| absa_pipe = pipeline("text2text-generation", model="tweetpie/stance-aware-absa") | |
| tweet_generation_pipe = pipeline("text2text-generation", model="tweetpie/stance-directed-tweet-generator") | |
| return classification_pipe, absa_pipe, tweet_generation_pipe | |
| # Set up the title | |
| st.title("Towards a Programmable Humanizing AI through Scalable Stance-Directed Architecture Dashboard") | |
| # Container for ideology selection spanning across first two columns | |
| col11, col12 = st.columns([6, 3]) # Adjust the ratios as needed for better appearance | |
| col21, col22, col23 = st.columns([3, 3, 3]) # Adjust the ratios as needed for better appearance | |
| with col11: | |
| model_selection = st.selectbox( | |
| "Select an ideology", | |
| options=['Left', 'Right'], | |
| index=0 # Default selection | |
| ) | |
| # Layout for entities and aspects inputs | |
| with col21: | |
| st.header("Entities") | |
| pro_entities = st.text_input("Pro Entities", help="Enter pro entities separated by commas") | |
| anti_entities = st.text_input("Anti Entities", help="Enter anti entities separated by commas") | |
| neutral_entities = st.text_input("Neutral Entities", help="Enter neutral entities separated by commas") | |
| with col22: | |
| st.header("Aspects") | |
| pro_aspects = st.text_input("Pro Aspects", help="Enter pro aspects separated by commas") | |
| anti_aspects = st.text_input("Anti Aspects", help="Enter anti aspects separated by commas") | |
| neutral_aspects = st.text_input("Neutral Aspects", help="Enter neutral aspects separated by commas") | |
| # Generate button (placed outside the columns so it spans the full width) | |
| with col12: | |
| generate_button = st.button("Generate tweet and classify toxicity") | |
| # Load the model (commented out, assuming model loading is handled elsewhere) | |
| classifier, absa, generator = load_model() | |
| # Process the input text and generate output | |
| if generate_button: | |
| with col23: # This block is for displaying outputs in the wider column | |
| with st.spinner('Generating the tweet...'): | |
| prompt = tweet_generator_prompt.format( | |
| ideology=model_selection.lower(), | |
| pro_entities=pro_entities, | |
| anti_entities=anti_entities, | |
| neutral_entities=neutral_entities, | |
| pro_aspects=pro_aspects, | |
| anti_aspects=anti_aspects, | |
| neutral_aspects=neutral_aspects | |
| ) | |
| generated_tweet = generator(prompt, max_new_tokens=80, do_sample=True, num_return_sequences=3) | |
| # Displaying the input and model's output | |
| st.write(f"Generated Tweet-1: {generated_tweet[0]['generated_text']}") | |
| st.write(f"Generated Tweet-2: {generated_tweet[1]['generated_text']}") | |
| st.write(f"Generated Tweet-3: {generated_tweet[2]['generated_text']}") | |
| with st.spinner('Generating the Stance-Aware ABSA output...'): | |
| absa_output = absa(absa_prompt.format(generated_tweet=generated_tweet[0]['generated_text'])) | |
| stances = [x.strip() for x in absa_output[0]['generated_text'].split(',')] | |
| stances = [{ | |
| 'Aspect': x.split(':')[0], | |
| 'Sentiment': x.split(':')[1] | |
| } for x in stances] | |
| stances_df = pd.DataFrame(stances) | |
| stances_df.index = stances_df.index + 1 | |
| st.write("Stance-Aware ABSA Output:") | |
| st.table(stances_df) | |
| with st.spinner('Classifying the toxicity...'): | |
| model_output = classifier(generated_tweet[0]['generated_text']) | |
| output = model_output[0] | |
| st.write("Toxicity Classifier Output:") | |
| for i in range(len(output)): | |
| if output[i]['label'] == 'LABEL_0': | |
| st.write(f"Non-Toxic Content: {output[i]['score'] * 100:.1f}%") | |
| elif output[i]['label'] == 'LABEL_2': | |
| st.write(f"Toxic Content: {output[i]['score'] * 100:.1f}%") | |
| else: | |
| continue | |