HarshaBattula
updated colors
e88d283
from pyabsa import available_checkpoints
from pyabsa import ATEPCCheckpointManager
import gradio as gr
def assign_sentiment_marker(sentiment_list, aspect_index):
"""
This function returns a sentiment marker '+' for Positive, '-' for Negative, and '$' for any other sentiment.
It maps the sentiment from the sentiment_list at the aspect_index to the respective marker.
Parameters:
sentiment_list (list): A list of sentiments.
aspect_index (int): The index of the sentiment in sentiment_list to map to a marker.
Returns:
str: The sentiment marker.
"""
# Define a dictionary to map sentiment labels to markers.
sentiment_markers = {
"Positive": "pos",
"Negative": "neg"
}
# Fetch sentiment label from the sentiment_list using aspect_index.
sentiment_label = sentiment_list[aspect_index]
# Use get method on dictionary to fetch corresponding marker or default to '$' for unlisted sentiments.
return sentiment_markers.get(sentiment_label, "neu")
def annotate_sentiment(tokens, aspect_positions, sentiment_list):
"""
Function to post-process the output from the aspect extractor model.
Annotates tokens based on sentiment, either positive (+) or negative (-).
:param tokens: list of tokens to be processed
:param aspect_positions: list of positions in tokens that refer to aspects
:param sentiment_list: list of sentiments corresponding to each aspect
:return: list of tuples where each tuple contains a token and its corresponding sentiment (if any)
"""
annotated_tokens = []
aspect_index = 0
for i, token_group in enumerate(tokens):
if aspect_index < len(aspect_positions) and i == aspect_positions[aspect_index]:
for token in token_group:
sentiment = assign_sentiment_marker(sentiment_list, aspect_index)
annotated_tokens.append((token, sentiment))
aspect_index += 1
else:
for token in token_group:
annotated_tokens.append((token, None))
# Add space between groups of tokens, if it's not the last group
if i != len(tokens) - 1:
annotated_tokens.append((' ', None))
return annotated_tokens
def annotate_text_sentiment(text):
"""
Interface function to extract aspects and their sentiments from the given text.
Uses an aspect extractor model, then processes its output for sentiment annotation.
:param text: string to be processed
:return: list of tuples where each tuple contains a token from the text and its corresponding sentiment (if any)
"""
aspect_extraction_result = aspect_extractor.extract_aspect(inference_source=[text], pred_sentiment=True)[0]
tokens = aspect_extraction_result["tokens"]
# If no aspects found, aspect_positions and sentiment_list are empty
if len(aspect_extraction_result['position']) == 0:
aspect_positions = []
sentiment_list = []
else:
aspect_positions = [position[0]-1 for position in aspect_extraction_result['position']]
sentiment_list = aspect_extraction_result["sentiment"]
print(sentiment_list)
annotated_tokens = annotate_sentiment(tokens, aspect_positions, sentiment_list)
return annotated_tokens
# Initializing the aspect extractor model
checkpoint_map = available_checkpoints()
aspect_extractor = ATEPCCheckpointManager.get_aspect_extractor(checkpoint='english',
auto_device=True # False means load model on CPU
)
demo1 = gr.Interface(
annotate_text_sentiment,
[
gr.Textbox(
label="Enter the text for aspect extraction and polarity detection",
info="Example: The food was good, but the service was terrible.",
lines=3,
),
],
gr.HighlightedText(
label="Aspect Detector based on DeBERTa",
combine_adjacent=True,
show_legend=True,
).style(color_map={"pos": "green", "neg": "red", "neu":"blue"}),
theme=gr.themes.Base()
)
demo2 = gr.Interface(
annotate_text_sentiment,
[
gr.Textbox(
label="Enter the text for polarity detection",
info="Example: The food was good, but the service was terrible.",
lines=3,
),
],
gr.HighlightedText(
label="Aspect Detector based on Relational Graph Attention Networks, and BERT",
combine_adjacent=True,
show_legend=True,
).style(color_map={"pos": "green", "neg": "red", "neu":"blue"}),
theme=gr.themes.Base()
)
demo = gr.TabbedInterface([demo1, demo2], ["Aspect Polarity Detection with Extraction", "Aspect Polarity Detection without Extraction"])
demo.launch()