Spaces:
Sleeping
Sleeping
Commit ·
3c5755f
1
Parent(s): b3179c4
experiment5
Browse files- absa_evaluator.py +113 -1
absa_evaluator.py
CHANGED
|
@@ -3,8 +3,10 @@ from typing import Dict, List
|
|
| 3 |
import evaluate
|
| 4 |
from datasets import Features, Sequence, Value
|
| 5 |
from sklearn.metrics import accuracy_score
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
from preprocessing import absa_term_preprocess
|
| 8 |
|
| 9 |
_CITATION = """
|
| 10 |
"""
|
|
@@ -164,3 +166,113 @@ class AbsaEvaluatorTest(evaluate.Metric):
|
|
| 164 |
"retrieved": retrieved,
|
| 165 |
"relevant": relevant,
|
| 166 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import evaluate
|
| 4 |
from datasets import Features, Sequence, Value
|
| 5 |
from sklearn.metrics import accuracy_score
|
| 6 |
+
from itertools import chain
|
| 7 |
+
from random import choice
|
| 8 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 9 |
|
|
|
|
| 10 |
|
| 11 |
_CITATION = """
|
| 12 |
"""
|
|
|
|
| 166 |
"retrieved": retrieved,
|
| 167 |
"relevant": relevant,
|
| 168 |
}
|
| 169 |
+
|
| 170 |
+
def adjust_predictions(refs, preds, choices):
|
| 171 |
+
"""Adjust predictions to match the length of references with either a special token or random choice."""
|
| 172 |
+
adjusted_preds = []
|
| 173 |
+
for ref, pred in zip(refs, preds):
|
| 174 |
+
if len(pred) < len(ref):
|
| 175 |
+
missing_count = len(ref) - len(pred)
|
| 176 |
+
pred.extend([choice(choices) for _ in range(missing_count)])
|
| 177 |
+
adjusted_preds.append(pred)
|
| 178 |
+
return adjusted_preds
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def extract_aspects(data, specific_key, specific_val):
|
| 182 |
+
"""Extracts and returns a list of specified aspect details from the nested 'aspects' data."""
|
| 183 |
+
return [item[specific_key][specific_val] for item in data]
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def absa_term_preprocess(references, predictions, subtask_key, subtask_value):
|
| 187 |
+
"""
|
| 188 |
+
Preprocess the terms and polarities for aspect-based sentiment analysis.
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
references (List[Dict]): A list of dictionaries containing the actual terms and polarities under 'aspects'.
|
| 192 |
+
predictions (List[Dict]): A list of dictionaries containing predicted aspect categories to terms and their sentiments.
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
Tuple[List[str], List[str], List[str], List[str]]: A tuple containing lists of true aspect terms,
|
| 196 |
+
adjusted predicted aspect terms, true polarities, and adjusted predicted polarities.
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
# Extract aspect terms and polarities
|
| 200 |
+
truth_aspect_terms = extract_aspects(references, subtask_key, subtask_value)
|
| 201 |
+
pred_aspect_terms = extract_aspects(predictions, subtask_key, subtask_value)
|
| 202 |
+
truth_polarities = extract_aspects(references, subtask_key, "polarity")
|
| 203 |
+
pred_polarities = extract_aspects(predictions, subtask_key, "polarity")
|
| 204 |
+
|
| 205 |
+
# Define adjustment parameters
|
| 206 |
+
special_token = "NONE" # For missing aspect terms
|
| 207 |
+
sentiment_choices = [
|
| 208 |
+
"positive",
|
| 209 |
+
"negative",
|
| 210 |
+
"neutral",
|
| 211 |
+
"conflict",
|
| 212 |
+
] # For missing polarities
|
| 213 |
+
|
| 214 |
+
# Adjust the predictions to match the length of references
|
| 215 |
+
adjusted_pred_terms = adjust_predictions(
|
| 216 |
+
truth_aspect_terms, pred_aspect_terms, [special_token]
|
| 217 |
+
)
|
| 218 |
+
adjusted_pred_polarities = adjust_predictions(
|
| 219 |
+
truth_polarities, pred_polarities, sentiment_choices
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
return (
|
| 223 |
+
flatten_list(truth_aspect_terms),
|
| 224 |
+
flatten_list(adjusted_pred_terms),
|
| 225 |
+
flatten_list(truth_polarities),
|
| 226 |
+
flatten_list(adjusted_pred_polarities),
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def flatten_list(nested_list):
|
| 231 |
+
"""Flatten a nested list into a single-level list."""
|
| 232 |
+
return list(chain.from_iterable(nested_list))
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def extract_pred_terms(
|
| 236 |
+
all_predictions: List[Dict[str, Dict[str, str]]]
|
| 237 |
+
) -> List[List]:
|
| 238 |
+
"""Extract and organize predicted terms from the sentiment analysis results."""
|
| 239 |
+
pred_aspect_terms = []
|
| 240 |
+
for pred in all_predictions:
|
| 241 |
+
terms = [term for cat in pred.values() for term in cat.keys()]
|
| 242 |
+
pred_aspect_terms.append(terms)
|
| 243 |
+
return pred_aspect_terms
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def merge_aspects_and_categories(aspects, categories):
|
| 247 |
+
result = []
|
| 248 |
+
|
| 249 |
+
# Assuming both lists are of the same length and corresponding indices match
|
| 250 |
+
for aspect, category in zip(aspects, categories):
|
| 251 |
+
combined_entry = {
|
| 252 |
+
"aspects": {"term": [], "polarity": []},
|
| 253 |
+
"category": {"category": [], "polarity": []},
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
# Process aspect entries
|
| 257 |
+
for cat_key, terms_dict in aspect.items():
|
| 258 |
+
for term, polarity in terms_dict.items():
|
| 259 |
+
combined_entry["aspects"]["term"].append(term)
|
| 260 |
+
combined_entry["aspects"]["polarity"].append(polarity)
|
| 261 |
+
|
| 262 |
+
# Add category details based on the aspect's key if available in categories
|
| 263 |
+
if cat_key in category:
|
| 264 |
+
combined_entry["category"]["category"].append(cat_key)
|
| 265 |
+
combined_entry["category"]["polarity"].append(
|
| 266 |
+
category[cat_key]
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
# Ensure all keys in category are accounted for
|
| 270 |
+
for cat_key, polarity in category.items():
|
| 271 |
+
if cat_key not in combined_entry["category"]["category"]:
|
| 272 |
+
combined_entry["category"]["category"].append(cat_key)
|
| 273 |
+
combined_entry["category"]["polarity"].append(polarity)
|
| 274 |
+
|
| 275 |
+
result.append(combined_entry)
|
| 276 |
+
|
| 277 |
+
return result
|
| 278 |
+
|