| import logging |
| import torch |
|
|
| logger = logging.getLogger(__name__) |
|
|
| class ContextualWeightOverrideAgent: |
| def __init__(self): |
| self.context_overrides = { |
| |
| "outdoor": { |
| "model_1": 0.8, |
| "model_5": 1.2, |
| }, |
| "low_light": { |
| "model_2": 0.7, |
| "model_7": 1.3, |
| }, |
| "sunny": { |
| "model_3": 0.9, |
| "model_4": 1.1, |
| } |
| |
| } |
|
|
| def get_overrides(self, context_tags: list[str]) -> dict: |
| """Returns combined weight overrides for given context tags.""" |
| combined_overrides = {} |
| for tag in context_tags: |
| if tag in self.context_overrides: |
| for model_id, multiplier in self.context_overrides[tag].items(): |
| |
| |
| combined_overrides[model_id] = combined_overrides.get(model_id, 1.0) * multiplier |
| return combined_overrides |
|
|
|
|
| class ModelWeightManager: |
| def __init__(self): |
| self.base_weights = { |
| "model_1": 0.15, |
| "model_2": 0.15, |
| "model_3": 0.15, |
| "model_4": 0.15, |
| "model_5": 0.15, |
| "model_5b": 0.10, |
| "model_6": 0.10, |
| "model_7": 0.05 |
| } |
| self.situation_weights = { |
| "high_confidence": 1.2, |
| "low_confidence": 0.8, |
| "conflict": 0.5, |
| "consensus": 1.5 |
| } |
| self.context_override_agent = ContextualWeightOverrideAgent() |
| |
| def adjust_weights(self, predictions, confidence_scores, context_tags: list[str] = None): |
| """Dynamically adjust weights based on prediction patterns and optional context.""" |
| adjusted_weights = self.base_weights.copy() |
|
|
| |
| if context_tags: |
| overrides = self.context_override_agent.get_overrides(context_tags) |
| for model_id, multiplier in overrides.items(): |
| adjusted_weights[model_id] = adjusted_weights.get(model_id, 0.0) * multiplier |
| |
| |
| |
| if self._has_consensus(predictions): |
| for model in adjusted_weights: |
| adjusted_weights[model] *= self.situation_weights["consensus"] |
| |
| |
| if self._has_conflicts(predictions): |
| for model in adjusted_weights: |
| adjusted_weights[model] *= self.situation_weights["conflict"] |
| |
| |
| for model, confidence in confidence_scores.items(): |
| if confidence > 0.8: |
| adjusted_weights[model] *= self.situation_weights["high_confidence"] |
| elif confidence < 0.5: |
| adjusted_weights[model] *= self.situation_weights["low_confidence"] |
| |
| return self._normalize_weights(adjusted_weights) |
| |
| def _has_consensus(self, predictions): |
| """Check if models agree on prediction""" |
| |
| non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"] |
| return len(non_none_predictions) > 0 and len(set(non_none_predictions)) == 1 |
| |
| def _has_conflicts(self, predictions): |
| """Check if models have conflicting predictions""" |
| |
| non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"] |
| return len(non_none_predictions) > 1 and len(set(non_none_predictions)) > 1 |
| |
| def _normalize_weights(self, weights): |
| """Normalize weights to sum to 1""" |
| total = sum(weights.values()) |
| if total == 0: |
| |
| |
| logger.warning("All weights became zero after adjustments. Reverting to base weights.") |
| return {k: 1.0/len(self.base_weights) for k in self.base_weights} |
| return {k: v/total for k, v in weights.items()} |