Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import regex as re | |
| import torch | |
| import nltk | |
| import pandas as pd | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from nltk.tokenize import sent_tokenize | |
| import plotly.express as px | |
| import time | |
| import tqdm | |
| nltk.download('punkt') | |
| # Define the model and tokenizer | |
| checkpoint = "sadickam/sdg-classification-bert" | |
| tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
| model = AutoModelForSequenceClassification.from_pretrained(checkpoint) | |
| # Define the function for preprocessing text | |
| def prep_text(text): | |
| clean_sents = [] | |
| sent_tokens = sent_tokenize(str(text)) | |
| for sent_token in sent_tokens: | |
| word_tokens = [str(word_token).strip().lower() for word_token in sent_token.split()] | |
| clean_sents.append(' '.join((word_tokens))) | |
| joined = ' '.join(clean_sents).strip(' ') | |
| joined = re.sub(r'`', "", joined) | |
| joined = re.sub(r'"', "", joined) | |
| return joined | |
| # APP INFO | |
| def app_info(): | |
| check = """ | |
| Please go to either the "Single-Text-Prediction" or "Multi-Text-Prediction" tab to analyse your text. | |
| """ | |
| return check | |
| # Create Gradio interface for single text | |
| iface1 = gr.Interface( | |
| fn=app_info, inputs=None, outputs=['text'], title="General-Infomation", | |
| description= ''' | |
| This app powered by the sgdBERT model (sadickam/sdg-classification-bert) is for automatic classification of text with respect to | |
| the UN Sustainable Development Goals (SDG). Note that 16 out of the 17 SDGs labels are covered. This app is for sustainability | |
| assessment and benchmarking and is not limited to a specific industry. The model powering this app was developed using the | |
| OSDG Community Dataset (OSDG-CD) [Link - https://zenodo.org/record/5550238#.Y8Sd5f5ByF5]. | |
| This app has two analysis modules summarised below: | |
| - Single-Text-Prediction - Analyses text pasted in a text box and return SDG prediction. | |
| - Multi-Text-Prediction - Analyses multiple rows of texts in an uploaded CSV file and returns a downloadable CSV file with SDG prediction for each row of text. | |
| This app runs on a free server and may therefore not be suitable for analysing large CSV and PDF files. | |
| If you need assistance with analysing large CSV or PDF files, do get in touch using the contact information in the Contact section. | |
| <h3>Contact</h3> | |
| <p>We would be happy to receive your feedback regarding this app. If you would also like to collaborate with us to explore some use cases for the model | |
| powering this app, we are happy to hear from you.</p> | |
| <p>App contact: s.sadick@deakin.edu.au</p> | |
| ''') | |
| # SINGLE TEXT | |
| # Define the prediction function | |
| def predict_sdg(text): | |
| # Preprocess the input text | |
| cleaned_text = prep_text(text) | |
| if cleaned_text == "": | |
| raise gr.Error('This model needs some text input to return a prediction') | |
| elif cleaned_text != "": | |
| # Tokenize the preprocessed text | |
| tokenized_text = tokenizer(cleaned_text, return_tensors="pt", truncation=True, max_length=512, padding=True) | |
| # Predict | |
| text_logits = model(**tokenized_text).logits | |
| predictions = torch.softmax(text_logits, dim=1).tolist()[0] | |
| # SDG labels | |
| label_list = [ | |
| 'GOAL 1: No Poverty', | |
| 'GOAL 2: Zero Hunger', | |
| 'GOAL 3: Good Health and Well-being', | |
| 'GOAL 4: Quality Education', | |
| 'GOAL 5: Gender Equality', | |
| 'GOAL 6: Clean Water and Sanitation', | |
| 'GOAL 7: Affordable and Clean Energy', | |
| 'GOAL 8: Decent Work and Economic Growth', | |
| 'GOAL 9: Industry, Innovation and Infrastructure', | |
| 'GOAL 10: Reduced Inequality', | |
| 'GOAL 11: Sustainable Cities and Communities', | |
| 'GOAL 12: Responsible Consumption and Production', | |
| 'GOAL 13: Climate Action', | |
| 'GOAL 14: Life Below Water', | |
| 'GOAL 15: Life on Land', | |
| 'GOAL 16: Peace, Justice and Strong Institutions' | |
| ] | |
| # dictionary with label as key and percentage as value | |
| pred_dict = dict(zip(label_list, predictions)) | |
| # sort 'pred_dict' by value and index the highest at [0] | |
| sorted_preds = sorted(pred_dict.items(), key=lambda x: x[1], reverse=True) | |
| # Make dataframe for plotly bar chart | |
| u, v = zip(*sorted_preds) | |
| m = list(u) | |
| n = list(v) | |
| df2 = pd.DataFrame() | |
| df2['SDG'] = m | |
| df2['Likelihood'] = n | |
| # plot graph of predictions | |
| fig = px.bar(df2, x="Likelihood", y="SDG", orientation="h") | |
| fig.update_layout( | |
| # barmode='stack', | |
| template='seaborn', font=dict(family="Arial", size=12, color="black"), | |
| autosize=True, | |
| #width=800, | |
| #height=500, | |
| xaxis_title="Likelihood of SDG", | |
| yaxis_title="Sustainable development goals (SDG)", | |
| # legend_title="Topics" | |
| ) | |
| fig.update_xaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12)) | |
| fig.update_yaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12)) | |
| fig.update_annotations(font_size=12) # this changes y_axis, x_axis and subplot title font sizes | |
| # Make dataframe for plotly bar chart | |
| #df2 = pd.DataFrame(sorted_preds, columns=['SDG', 'Likelihood']) | |
| # Return the top prediction | |
| top_prediction = sorted_preds[0] | |
| # Return result | |
| return {top_prediction[0]: round(top_prediction[1], 3)}, fig | |
| # Create Gradio interface for single text | |
| iface2 = gr.Interface(fn=predict_sdg, | |
| inputs=gr.Textbox(lines=7, label="Paste or type text here"), | |
| outputs=[gr.Label(label="Top SDG Predicted", show_label=True), gr.Plot(label="Likelihood of all SDG", show_label=True)], | |
| title="Single Text Prediction", | |
| article="**Note:** The quality of model predictions may depend on the quality of information provided." | |
| ) | |
| # UPLOAD CSV | |
| # Define the prediction function | |
| def predict_sdg_from_csv(file, progress=gr.Progress()): | |
| # Read the CSV file | |
| df_docs = pd.read_csv(file) | |
| text_list = df_docs["text_inputs"].tolist() | |
| # SDG labels list | |
| label_list = [ | |
| 'GOAL 1: No Poverty', | |
| 'GOAL 2: Zero Hunger', | |
| 'GOAL 3: Good Health and Well-being', | |
| 'GOAL 4: Quality Education', | |
| 'GOAL 5: Gender Equality', | |
| 'GOAL 6: Clean Water and Sanitation', | |
| 'GOAL 7: Affordable and Clean Energy', | |
| 'GOAL 8: Decent Work and Economic Growth', | |
| 'GOAL 9: Industry, Innovation and Infrastructure', | |
| 'GOAL 10: Reduced Inequality', | |
| 'GOAL 11: Sustainable Cities and Communities', | |
| 'GOAL 12: Responsible Consumption and Production', | |
| 'GOAL 13: Climate Action', | |
| 'GOAL 14: Life Below Water', | |
| 'GOAL 15: Life on Land', | |
| 'GOAL 16: Peace, Justice and Strong Institutions' | |
| ] | |
| # Lists for appending predictions | |
| predicted_labels = [] | |
| prediction_score = [] | |
| # Preprocess text and make predictions | |
| for text_input in progress.tqdm(text_list, desc="Analysing data"): | |
| time.sleep(0.02) # Sleep to avoid rate limiting | |
| cleaned_text = prep_text(text_input) | |
| tokenized_text = tokenizer(cleaned_text, return_tensors="pt", truncation=True, max_length=512, padding=True) | |
| text_logits = model(**tokenized_text).logits | |
| predictions = torch.softmax(text_logits, dim=1).tolist()[0] | |
| pred_dict = dict(zip(label_list, predictions)) | |
| sorted_preds = sorted(pred_dict.items(), key=lambda g: g[1], reverse=True) | |
| predicted_labels.append(sorted_preds[0][0]) | |
| prediction_score.append(sorted_preds[0][1]) | |
| # Append predictions to the DataFrame | |
| df_docs['SDG_predicted'] = predicted_labels | |
| df_docs['prediction_score'] = prediction_score | |
| df_docs.to_csv('sdg_predictions.csv') | |
| output_csv = gr.File(value='sdg_predictions.csv', visible=True) | |
| # Create the histogram | |
| fig = px.histogram(df_docs, y="SDG_predicted") | |
| fig.update_layout( | |
| template='seaborn', | |
| font=dict(family="Arial", size=12, color="black"), | |
| autosize=True, | |
| #width=800, | |
| #height=500, | |
| xaxis_title="SDG counts", | |
| yaxis_title="Sustainable development goals (SDG)", | |
| ) | |
| fig.update_xaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12)) | |
| fig.update_yaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12)) | |
| fig.update_annotations(font_size=12) | |
| return fig, output_csv | |
| # Define the input component | |
| file_input = gr.File(label="Upload CSV file here", show_label=True, file_types=[".csv"]) | |
| # Create the Gradio interface | |
| iface3 = gr.Interface(fn=predict_sdg_from_csv, | |
| inputs= file_input, | |
| outputs=[gr.Plot(label='Frequency of SDGs', show_label=True), gr.File(label='Download output CSV', show_label=True)], | |
| title="Multi-text Prediction (CVS)", | |
| description='**NOTE:** The column to be analysed must be titled ***text_inputs***') | |
| demo = gr.TabbedInterface(interface_list = [iface1, iface2, iface3], | |
| tab_names = ["General-App-Info", "Single-Text-Prediction", "Multi-Text-Prediction (CSV)"], | |
| title = "Sustainble Development Goals (SDG) Text Classifier App", | |
| theme = 'soft' | |
| ) | |
| # Run the interface | |
| demo.queue().launch() |