| import logging
|
| import torch
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
| class ContextualWeightOverrideAgent:
|
| def __init__(self):
|
| self.context_overrides = {
|
|
|
| "outdoor": {
|
| "model_1": 0.8,
|
| "model_5": 1.2,
|
| },
|
| "low_light": {
|
| "model_2": 0.7,
|
| "model_7": 1.3,
|
| },
|
| "sunny": {
|
| "model_3": 0.9,
|
| "model_4": 1.1,
|
| }
|
|
|
| }
|
|
|
| def get_overrides(self, context_tags: list[str]) -> dict:
|
| """Returns combined weight overrides for given context tags."""
|
| combined_overrides = {}
|
| for tag in context_tags:
|
| if tag in self.context_overrides:
|
| for model_id, multiplier in self.context_overrides[tag].items():
|
|
|
|
|
| combined_overrides[model_id] = combined_overrides.get(model_id, 1.0) * multiplier
|
| return combined_overrides
|
|
|
|
|
| class ModelWeightManager:
|
| def __init__(self):
|
| self.base_weights = {
|
| "model_1": 0.15,
|
| "model_2": 0.15,
|
| "model_3": 0.15,
|
| "model_4": 0.15,
|
| "model_5": 0.15,
|
| "model_5b": 0.10,
|
| "model_6": 0.10,
|
| "model_7": 0.05
|
| }
|
| self.situation_weights = {
|
| "high_confidence": 1.2,
|
| "low_confidence": 0.8,
|
| "conflict": 0.5,
|
| "consensus": 1.5
|
| }
|
| self.context_override_agent = ContextualWeightOverrideAgent()
|
|
|
| def adjust_weights(self, predictions, confidence_scores, context_tags: list[str] = None):
|
| """Dynamically adjust weights based on prediction patterns and optional context."""
|
| adjusted_weights = self.base_weights.copy()
|
|
|
|
|
| if context_tags:
|
| overrides = self.context_override_agent.get_overrides(context_tags)
|
| for model_id, multiplier in overrides.items():
|
| adjusted_weights[model_id] = adjusted_weights.get(model_id, 0.0) * multiplier
|
|
|
|
|
|
|
| if self._has_consensus(predictions):
|
| for model in adjusted_weights:
|
| adjusted_weights[model] *= self.situation_weights["consensus"]
|
|
|
|
|
| if self._has_conflicts(predictions):
|
| for model in adjusted_weights:
|
| adjusted_weights[model] *= self.situation_weights["conflict"]
|
|
|
|
|
| for model, confidence in confidence_scores.items():
|
| if confidence > 0.8:
|
| adjusted_weights[model] *= self.situation_weights["high_confidence"]
|
| elif confidence < 0.5:
|
| adjusted_weights[model] *= self.situation_weights["low_confidence"]
|
|
|
| return self._normalize_weights(adjusted_weights)
|
|
|
| def _has_consensus(self, predictions):
|
| """Check if models agree on prediction"""
|
|
|
| non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"]
|
| return len(non_none_predictions) > 0 and len(set(non_none_predictions)) == 1
|
|
|
| def _has_conflicts(self, predictions):
|
| """Check if models have conflicting predictions"""
|
|
|
| non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"]
|
| return len(non_none_predictions) > 1 and len(set(non_none_predictions)) > 1
|
|
|
| def _normalize_weights(self, weights):
|
| """Normalize weights to sum to 1"""
|
| total = sum(weights.values())
|
| if total == 0:
|
|
|
|
|
| logger.warning("All weights became zero after adjustments. Reverting to base weights.")
|
| return {k: 1.0/len(self.base_weights) for k in self.base_weights}
|
| return {k: v/total for k, v in weights.items()} |