Spaces:
Running
Running
| """Defuzzification: convert fuzzy rule outputs back to crisp 28-dim probabilities.""" | |
| from __future__ import annotations | |
| import numpy as np | |
| from constants import GOEMOTIONS_LABELS, NUM_GOEMOTIONS | |
| from membership import LEVEL_CENTROIDS, FUZZY_LEVELS | |
| from rule_base import FuzzyRule | |
| class Defuzzifier: | |
| """Convert fuzzy inference output to crisp emotion probabilities.""" | |
| def __init__(self, method: str = "centroid"): | |
| if method not in ("centroid", "mom", "lom"): | |
| raise ValueError(f"Unknown defuzzification method: {method}") | |
| self.method = method | |
| self._label_idx = {label: i for i, label in enumerate(GOEMOTIONS_LABELS)} | |
| def defuzzify( | |
| self, | |
| base_fuzzy: dict[str, dict[str, float]], | |
| fired_rules: list[tuple[FuzzyRule, float]], | |
| base_crisp: np.ndarray, | |
| ) -> tuple[np.ndarray, list[str]]: | |
| result = base_crisp.copy() | |
| fired_names = [] | |
| if not fired_rules: | |
| total = result.sum() | |
| if total > 0: | |
| result = result / total | |
| return result, fired_names | |
| rule_adjustments = np.zeros(NUM_GOEMOTIONS) | |
| rule_weights = np.zeros(NUM_GOEMOTIONS) | |
| for rule, activation in fired_rules: | |
| fired_names.append(rule.name) | |
| for emotion_label, (target_level, weight_modifier) in rule.consequent.items(): | |
| if emotion_label not in self._label_idx: | |
| continue | |
| idx = self._label_idx[emotion_label] | |
| target_centroid = LEVEL_CENTROIDS.get(target_level, 0.35) | |
| effective_target = target_centroid * activation * weight_modifier | |
| if effective_target > rule_adjustments[idx]: | |
| rule_adjustments[idx] = effective_target | |
| rule_weights[idx] = activation | |
| for i in range(NUM_GOEMOTIONS): | |
| if rule_weights[i] > 0: | |
| blend = rule_weights[i] | |
| result[i] = (1.0 - blend) * result[i] + blend * rule_adjustments[i] | |
| if self.method == "centroid": | |
| result = self._centroid_refine(result, base_fuzzy) | |
| elif self.method == "mom": | |
| result = self._mom_refine(result, base_fuzzy) | |
| elif self.method == "lom": | |
| result = self._lom_refine(result, base_fuzzy) | |
| result = np.maximum(result, 0) | |
| total = result.sum() | |
| if total > 0: | |
| result = result / total | |
| else: | |
| result = np.ones(NUM_GOEMOTIONS) / NUM_GOEMOTIONS | |
| return result, fired_names | |
| def _centroid_refine(self, crisp, fuzzy): | |
| for label, memberships in fuzzy.items(): | |
| if label not in self._label_idx: | |
| continue | |
| idx = self._label_idx[label] | |
| numerator = 0.0 | |
| denominator = 0.0 | |
| for level, mu in memberships.items(): | |
| if mu > 0 and level in LEVEL_CENTROIDS: | |
| numerator += mu * LEVEL_CENTROIDS[level] | |
| denominator += mu | |
| if denominator > 0: | |
| cog = numerator / denominator | |
| crisp[idx] = 0.7 * crisp[idx] + 0.3 * cog | |
| return crisp | |
| def _mom_refine(self, crisp, fuzzy): | |
| for label, memberships in fuzzy.items(): | |
| if label not in self._label_idx: | |
| continue | |
| idx = self._label_idx[label] | |
| if not memberships: | |
| continue | |
| max_mu = max(memberships.values()) | |
| if max_mu < 0.01: | |
| continue | |
| max_levels = [ | |
| level for level, mu in memberships.items() | |
| if mu >= max_mu - 0.01 and level in LEVEL_CENTROIDS | |
| ] | |
| if max_levels: | |
| mom = np.mean([LEVEL_CENTROIDS[l] for l in max_levels]) | |
| crisp[idx] = 0.7 * crisp[idx] + 0.3 * mom | |
| return crisp | |
| def _lom_refine(self, crisp, fuzzy): | |
| for label, memberships in fuzzy.items(): | |
| if label not in self._label_idx: | |
| continue | |
| idx = self._label_idx[label] | |
| if not memberships: | |
| continue | |
| max_mu = max(memberships.values()) | |
| if max_mu < 0.01: | |
| continue | |
| best_level = "absent" | |
| for level in reversed(FUZZY_LEVELS): | |
| if memberships.get(level, 0) >= max_mu - 0.01: | |
| best_level = level | |
| break | |
| if best_level in LEVEL_CENTROIDS: | |
| lom = LEVEL_CENTROIDS[best_level] | |
| crisp[idx] = 0.7 * crisp[idx] + 0.3 * lom | |
| return crisp | |