File size: 16,116 Bytes
0a4529c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
# DEPENDENCIES
import math
from typing import Any
from typing import Dict
from typing import Optional
from config.settings import get_settings
from config.logging_config import get_logger
from utils.error_handler import handle_errors
from config.models import TemperatureStrategy
from utils.error_handler import TemperatureControlError


# Setup Settings and Logging
settings = get_settings()
logger   = get_logger(__name__)


class TemperatureController:
    """
    Intelligent temperature control for LLM generation: Implements adaptive temperature strategies based on query type, complexity, and desired output characteristics
    """
    def __init__(self, base_temperature: float = None, strategy: TemperatureStrategy = None):
        """
        Initialize temperature controller
        
        Arguments:
        ----------
            base_temperature { float } : Base temperature value (default from settings)
            
            strategy          { str }  : Temperature control strategy
        """
        self.logger           = logger
        self.settings         = get_settings()
        self.base_temperature = base_temperature or self.settings.DEFAULT_TEMPERATURE
        self.strategy         = strategy or TemperatureStrategy.ADAPTIVE
        
        # Validate base temperature
        if not (0.0 <= self.base_temperature <= 1.0):
            raise TemperatureControlError(f"Temperature must be between 0 and 1: {self.base_temperature}")
        
        # Strategy configurations
        self.strategy_configs = {TemperatureStrategy.FIXED       : {"description" : "Fixed temperature for all queries", "range" : (0.0, 1.0)},
                                 TemperatureStrategy.ADAPTIVE    : {"description" : "Adapt temperature based on query complexity", "range" : (0.1, 0.8), "complexity_threshold" : 0.6},
                                 TemperatureStrategy.CONFIDENCE  : {"description" : "Adjust temperature based on retrieval confidence", "range" : (0.1, 0.9), "high_confidence_temp" : 0.1, "low_confidence_temp" : 0.7},
                                 TemperatureStrategy.PROGRESSIVE : {"description" : "Progressively increase temperature for creative tasks", "range" : (0.1, 0.9), "creative_threshold" : 0.7}
                                }
        
        self.logger.info(f"Initialized TemperatureController: base={self.base_temperature}, strategy={self.strategy}")
    

    def get_temperature(self, query: str = "", context: str = "", retrieval_scores: Optional[list] = None, query_type: str = "qa") -> float:
        """
        Get appropriate temperature for generation
        
        Arguments:
        ----------
            query            { str }  : User query
            
            context          { str }  : Retrieved context
            
            retrieval_scores { list } : Scores of retrieved chunks
            
            query_type       { str }  : Type of query ('qa', 'creative', 'analytical', 'summary')
        
        Returns:
        --------
                   { float }          : Temperature value (0.0 - 1.0)
        """
        if (self.strategy == TemperatureStrategy.FIXED):
            return self._fixed_temperature()
        
        elif (self.strategy == TemperatureStrategy.ADAPTIVE):
            return self._adaptive_temperature(query      = query, 
                                              context    = context, 
                                              query_type = query_type,
                                             )
        
        elif (self.strategy == TemperatureStrategy.CONFIDENCE):
            return self._confidence_based_temperature(retrieval_scores = retrieval_scores, 
                                                      query_type       = query_type,
                                                     )
        
        elif (self.strategy == TemperatureStrategy.PROGRESSIVE):
            return self._progressive_temperature(query_type = query_type, 
                                                 query      = query,
                                                )
        
        else:
            self.logger.warning(f"Unknown strategy: {self.strategy}, using fixed")
            return self.base_temperature
    

    def _fixed_temperature(self) -> float:
        """
        Fixed temperature strategy
        """
        return self.base_temperature
    

    def _adaptive_temperature(self, query: str, context: str, query_type: str) -> float:
        """
        Adaptive temperature based on query complexity and type
        """
        base_temp        = self.base_temperature
        
        # Adjust based on query type
        type_adjustments = {"qa"         : -0.2,      # More deterministic for Q&A
                            "creative"   : 0.3,       # More creative for creative tasks
                            "analytical" : -0.1,      # Slightly deterministic for analysis
                            "summary"    : -0.15,     # Deterministic for summarization
                            "comparison" : 0.1,       # Slightly creative for comparisons
                           }
        
        adjustment       = type_adjustments.get(query_type, 0.0)
        temp             = base_temp + adjustment
        
        # Adjust based on query complexity
        complexity       = self._calculate_query_complexity(query = query)
        
        if (complexity > 0.7): 
            # High complexity
            temp += 0.1

        elif (complexity < 0.3):  
            # Low complexity
            temp -= 0.1
        
        # Adjust based on context quality
        if context:
            context_quality = self._calculate_context_quality(context = context)

            # Poor context
            if (context_quality < 0.5):  
                # More creative when context is poor
                temp += 0.15  
        
        return self._clamp_temperature(temperature = temp)
    

    def _confidence_based_temperature(self, retrieval_scores: Optional[list], query_type: str) -> float:
        """
        Temperature based on retrieval confidence
        """
        if not retrieval_scores:
            self.logger.debug("No retrieval scores, using base temperature")
            return self.base_temperature
        
        # Calculate average confidence
        avg_confidence = sum(retrieval_scores) / len(retrieval_scores)
        
        config         = self.strategy_configs[TemperatureStrategy.CONFIDENCE]
        high_temp      = config["high_confidence_temp"]
        low_temp       = config["low_confidence_temp"]
        
        # High confidence -> low temperature (deterministic) & Low confidence -> high temperature (creative)
        if (avg_confidence > 0.8):
            temperature = high_temp

        elif (avg_confidence < 0.3):
            temperature = low_temp

        else:
            # Linear interpolation between high and low temps
            normalized_confidence = (avg_confidence - 0.3) / (0.8 - 0.3)
            temperature           = high_temp + (low_temp - high_temp) * (1 - normalized_confidence)
        
        # Adjust for query type
        if (query_type == "creative"):
            temperature = min(0.9, temperature + 0.2)

        elif (query_type == "qa"):
            temperature = max(0.1, temperature - 0.1)
        
        return self._clamp_temperature(temperature = temperature)
    

    def _progressive_temperature(self, query_type: str, query: str) -> float:
        """
        Progressive temperature based on task requirements
        """
        base_temp = self.base_temperature
        
        # Task-based progression
        if (query_type == "creative"):
            # High creativity
            return self._clamp_temperature(temperature = 0.8)  
        
        elif (query_type == "analytical"):
            # Balanced
            return self._clamp_temperature(temperature = 0.3)  
        
        elif (query_type == "qa"):
            # For factual Q&A, use lower temperature
            if self._is_factual_query(query):
                return self._clamp_temperature(temperature = 0.1)
            
            else:
                return self._clamp_temperature(temperature = 0.4)
        
        elif (query_type == "summary"):
            # Deterministic summaries
            return self._clamp_temperature(temperature = 0.2)  
        
        else:
            return self._clamp_temperature(temperature = base_temp)
    

    def _calculate_query_complexity(self, query: str) -> float:
        """
        Simple, predictable complexity score
        """
        if not query:
            return 0.5
        
        # Count words and questions
        words       = len(query.split())
        has_why_how = any(word in query.lower() for word in ['why', 'how', 'explain'])
        has_compare = any(word in query.lower() for word in ['compare', 'contrast', 'difference'])
        
        # Simple rules
        if has_compare:
            # Complex
            return 0.8  
        
        elif (has_why_how and( words > 15)):
            return 0.7
        
        elif words > 20:
            return 0.6
        
        else:
            # Simple
            return 0.3  
    

    def _calculate_context_quality(self, context: str) -> float:
        """
        Calculate context quality (0.0 - 1.0)
        """
        if not context:
            return 0.0
        
        factors          = list()
        
        # Length factor (adequate context)
        words            = len(context.split())

        # Normalize
        length_factor    = min(words / 500, 1.0)  

        factors.append(length_factor)
        
        # Diversity factor (multiple sources/citations)
        citation_count   = context.count('[')
        diversity_factor = min(citation_count / 5, 1.0)
        
        factors.append(diversity_factor)
        
        # Coherence factor (simple measure)
        sentence_count   = context.count('.')
        
        if (sentence_count > 0):
            avg_sentence_length = words / sentence_count
            # Ideal ~20 words/sentence
            coherence_factor    = 1.0 - min(abs(avg_sentence_length - 20) / 50, 1.0)  

            factors.append(coherence_factor)
        
        return sum(factors) / len(factors)
    

    def _is_factual_query(self, query: str) -> bool:
        """
        Check if query is factual (requires precise answers)
        """
        factual_indicators = ['what is', 'who is', 'when did', 'where is', 'how many', 'how much', 'definition of', 'meaning of', 'calculate', 'number of']
        
        query_lower        = query.lower()

        return any(indicator in query_lower for indicator in factual_indicators)
    

    def _clamp_temperature(self, temperature: float) -> float:
        """
        Clamp temperature to valid range
        """
        strategy_config = self.strategy_configs.get(self.strategy, {})
        temp_range      = strategy_config.get("range", (0.0, 1.0))
        
        clamped         = max(temp_range[0], min(temperature, temp_range[1]))

        # Round to 2 decimal places
        clamped         = round(clamped, 2)  
        
        return clamped
    

    def get_temperature_parameters(self, temperature: float) -> Dict[str, Any]:
        """
        Get additional parameters based on temperature
        
        Arguments:
        ----------
            temperature { float } : Temperature value
        
        Returns:
        --------
                   { dict }       : Additional generation parameters
        """
        params = {"temperature" : temperature,
                  "top_p"       : 0.9,
                 }
        
        # Adjust top_p based on temperature
        if (temperature < 0.3):
            # Broader distribution for low temp
            params["top_p"] = 0.95  

        elif (temperature > 0.7):
            # Narrower distribution for high temp
            params["top_p"] = 0.7  
        
        # Adjust presence_penalty based on temperature
        if (temperature > 0.5):
            # Encourage novelty for creative tasks
            params["presence_penalty"] = 0.1  

        else:
            params["presence_penalty"] = 0.0
        
        return params
    

    def explain_temperature_choice(self, query: str, context: str, retrieval_scores: list, query_type: str, final_temperature: float) -> Dict[str, Any]:
        """
        Explain why a particular temperature was chosen
        
        Arguments:
        ----------
            query             { str }   : User query
            
            context           { str }   : Retrieved context
            
            retrieval_scores  { list }  : Retrieval scores
            
            query_type        { str }   : Query type
            
            final_temperature { float } : Chosen temperature
        
        Returns:
        --------
                      { dict }          : Explanation dictionary
        """
        explanation = {"strategy"          : self.strategy.value,
                       "final_temperature" : final_temperature,
                       "base_temperature"  : self.base_temperature,
                       "factors"           : {},
                      }
        
        if (self.strategy == TemperatureStrategy.ADAPTIVE):
            complexity             = self._calculate_query_complexity(query = query)
            context_quality        = self._calculate_context_quality(context = context)
            
            explanation["factors"] = {"query_complexity" : round(complexity, 3),
                                      "context_quality"  : round(context_quality, 3),
                                      "query_type"       : query_type,
                                     }
        
        elif (self.strategy == TemperatureStrategy.CONFIDENCE):
            if retrieval_scores:
                avg_confidence         = sum(retrieval_scores) / len(retrieval_scores)
                explanation["factors"] = {"average_retrieval_confidence" : round(avg_confidence, 3),
                                          "query_type"                   : query_type,
                                         }
        
        elif (self.strategy == TemperatureStrategy.PROGRESSIVE):
            is_factual             = self._is_factual_query(query)
            explanation["factors"] = {"query_type"       : query_type,
                                      "is_factual_query" : is_factual,
                                     }
        
        return explanation


# Global temperature controller instance
_temperature_controller = None


def get_temperature_controller() -> TemperatureController:
    """
    Get global temperature controller instance (singleton)
    
    Returns:
    --------
        { TemperatureController }    : TemperatureController instance
    """
    global _temperature_controller
    
    if _temperature_controller is None:
        _temperature_controller = TemperatureController()
    
    return _temperature_controller


@handle_errors(error_type=TemperatureControlError, log_error=True, reraise=False)
def get_adaptive_temperature(query: str = "", context: str = "", retrieval_scores: list = None, query_type: str = "qa") -> float:
    """
    Convenience function for getting adaptive temperature
    
    Arguments:
    ----------
        query            { str }   : User query
        
        context          { str }   : Retrieved context
         
        retrieval_scores { list }  : Retrieval scores
        
        query_type       { str }   : Query type
    
    Returns:
    --------
               { float }           : Temperature value
    """
    controller = get_temperature_controller()
    
    return controller.get_temperature(query            = query, 
                                      context          = context, 
                                      retrieval_scores = retrieval_scores, 
                                      query_type       = query_type,
                                     )