# DEPENDENCIES import math from typing import Any from typing import Dict from typing import Optional from config.settings import get_settings from config.logging_config import get_logger from utils.error_handler import handle_errors from config.models import TemperatureStrategy from utils.error_handler import TemperatureControlError # Setup Settings and Logging settings = get_settings() logger = get_logger(__name__) class TemperatureController: """ Intelligent temperature control for LLM generation: Implements adaptive temperature strategies based on query type, complexity, and desired output characteristics """ def __init__(self, base_temperature: float = None, strategy: TemperatureStrategy = None): """ Initialize temperature controller Arguments: ---------- base_temperature { float } : Base temperature value (default from settings) strategy { str } : Temperature control strategy """ self.logger = logger self.settings = get_settings() self.base_temperature = base_temperature or self.settings.DEFAULT_TEMPERATURE self.strategy = strategy or TemperatureStrategy.ADAPTIVE # Validate base temperature if not (0.0 <= self.base_temperature <= 1.0): raise TemperatureControlError(f"Temperature must be between 0 and 1: {self.base_temperature}") # Strategy configurations self.strategy_configs = {TemperatureStrategy.FIXED : {"description" : "Fixed temperature for all queries", "range" : (0.0, 1.0)}, TemperatureStrategy.ADAPTIVE : {"description" : "Adapt temperature based on query complexity", "range" : (0.1, 0.8), "complexity_threshold" : 0.6}, TemperatureStrategy.CONFIDENCE : {"description" : "Adjust temperature based on retrieval confidence", "range" : (0.1, 0.9), "high_confidence_temp" : 0.1, "low_confidence_temp" : 0.7}, TemperatureStrategy.PROGRESSIVE : {"description" : "Progressively increase temperature for creative tasks", "range" : (0.1, 0.9), "creative_threshold" : 0.7} } self.logger.info(f"Initialized TemperatureController: base={self.base_temperature}, strategy={self.strategy}") def get_temperature(self, query: str = "", context: str = "", retrieval_scores: Optional[list] = None, query_type: str = "qa") -> float: """ Get appropriate temperature for generation Arguments: ---------- query { str } : User query context { str } : Retrieved context retrieval_scores { list } : Scores of retrieved chunks query_type { str } : Type of query ('qa', 'creative', 'analytical', 'summary') Returns: -------- { float } : Temperature value (0.0 - 1.0) """ if (self.strategy == TemperatureStrategy.FIXED): return self._fixed_temperature() elif (self.strategy == TemperatureStrategy.ADAPTIVE): return self._adaptive_temperature(query = query, context = context, query_type = query_type, ) elif (self.strategy == TemperatureStrategy.CONFIDENCE): return self._confidence_based_temperature(retrieval_scores = retrieval_scores, query_type = query_type, ) elif (self.strategy == TemperatureStrategy.PROGRESSIVE): return self._progressive_temperature(query_type = query_type, query = query, ) else: self.logger.warning(f"Unknown strategy: {self.strategy}, using fixed") return self.base_temperature def _fixed_temperature(self) -> float: """ Fixed temperature strategy """ return self.base_temperature def _adaptive_temperature(self, query: str, context: str, query_type: str) -> float: """ Adaptive temperature based on query complexity and type """ base_temp = self.base_temperature # Adjust based on query type type_adjustments = {"qa" : -0.2, # More deterministic for Q&A "creative" : 0.3, # More creative for creative tasks "analytical" : -0.1, # Slightly deterministic for analysis "summary" : -0.15, # Deterministic for summarization "comparison" : 0.1, # Slightly creative for comparisons } adjustment = type_adjustments.get(query_type, 0.0) temp = base_temp + adjustment # Adjust based on query complexity complexity = self._calculate_query_complexity(query = query) if (complexity > 0.7): # High complexity temp += 0.1 elif (complexity < 0.3): # Low complexity temp -= 0.1 # Adjust based on context quality if context: context_quality = self._calculate_context_quality(context = context) # Poor context if (context_quality < 0.5): # More creative when context is poor temp += 0.15 return self._clamp_temperature(temperature = temp) def _confidence_based_temperature(self, retrieval_scores: Optional[list], query_type: str) -> float: """ Temperature based on retrieval confidence """ if not retrieval_scores: self.logger.debug("No retrieval scores, using base temperature") return self.base_temperature # Calculate average confidence avg_confidence = sum(retrieval_scores) / len(retrieval_scores) config = self.strategy_configs[TemperatureStrategy.CONFIDENCE] high_temp = config["high_confidence_temp"] low_temp = config["low_confidence_temp"] # High confidence -> low temperature (deterministic) & Low confidence -> high temperature (creative) if (avg_confidence > 0.8): temperature = high_temp elif (avg_confidence < 0.3): temperature = low_temp else: # Linear interpolation between high and low temps normalized_confidence = (avg_confidence - 0.3) / (0.8 - 0.3) temperature = high_temp + (low_temp - high_temp) * (1 - normalized_confidence) # Adjust for query type if (query_type == "creative"): temperature = min(0.9, temperature + 0.2) elif (query_type == "qa"): temperature = max(0.1, temperature - 0.1) return self._clamp_temperature(temperature = temperature) def _progressive_temperature(self, query_type: str, query: str) -> float: """ Progressive temperature based on task requirements """ base_temp = self.base_temperature # Task-based progression if (query_type == "creative"): # High creativity return self._clamp_temperature(temperature = 0.8) elif (query_type == "analytical"): # Balanced return self._clamp_temperature(temperature = 0.3) elif (query_type == "qa"): # For factual Q&A, use lower temperature if self._is_factual_query(query): return self._clamp_temperature(temperature = 0.1) else: return self._clamp_temperature(temperature = 0.4) elif (query_type == "summary"): # Deterministic summaries return self._clamp_temperature(temperature = 0.2) else: return self._clamp_temperature(temperature = base_temp) def _calculate_query_complexity(self, query: str) -> float: """ Simple, predictable complexity score """ if not query: return 0.5 # Count words and questions words = len(query.split()) has_why_how = any(word in query.lower() for word in ['why', 'how', 'explain']) has_compare = any(word in query.lower() for word in ['compare', 'contrast', 'difference']) # Simple rules if has_compare: # Complex return 0.8 elif (has_why_how and( words > 15)): return 0.7 elif words > 20: return 0.6 else: # Simple return 0.3 def _calculate_context_quality(self, context: str) -> float: """ Calculate context quality (0.0 - 1.0) """ if not context: return 0.0 factors = list() # Length factor (adequate context) words = len(context.split()) # Normalize length_factor = min(words / 500, 1.0) factors.append(length_factor) # Diversity factor (multiple sources/citations) citation_count = context.count('[') diversity_factor = min(citation_count / 5, 1.0) factors.append(diversity_factor) # Coherence factor (simple measure) sentence_count = context.count('.') if (sentence_count > 0): avg_sentence_length = words / sentence_count # Ideal ~20 words/sentence coherence_factor = 1.0 - min(abs(avg_sentence_length - 20) / 50, 1.0) factors.append(coherence_factor) return sum(factors) / len(factors) def _is_factual_query(self, query: str) -> bool: """ Check if query is factual (requires precise answers) """ factual_indicators = ['what is', 'who is', 'when did', 'where is', 'how many', 'how much', 'definition of', 'meaning of', 'calculate', 'number of'] query_lower = query.lower() return any(indicator in query_lower for indicator in factual_indicators) def _clamp_temperature(self, temperature: float) -> float: """ Clamp temperature to valid range """ strategy_config = self.strategy_configs.get(self.strategy, {}) temp_range = strategy_config.get("range", (0.0, 1.0)) clamped = max(temp_range[0], min(temperature, temp_range[1])) # Round to 2 decimal places clamped = round(clamped, 2) return clamped def get_temperature_parameters(self, temperature: float) -> Dict[str, Any]: """ Get additional parameters based on temperature Arguments: ---------- temperature { float } : Temperature value Returns: -------- { dict } : Additional generation parameters """ params = {"temperature" : temperature, "top_p" : 0.9, } # Adjust top_p based on temperature if (temperature < 0.3): # Broader distribution for low temp params["top_p"] = 0.95 elif (temperature > 0.7): # Narrower distribution for high temp params["top_p"] = 0.7 # Adjust presence_penalty based on temperature if (temperature > 0.5): # Encourage novelty for creative tasks params["presence_penalty"] = 0.1 else: params["presence_penalty"] = 0.0 return params def explain_temperature_choice(self, query: str, context: str, retrieval_scores: list, query_type: str, final_temperature: float) -> Dict[str, Any]: """ Explain why a particular temperature was chosen Arguments: ---------- query { str } : User query context { str } : Retrieved context retrieval_scores { list } : Retrieval scores query_type { str } : Query type final_temperature { float } : Chosen temperature Returns: -------- { dict } : Explanation dictionary """ explanation = {"strategy" : self.strategy.value, "final_temperature" : final_temperature, "base_temperature" : self.base_temperature, "factors" : {}, } if (self.strategy == TemperatureStrategy.ADAPTIVE): complexity = self._calculate_query_complexity(query = query) context_quality = self._calculate_context_quality(context = context) explanation["factors"] = {"query_complexity" : round(complexity, 3), "context_quality" : round(context_quality, 3), "query_type" : query_type, } elif (self.strategy == TemperatureStrategy.CONFIDENCE): if retrieval_scores: avg_confidence = sum(retrieval_scores) / len(retrieval_scores) explanation["factors"] = {"average_retrieval_confidence" : round(avg_confidence, 3), "query_type" : query_type, } elif (self.strategy == TemperatureStrategy.PROGRESSIVE): is_factual = self._is_factual_query(query) explanation["factors"] = {"query_type" : query_type, "is_factual_query" : is_factual, } return explanation # Global temperature controller instance _temperature_controller = None def get_temperature_controller() -> TemperatureController: """ Get global temperature controller instance (singleton) Returns: -------- { TemperatureController } : TemperatureController instance """ global _temperature_controller if _temperature_controller is None: _temperature_controller = TemperatureController() return _temperature_controller @handle_errors(error_type=TemperatureControlError, log_error=True, reraise=False) def get_adaptive_temperature(query: str = "", context: str = "", retrieval_scores: list = None, query_type: str = "qa") -> float: """ Convenience function for getting adaptive temperature Arguments: ---------- query { str } : User query context { str } : Retrieved context retrieval_scores { list } : Retrieval scores query_type { str } : Query type Returns: -------- { float } : Temperature value """ controller = get_temperature_controller() return controller.get_temperature(query = query, context = context, retrieval_scores = retrieval_scores, query_type = query_type, )