wanderlust.ai / src /wanderlust_ai /core /conflict_resolution.py
BlakeL's picture
Upload 115 files
3f9f85b verified
"""
Conflict Resolution for Multi-Source Data
This module handles incomplete or conflicting data from multiple sources
by implementing intelligent conflict resolution strategies and data fusion.
"""
import logging
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Union, Tuple, Set
from dataclasses import dataclass, field
from enum import Enum
import statistics
import math
from collections import defaultdict
from .data_quality_scoring import QualityScore, QualityDimension
from .data_cleaning import DataQualityLevel
class ConflictType(str, Enum):
"""Types of conflicts in data."""
MISSING_DATA = "missing_data" # Some sources have missing fields
CONFLICTING_VALUES = "conflicting_values" # Different values for same field
INCONSISTENT_FORMATS = "inconsistent_formats" # Same data in different formats
TEMPORAL_DISCREPANCY = "temporal_discrepancy" # Data from different time periods
SOURCE_RELIABILITY = "source_reliability" # Sources have different reliability
PRICE_DISCREPANCY = "price_discrepancy" # Significant price differences
SCHEDULE_CONFLICT = "schedule_conflict" # Conflicting time/schedule data
class ResolutionStrategy(str, Enum):
"""Strategies for resolving conflicts."""
MOST_RECENT = "most_recent" # Use most recent data
MOST_RELIABLE = "most_reliable" # Use most reliable source
CONSENSUS = "consensus" # Use majority consensus
WEIGHTED_AVERAGE = "weighted_average" # Weighted average based on reliability
HIGHEST_QUALITY = "highest_quality" # Use highest quality data
USER_PREFERENCE = "user_preference" # Use user-specified preference
INTELLIGENT_FUSION = "intelligent_fusion" # Combine using ML/statistical methods
@dataclass
class DataConflict:
"""Represents a conflict between data sources."""
conflict_type: ConflictType
field_name: str
conflicting_values: Dict[str, Any] # source -> value
source_qualities: Dict[str, QualityScore] # source -> quality score
severity: float # 0.0 to 1.0
suggested_resolution: Any = None
resolution_strategy: ResolutionStrategy = ResolutionStrategy.INTELLIGENT_FUSION
confidence: float = 0.0
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class ConflictResolution:
"""Result of conflict resolution."""
resolved_data: Dict[str, Any]
conflicts_resolved: List[DataConflict]
resolution_strategies_used: Dict[str, ResolutionStrategy]
confidence_score: float # Overall confidence in resolution
warnings: List[str] = field(default_factory=list)
metadata: Dict[str, Any] = field(default_factory=dict)
class ConflictResolver:
"""
Intelligent conflict resolver that handles conflicting data from multiple
sources using various resolution strategies.
"""
def __init__(self):
self.logger = logging.getLogger(__name__)
# Resolution strategy preferences by conflict type
self._strategy_preferences = {
ConflictType.MISSING_DATA: ResolutionStrategy.MOST_RELIABLE,
ConflictType.CONFLICTING_VALUES: ResolutionStrategy.WEIGHTED_AVERAGE,
ConflictType.INCONSISTENT_FORMATS: ResolutionStrategy.HIGHEST_QUALITY,
ConflictType.TEMPORAL_DISCREPANCY: ResolutionStrategy.MOST_RECENT,
ConflictType.SOURCE_RELIABILITY: ResolutionStrategy.MOST_RELIABLE,
ConflictType.PRICE_DISCREPANCY: ResolutionStrategy.WEIGHTED_AVERAGE,
ConflictType.SCHEDULE_CONFLICT: ResolutionStrategy.CONSENSUS
}
# Field-specific resolution strategies
self._field_strategies = {
"price": ResolutionStrategy.WEIGHTED_AVERAGE,
"cost": ResolutionStrategy.WEIGHTED_AVERAGE,
"departure_time": ResolutionStrategy.MOST_RELIABLE,
"arrival_time": ResolutionStrategy.MOST_RELIABLE,
"airline": ResolutionStrategy.CONSENSUS,
"flight_number": ResolutionStrategy.MOST_RELIABLE,
"hotel_name": ResolutionStrategy.CONSENSUS,
"rating": ResolutionStrategy.WEIGHTED_AVERAGE,
"location": ResolutionStrategy.MOST_RELIABLE,
"availability": ResolutionStrategy.MOST_RECENT
}
# Thresholds for conflict detection
self._conflict_thresholds = {
"price_variance": 0.3, # 30% price difference triggers conflict
"time_difference_minutes": 60, # 1 hour time difference
"rating_difference": 1.0, # 1 star rating difference
"quality_threshold": 0.2 # 20% quality score difference
}
def resolve_conflicts(self,
multi_source_data: Dict[str, Dict[str, Any]],
source_qualities: Optional[Dict[str, QualityScore]] = None,
user_preferences: Optional[Dict[str, Any]] = None) -> ConflictResolution:
"""
Resolve conflicts in multi-source data.
Args:
multi_source_data: Dict mapping source names to their data
source_qualities: Optional quality scores for each source
user_preferences: Optional user preferences for resolution
Returns:
ConflictResolution with resolved data and conflict details
"""
if not multi_source_data:
return ConflictResolution(
resolved_data={},
conflicts_resolved=[],
resolution_strategies_used={},
confidence_score=0.0,
warnings=["No data provided for conflict resolution"]
)
if len(multi_source_data) == 1:
# Single source - no conflicts to resolve
source_name = list(multi_source_data.keys())[0]
return ConflictResolution(
resolved_data=multi_source_data[source_name],
conflicts_resolved=[],
resolution_strategies_used={},
confidence_score=1.0,
metadata={"single_source": source_name}
)
# Detect conflicts
conflicts = self._detect_conflicts(multi_source_data, source_qualities)
# Resolve conflicts
resolved_data = {}
resolution_strategies = {}
warnings = []
# Get all unique field names
all_fields = set()
for source_data in multi_source_data.values():
all_fields.update(source_data.keys())
for field in all_fields:
field_conflicts = [c for c in conflicts if c.field_name == field]
if field_conflicts:
# Resolve conflict for this field
resolution = self._resolve_field_conflict(
field, field_conflicts, user_preferences
)
resolved_data[field] = resolution["value"]
resolution_strategies[field] = resolution["strategy"]
if resolution.get("warning"):
warnings.append(resolution["warning"])
else:
# No conflict - use most reliable source
best_source = self._get_most_reliable_source_for_field(
field, multi_source_data, source_qualities
)
if best_source:
resolved_data[field] = multi_source_data[best_source].get(field)
resolution_strategies[field] = ResolutionStrategy.MOST_RELIABLE
# Calculate overall confidence
confidence_score = self._calculate_resolution_confidence(
conflicts, resolution_strategies, source_qualities
)
return ConflictResolution(
resolved_data=resolved_data,
conflicts_resolved=conflicts,
resolution_strategies_used=resolution_strategies,
confidence_score=confidence_score,
warnings=warnings,
metadata={
"sources_processed": list(multi_source_data.keys()),
"total_fields": len(all_fields),
"conflicts_detected": len(conflicts),
"resolution_timestamp": datetime.now().isoformat()
}
)
def _detect_conflicts(self,
multi_source_data: Dict[str, Dict[str, Any]],
source_qualities: Optional[Dict[str, QualityScore]]) -> List[DataConflict]:
"""Detect conflicts between data sources."""
conflicts = []
# Get all unique field names
all_fields = set()
for source_data in multi_source_data.values():
all_fields.update(source_data.keys())
for field in all_fields:
field_values = {}
field_qualities = {}
# Collect values and qualities for this field
for source, data in multi_source_data.items():
if field in data:
field_values[source] = data[field]
if source_qualities and source in source_qualities:
field_qualities[source] = source_qualities[source]
# Detect conflicts for this field
field_conflicts = self._detect_field_conflicts(
field, field_values, field_qualities
)
conflicts.extend(field_conflicts)
return conflicts
def _detect_field_conflicts(self,
field_name: str,
field_values: Dict[str, Any],
field_qualities: Dict[str, QualityScore]) -> List[DataConflict]:
"""Detect conflicts for a specific field."""
conflicts = []
if len(field_values) <= 1:
return conflicts
# Check for missing data conflicts
if len(field_values) < len(field_qualities):
missing_sources = set(field_qualities.keys()) - set(field_values.keys())
if missing_sources:
conflicts.append(DataConflict(
conflict_type=ConflictType.MISSING_DATA,
field_name=field_name,
conflicting_values=field_values,
source_qualities=field_qualities,
severity=0.3,
metadata={"missing_sources": list(missing_sources)}
))
# Check for conflicting values
if self._has_conflicting_values(field_name, field_values):
conflict_type = self._get_conflict_type_for_field(field_name)
severity = self._calculate_conflict_severity(field_name, field_values)
conflicts.append(DataConflict(
conflict_type=conflict_type,
field_name=field_name,
conflicting_values=field_values,
source_qualities=field_qualities,
severity=severity,
suggested_resolution=self._suggest_resolution(field_name, field_values, field_qualities),
metadata={"conflict_detected": True}
))
return conflicts
def _has_conflicting_values(self, field_name: str, field_values: Dict[str, Any]) -> bool:
"""Check if field values conflict with each other."""
values = list(field_values.values())
if not values:
return False
# Check for exact matches
if len(set(str(v) for v in values)) == 1:
return False
# Field-specific conflict detection
if field_name in ["price", "cost"]:
return self._has_price_conflict(values)
elif field_name in ["departure_time", "arrival_time"]:
return self._has_time_conflict(values)
elif field_name in ["rating", "score"]:
return self._has_rating_conflict(values)
elif field_name in ["airline", "hotel_name"]:
return self._has_text_conflict(values)
else:
# Generic conflict detection
return len(set(str(v).lower() for v in values)) > 1
def _has_price_conflict(self, values: List[Any]) -> bool:
"""Check for price conflicts."""
numeric_values = []
for value in values:
try:
# Extract numeric value
if isinstance(value, str):
# Remove currency symbols and commas
cleaned = value.replace('$', '').replace(',', '').strip()
numeric_values.append(float(cleaned))
elif isinstance(value, (int, float)):
numeric_values.append(float(value))
except (ValueError, TypeError):
continue
if len(numeric_values) < 2:
return False
# Check if variance exceeds threshold
mean_price = statistics.mean(numeric_values)
max_variance = max(abs(v - mean_price) / mean_price for v in numeric_values)
return max_variance > self._conflict_thresholds["price_variance"]
def _has_time_conflict(self, values: List[Any]) -> bool:
"""Check for time conflicts."""
parsed_times = []
for value in values:
try:
if isinstance(value, str):
# Try to parse various time formats
parsed = datetime.fromisoformat(value.replace('Z', '+00:00'))
elif isinstance(value, datetime):
parsed = value
else:
continue
parsed_times.append(parsed)
except (ValueError, TypeError):
continue
if len(parsed_times) < 2:
return False
# Check if time difference exceeds threshold
time_diffs = []
for i in range(len(parsed_times)):
for j in range(i + 1, len(parsed_times)):
diff_minutes = abs((parsed_times[i] - parsed_times[j]).total_seconds() / 60)
time_diffs.append(diff_minutes)
max_diff = max(time_diffs) if time_diffs else 0
return max_diff > self._conflict_thresholds["time_difference_minutes"]
def _has_rating_conflict(self, values: List[Any]) -> bool:
"""Check for rating conflicts."""
numeric_values = []
for value in values:
try:
if isinstance(value, (int, float)):
numeric_values.append(float(value))
elif isinstance(value, str):
# Extract numeric part
import re
match = re.search(r'(\d+\.?\d*)', value)
if match:
numeric_values.append(float(match.group(1)))
except (ValueError, TypeError):
continue
if len(numeric_values) < 2:
return False
# Check if rating difference exceeds threshold
rating_diff = max(numeric_values) - min(numeric_values)
return rating_diff > self._conflict_thresholds["rating_difference"]
def _has_text_conflict(self, values: List[Any]) -> bool:
"""Check for text conflicts."""
# Normalize text values
normalized_values = []
for value in values:
if isinstance(value, str):
normalized = value.lower().strip()
normalized_values.append(normalized)
# Check for significant differences
unique_values = set(normalized_values)
return len(unique_values) > 1
def _get_conflict_type_for_field(self, field_name: str) -> ConflictType:
"""Get the conflict type for a specific field."""
if field_name in ["price", "cost"]:
return ConflictType.PRICE_DISCREPANCY
elif field_name in ["departure_time", "arrival_time"]:
return ConflictType.SCHEDULE_CONFLICT
elif field_name in ["airline", "hotel_name", "location"]:
return ConflictType.CONFLICTING_VALUES
else:
return ConflictType.CONFLICTING_VALUES
def _calculate_conflict_severity(self, field_name: str, field_values: Dict[str, Any]) -> float:
"""Calculate the severity of a conflict."""
values = list(field_values.values())
if field_name in ["price", "cost"]:
return self._calculate_price_conflict_severity(values)
elif field_name in ["departure_time", "arrival_time"]:
return self._calculate_time_conflict_severity(values)
elif field_name in ["rating", "score"]:
return self._calculate_rating_conflict_severity(values)
else:
# Generic severity calculation
unique_values = len(set(str(v).lower() for v in values))
return min(1.0, (unique_values - 1) / len(values))
def _calculate_price_conflict_severity(self, values: List[Any]) -> float:
"""Calculate price conflict severity."""
numeric_values = []
for value in values:
try:
if isinstance(value, (int, float)):
numeric_values.append(float(value))
elif isinstance(value, str):
cleaned = value.replace('$', '').replace(',', '').strip()
numeric_values.append(float(cleaned))
except (ValueError, TypeError):
continue
if len(numeric_values) < 2:
return 0.0
mean_price = statistics.mean(numeric_values)
max_variance = max(abs(v - mean_price) / mean_price for v in numeric_values)
# Convert variance to severity (0-1 scale)
return min(1.0, max_variance / 2.0) # 100% variance = 0.5 severity
def _calculate_time_conflict_severity(self, values: List[Any]) -> float:
"""Calculate time conflict severity."""
parsed_times = []
for value in values:
try:
if isinstance(value, str):
parsed = datetime.fromisoformat(value.replace('Z', '+00:00'))
elif isinstance(value, datetime):
parsed = value
else:
continue
parsed_times.append(parsed)
except (ValueError, TypeError):
continue
if len(parsed_times) < 2:
return 0.0
# Calculate maximum time difference in hours
max_diff_hours = 0
for i in range(len(parsed_times)):
for j in range(i + 1, len(parsed_times)):
diff_hours = abs((parsed_times[i] - parsed_times[j]).total_seconds() / 3600)
max_diff_hours = max(max_diff_hours, diff_hours)
# Convert hours to severity (0-1 scale)
return min(1.0, max_diff_hours / 24) # 24 hours = 1.0 severity
def _calculate_rating_conflict_severity(self, values: List[Any]) -> float:
"""Calculate rating conflict severity."""
numeric_values = []
for value in values:
try:
if isinstance(value, (int, float)):
numeric_values.append(float(value))
elif isinstance(value, str):
import re
match = re.search(r'(\d+\.?\d*)', value)
if match:
numeric_values.append(float(match.group(1)))
except (ValueError, TypeError):
continue
if len(numeric_values) < 2:
return 0.0
rating_diff = max(numeric_values) - min(numeric_values)
# Convert rating difference to severity (0-1 scale)
return min(1.0, rating_diff / 5.0) # 5-point rating difference = 1.0 severity
def _suggest_resolution(self,
field_name: str,
field_values: Dict[str, Any],
field_qualities: Dict[str, QualityScore]) -> Any:
"""Suggest a resolution for a field conflict."""
strategy = self._get_resolution_strategy(field_name)
if strategy == ResolutionStrategy.MOST_RELIABLE:
return self._resolve_most_reliable(field_values, field_qualities)
elif strategy == ResolutionStrategy.WEIGHTED_AVERAGE:
return self._resolve_weighted_average(field_name, field_values, field_qualities)
elif strategy == ResolutionStrategy.CONSENSUS:
return self._resolve_consensus(field_values)
elif strategy == ResolutionStrategy.MOST_RECENT:
return self._resolve_most_recent(field_values, field_qualities)
else:
return self._resolve_intelligent_fusion(field_name, field_values, field_qualities)
def _get_resolution_strategy(self, field_name: str) -> ResolutionStrategy:
"""Get the resolution strategy for a field."""
# Check field-specific strategy first
if field_name in self._field_strategies:
return self._field_strategies[field_name]
# Default strategy
return ResolutionStrategy.INTELLIGENT_FUSION
def _resolve_most_reliable(self,
field_values: Dict[str, Any],
field_qualities: Dict[str, QualityScore]) -> Any:
"""Resolve using most reliable source."""
if not field_qualities:
# If no quality info, return first value
return next(iter(field_values.values()))
best_source = max(field_qualities.keys(),
key=lambda s: field_qualities[s].overall_score)
return field_values.get(best_source)
def _resolve_weighted_average(self,
field_name: str,
field_values: Dict[str, Any],
field_qualities: Dict[str, QualityScore]) -> Any:
"""Resolve using weighted average."""
if field_name in ["price", "cost", "rating", "score"]:
return self._resolve_numeric_weighted_average(field_values, field_qualities)
else:
# For non-numeric fields, fall back to most reliable
return self._resolve_most_reliable(field_values, field_qualities)
def _resolve_numeric_weighted_average(self,
field_values: Dict[str, Any],
field_qualities: Dict[str, QualityScore]) -> float:
"""Resolve numeric fields using weighted average."""
weighted_sum = 0.0
total_weight = 0.0
for source, value in field_values.items():
try:
# Extract numeric value
if isinstance(value, str):
cleaned = value.replace('$', '').replace(',', '').strip()
numeric_value = float(cleaned)
elif isinstance(value, (int, float)):
numeric_value = float(value)
else:
continue
# Get weight from quality score
weight = field_qualities.get(source, QualityScore(0.5, None, {})).overall_score
weighted_sum += numeric_value * weight
total_weight += weight
except (ValueError, TypeError):
continue
if total_weight == 0:
# Fallback to simple average
numeric_values = []
for value in field_values.values():
try:
if isinstance(value, str):
cleaned = value.replace('$', '').replace(',', '').strip()
numeric_values.append(float(cleaned))
elif isinstance(value, (int, float)):
numeric_values.append(float(value))
except (ValueError, TypeError):
continue
return statistics.mean(numeric_values) if numeric_values else 0.0
return weighted_sum / total_weight
def _resolve_consensus(self, field_values: Dict[str, Any]) -> Any:
"""Resolve using majority consensus."""
# Count occurrences of each value
value_counts = defaultdict(int)
for value in field_values.values():
# Normalize value for comparison
normalized = str(value).lower().strip()
value_counts[normalized] += 1
# Return most common value
if value_counts:
most_common = max(value_counts.items(), key=lambda x: x[1])
# Find original value that matches most common
for value in field_values.values():
if str(value).lower().strip() == most_common[0]:
return value
# Fallback to first value
return next(iter(field_values.values()))
def _resolve_most_recent(self,
field_values: Dict[str, Any],
field_qualities: Dict[str, QualityScore]) -> Any:
"""Resolve using most recent data."""
if not field_qualities:
return next(iter(field_values.values()))
# Find source with most recent timestamp
most_recent_source = None
most_recent_time = None
for source, quality in field_qualities.items():
if hasattr(quality, 'timestamp'):
if most_recent_time is None or quality.timestamp > most_recent_time:
most_recent_time = quality.timestamp
most_recent_source = source
if most_recent_source:
return field_values.get(most_recent_source)
# Fallback to first value
return next(iter(field_values.values()))
def _resolve_intelligent_fusion(self,
field_name: str,
field_values: Dict[str, Any],
field_qualities: Dict[str, QualityScore]) -> Any:
"""Resolve using intelligent fusion of multiple strategies."""
# For now, use weighted average for numeric fields, most reliable for others
if field_name in ["price", "cost", "rating", "score"]:
return self._resolve_numeric_weighted_average(field_values, field_qualities)
else:
return self._resolve_most_reliable(field_values, field_qualities)
def _resolve_field_conflict(self,
field_name: str,
field_conflicts: List[DataConflict],
user_preferences: Optional[Dict[str, Any]]) -> Dict[str, Any]:
"""Resolve conflicts for a specific field."""
if not field_conflicts:
return {"value": None, "strategy": None}
# Use the most severe conflict for resolution
primary_conflict = max(field_conflicts, key=lambda c: c.severity)
# Check if user has preference for this field
if user_preferences and field_name in user_preferences:
return {
"value": user_preferences[field_name],
"strategy": ResolutionStrategy.USER_PREFERENCE
}
# Use suggested resolution if available
if primary_conflict.suggested_resolution is not None:
return {
"value": primary_conflict.suggested_resolution,
"strategy": primary_conflict.resolution_strategy
}
# Fallback resolution
return {
"value": next(iter(primary_conflict.conflicting_values.values())),
"strategy": ResolutionStrategy.MOST_RELIABLE,
"warning": f"Used fallback resolution for {field_name}"
}
def _get_most_reliable_source_for_field(self,
field_name: str,
multi_source_data: Dict[str, Dict[str, Any]],
source_qualities: Optional[Dict[str, QualityScore]]) -> Optional[str]:
"""Get the most reliable source for a field."""
if not source_qualities:
# Return first source that has this field
for source, data in multi_source_data.items():
if field_name in data:
return source
return None
# Find source with highest quality that has this field
best_source = None
best_quality = -1
for source, quality in source_qualities.items():
if source in multi_source_data and field_name in multi_source_data[source]:
if quality.overall_score > best_quality:
best_quality = quality.overall_score
best_source = source
return best_source
def _calculate_resolution_confidence(self,
conflicts: List[DataConflict],
resolution_strategies: Dict[str, ResolutionStrategy],
source_qualities: Optional[Dict[str, QualityScore]]) -> float:
"""Calculate overall confidence in the resolution."""
if not conflicts:
return 1.0
# Base confidence
base_confidence = 1.0
# Reduce confidence based on conflict severity
total_severity = sum(conflict.severity for conflict in conflicts)
avg_severity = total_severity / len(conflicts)
# Reduce confidence based on severity
severity_penalty = avg_severity * 0.3
# Reduce confidence based on number of conflicts
conflict_penalty = min(0.2, len(conflicts) * 0.05)
# Boost confidence based on source quality
quality_boost = 0.0
if source_qualities:
avg_quality = sum(q.overall_score for q in source_qualities.values()) / len(source_qualities)
quality_boost = avg_quality * 0.1
confidence = base_confidence - severity_penalty - conflict_penalty + quality_boost
return max(0.0, min(1.0, confidence))
# Global conflict resolver instance
_global_conflict_resolver: Optional[ConflictResolver] = None
def get_global_conflict_resolver() -> ConflictResolver:
"""Get the global conflict resolver instance."""
global _global_conflict_resolver
if _global_conflict_resolver is None:
_global_conflict_resolver = ConflictResolver()
return _global_conflict_resolver
def resolve_data_conflicts(multi_source_data: Dict[str, Dict[str, Any]],
source_qualities: Optional[Dict[str, QualityScore]] = None,
user_preferences: Optional[Dict[str, Any]] = None) -> ConflictResolution:
"""Convenience function to resolve data conflicts."""
resolver = get_global_conflict_resolver()
return resolver.resolve_conflicts(multi_source_data, source_qualities, user_preferences)