| import logging |
| import torch |
| from utils.registry import MODEL_REGISTRY |
| from utils.agent_logger import AgentLogger |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
| class ContextualWeightOverrideAgent: |
| def __init__(self): |
| agent_logger = AgentLogger() |
| agent_logger.log("weight_optimization", "info", "Initializing ContextualWeightOverrideAgent.") |
| self.context_overrides = { |
| |
| "outdoor": { |
| "model_1": 0.8, |
| "model_5": 1.2, |
| }, |
| "low_light": { |
| "model_2": 0.7, |
| "model_7": 1.3, |
| }, |
| "sunny": { |
| "model_3": 0.9, |
| "model_4": 1.1, |
| } |
| |
| } |
|
|
| def get_overrides(self, context_tags: list[str]) -> dict: |
| agent_logger.log("weight_optimization", "info", f"Getting weight overrides for context tags: {context_tags}") |
| combined_overrides = {} |
| for tag in context_tags: |
| if tag in self.context_overrides: |
| for model_id, multiplier in self.context_overrides[tag].items(): |
| |
| |
| combined_overrides[model_id] = combined_overrides.get(model_id, 1.0) * multiplier |
| agent_logger.log("weight_optimization", "info", f"Combined context overrides: {combined_overrides}") |
| return combined_overrides |
|
|
|
|
|
|
| class ModelWeightManager: |
| def __init__(self, strongest_model_id: str = None): |
| agent_logger = AgentLogger() |
| agent_logger.log("weight_optimization", "info", f"Initializing ModelWeightManager. Strongest model: {strongest_model_id}") |
| |
| num_models = len(MODEL_REGISTRY) |
| if num_models > 0: |
| if strongest_model_id and strongest_model_id in MODEL_REGISTRY: |
| logger.info(f"Designating '{strongest_model_id}' as the strongest model.") |
| |
| strongest_weight_share = 0.5 |
| self.base_weights = {strongest_model_id: strongest_weight_share} |
| remaining_models = [mid for mid in MODEL_REGISTRY.keys() if mid != strongest_model_id] |
| if remaining_models: |
| other_models_weight_share = (1.0 - strongest_weight_share) / len(remaining_models) |
| for model_id in remaining_models: |
| self.base_weights[model_id] = other_models_weight_share |
| else: |
| self.base_weights[strongest_model_id] = 1.0 |
| else: |
| if strongest_model_id and strongest_model_id not in MODEL_REGISTRY: |
| logger.warning(f"Strongest model ID '{strongest_model_id}' not found in MODEL_REGISTRY. Distributing weights equally.") |
| initial_weight = 1.0 / num_models |
| self.base_weights = {model_id: initial_weight for model_id in MODEL_REGISTRY.keys()} |
| else: |
| self.base_weights = {} |
| logger.info(f"Base weights initialized: {self.base_weights}") |
| |
| self.situation_weights = { |
| "high_confidence": 1.2, |
| "low_confidence": 0.8, |
| "conflict": 0.5, |
| "consensus": 1.5 |
| } |
| self.context_override_agent = ContextualWeightOverrideAgent() |
| |
| def adjust_weights(self, predictions, confidence_scores, context_tags: list[str] = None): |
| """Dynamically adjust weights based on prediction patterns and optional context.""" |
| agent_logger.log("weight_optimization", "info", "Adjusting model weights.") |
| adjusted_weights = self.base_weights.copy() |
| agent_logger.log("weight_optimization", "info", f"Initial adjusted weights (copy of base): {adjusted_weights}") |
|
|
| |
| if context_tags: |
| logger.info(f"Applying contextual overrides for tags: {context_tags}") |
| overrides = self.context_override_agent.get_overrides(context_tags) |
| for model_id, multiplier in overrides.items(): |
| adjusted_weights[model_id] = adjusted_weights.get(model_id, 0.0) * multiplier |
| agent_logger.log("weight_optimization", "info", f"Adjusted weights after context overrides: {adjusted_weights}") |
| |
| |
| |
| has_consensus = self._has_consensus(predictions) |
| if has_consensus: |
| agent_logger.log("weight_optimization", "info", "Consensus detected. Boosting weights for consensus.") |
| for model in adjusted_weights: |
| adjusted_weights[model] *= self.situation_weights["consensus"] |
| agent_logger.log("weight_optimization", "info", f"Adjusted weights after consensus boost: {adjusted_weights}") |
| |
| |
| has_conflicts = self._has_conflicts(predictions) |
| if has_conflicts: |
| agent_logger.log("weight_optimization", "info", "Conflicts detected. Reducing weights for conflict.") |
| for model in adjusted_weights: |
| adjusted_weights[model] *= self.situation_weights["conflict"] |
| agent_logger.log("weight_optimization", "info", f"Adjusted weights after conflict reduction: {adjusted_weights}") |
| |
| |
| logger.info("Adjusting weights based on model confidence scores.") |
| for model, confidence in confidence_scores.items(): |
| if confidence > 0.8: |
| adjusted_weights[model] *= self.situation_weights["high_confidence"] |
| agent_logger.log("weight_optimization", "info", f"Model '{model}' has high confidence ({confidence:.2f}). Weight boosted.") |
| elif confidence < 0.5: |
| adjusted_weights[model] *= self.situation_weights["low_confidence"] |
| agent_logger.log("weight_optimization", "info", f"Model '{model}' has low confidence ({confidence:.2f}). Weight reduced.") |
| logger.info(f"Adjusted weights before normalization: {adjusted_weights}") |
| |
| normalized_weights = self._normalize_weights(adjusted_weights) |
| logger.info(f"Final normalized adjusted weights: {normalized_weights}") |
| return normalized_weights |
| |
| def _has_consensus(self, predictions): |
| """Check if models agree on prediction""" |
| agent_logger.log("weight_optimization", "info", "Checking for consensus among model predictions.") |
| non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"] |
| logger.debug(f"Non-none predictions for consensus check: {non_none_predictions}") |
| result = len(non_none_predictions) > 0 and len(set(non_none_predictions)) == 1 |
| logger.info(f"Consensus detected: {result}") |
| return result |
| |
| def _has_conflicts(self, predictions): |
| """Check if models have conflicting predictions""" |
| agent_logger.log("weight_optimization", "info", "Checking for conflicts among model predictions.") |
| non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"] |
| logger.debug(f"Non-none predictions for conflict check: {non_none_predictions}") |
| result = len(non_none_predictions) > 1 and len(set(non_none_predictions)) > 1 |
| logger.info(f"Conflicts detected: {result}") |
| return result |
| |
| def _normalize_weights(self, weights): |
| """Normalize weights to sum to 1""" |
| agent_logger.log("weight_optimization", "info", "Normalizing weights.") |
| total = sum(weights.values()) |
| if total == 0: |
| agent_logger.log("weight_optimization", "warning", "All weights became zero after adjustments. Reverting to equal base weights for registered models.") |
| |
| num_registered_models = len(MODEL_REGISTRY) |
| if num_registered_models > 0: |
| return {k: 1.0/num_registered_models for k in MODEL_REGISTRY.keys()} |
| else: |
| return {} |
| normalized = {k: v/total for k, v in weights.items()} |
| agent_logger.log("weight_optimization", "info", f"Weights normalized. Total sum: {sum(normalized.values()):.2f}") |
| return normalized |