File size: 7,900 Bytes
b701455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
"""
Prompt Attention Caching
Caches CLIP embeddings for repeated prompts to avoid re-encoding.
Training-free, lossless optimization providing 5-15% speedup.
"""

import hashlib
from functools import lru_cache
import torch
import logging

# Global cache enabled flag
_cache_enabled = True

def enable_prompt_cache(enabled: bool = True):
    """Enable or disable prompt caching globally.
    
    Args:
        enabled (bool): Whether to enable caching. Defaults to True.
    """
    global _cache_enabled
    _cache_enabled = enabled
    if not enabled:
        clear_prompt_cache()
    logging.info(f"Prompt caching {'enabled' if enabled else 'disabled'}")


def is_prompt_cache_enabled() -> bool:
    """Check if prompt caching is enabled.
    
    Returns:
        bool: True if caching is enabled.
    """
    return _cache_enabled


def get_prompt_hash(prompt: str) -> int:
    """Generate a fast hash for a prompt.
    
    Uses Python's built-in hash() which is much faster than MD5
    and sufficient for cache keying (not cryptographic).
    
    Args:
        prompt (str): The text prompt.
    
    Returns:
        int: Hash of the prompt.
    """
    return hash(prompt)


def _get_clip_identity(clip) -> str:
    """Get a stable identity string for a CLIP model instance.
    
    Uses the model's checkpoint path or class name instead of id(clip)
    which changes when a model is reloaded at the same logical identity.
    
    Args:
        clip: CLIP model instance.
        
    Returns:
        str: Stable identity string.
    """
    # Try to get a stable path-based identifier
    if hasattr(clip, 'model_path') and clip.model_path:
        return f"clip:{clip.model_path}"
    if hasattr(clip, 'patcher') and hasattr(clip.patcher, 'model_path'):
        return f"clip:{clip.patcher.model_path}"
    # Fall back to class name + parameter count for stability
    try:
        param_count = sum(p.numel() for p in clip.parameters() if hasattr(clip, 'parameters'))
        return f"clip:{clip.__class__.__name__}:{param_count}"
    except Exception:
        # Last resort: use id() (not ideal but better than crashing)
        return f"clip:id:{id(clip)}"


# LRU cache with 128 slots (enough for typical session)
# Each cached entry is ~100-500KB depending on model
@lru_cache(maxsize=128)
def _cached_encode_impl(prompt_hash: str, prompt: str, clip_id: int):
    """Internal cached encoding function.
    
    Note: This is called by get_cached_encoding and should not be called directly.
    The actual encoding happens in the calling code, this just provides the cache wrapper.
    
    Args:
        prompt_hash (str): Hash of the prompt.
        prompt (str): The actual prompt text.
        clip_id (int): Unique ID of the CLIP model instance.
    
    Returns:
        None (actual encoding happens in caller)
    """
    pass


class PromptCacheEntry:
    """Container for cached prompt encoding results."""
    
    def __init__(self, cond: torch.Tensor, pooled: torch.Tensor):
        """Initialize cache entry.
        
        Args:
            cond (torch.Tensor): Conditional embedding tensor.
            pooled (torch.Tensor): Pooled output tensor.
        """
        # We don't clone here because these tensors are treated as read-only
        # by consumers, and the producer (CLIP) creates fresh tensors
        # for each encoding. This reduces memory pressure and latency.
        self.cond = cond if cond is not None else None
        self.pooled = pooled if pooled is not None else None
        self.hits = 0
    
    def get(self) -> tuple:
        """Get cached tensors (returns references for performance).
        
        Returns:
            tuple: (cond, pooled) tensors.
        """
        self.hits += 1
        # Returns direct references. Tensors are assumed to be read-only.
        return (self.cond, self.pooled)


# Secondary cache using dict for more control
_prompt_cache_dict = {}
_cache_stats = {"hits": 0, "misses": 0, "size_mb": 0.0}


def get_cached_encoding(clip, prompt: str) -> tuple:
    """Get cached encoding or encode and cache if not present.
    
    Args:
        clip: CLIP model instance.
        prompt (str): Text prompt.
    
    Returns:
        tuple: (cond, pooled) or None if caching disabled.
    """
    if not _cache_enabled:
        return None
    
    prompt_hash = get_prompt_hash(prompt)
    clip_key = _get_clip_identity(clip)
    cache_key = f"{clip_key}_{prompt_hash}"
    
    # Check if we have it cached
    if cache_key in _prompt_cache_dict:
        _cache_stats["hits"] += 1
        entry = _prompt_cache_dict[cache_key]
        cond, pooled = entry.get()
        
        if _cache_stats["hits"] % 10 == 0:  # Log every 10 hits
            hit_rate = _cache_stats["hits"] / max(1, _cache_stats["hits"] + _cache_stats["misses"])
            logging.debug(f"Prompt cache hit rate: {hit_rate:.1%} (size: {len(_prompt_cache_dict)} entries)")
        
        return (cond, pooled)
    
    # Cache miss
    _cache_stats["misses"] += 1
    return None


def cache_encoding(clip, prompt: str, cond: torch.Tensor, pooled: torch.Tensor):
    """Cache an encoding result.
    
    Args:
        clip: CLIP model instance.
        prompt (str): Text prompt.
        cond (torch.Tensor): Conditional embedding.
        pooled (torch.Tensor): Pooled output.
    """
    if not _cache_enabled:
        return
    
    prompt_hash = get_prompt_hash(prompt)
    clip_key = _get_clip_identity(clip)
    cache_key = f"{clip_key}_{prompt_hash}"
    
    # Don't cache if already present
    if cache_key in _prompt_cache_dict:
        return
    
    # Store in cache
    entry = PromptCacheEntry(cond, pooled)
    _prompt_cache_dict[cache_key] = entry
    
    # Update size estimate (rough)
    if cond is not None:
        _cache_stats["size_mb"] = len(_prompt_cache_dict) * (cond.numel() * cond.element_size() / 1024 / 1024)
    
    # Limit cache size to prevent memory issues
    max_entries = 256
    if len(_prompt_cache_dict) > max_entries:
        # Remove oldest 25% of entries (simple FIFO)
        remove_count = max_entries // 4
        keys_to_remove = list(_prompt_cache_dict.keys())[:remove_count]
        for key in keys_to_remove:
            del _prompt_cache_dict[key]
        logging.debug(f"Prompt cache pruned: removed {remove_count} old entries")


def clear_prompt_cache():
    """Clear the entire prompt cache."""
    global _prompt_cache_dict, _cache_stats
    
    old_size = len(_prompt_cache_dict)
    _prompt_cache_dict.clear()
    _cached_encode_impl.cache_clear()  # Clear LRU cache too
    _cache_stats = {"hits": 0, "misses": 0, "size_mb": 0.0}
    
    if old_size > 0:
        logging.info(f"Prompt cache cleared ({old_size} entries removed)")


def get_cache_stats() -> dict:
    """Get cache statistics.
    
    Returns:
        dict: Stats including hits, misses, hit rate, size.
    """
    total_requests = _cache_stats["hits"] + _cache_stats["misses"]
    hit_rate = _cache_stats["hits"] / max(1, total_requests)
    
    return {
        "enabled": _cache_enabled,
        "hits": _cache_stats["hits"],
        "misses": _cache_stats["misses"],
        "total_requests": total_requests,
        "hit_rate": hit_rate,
        "cache_entries": len(_prompt_cache_dict),
        "estimated_size_mb": _cache_stats["size_mb"],
    }


def print_cache_stats():
    """Print cache statistics to console."""
    stats = get_cache_stats()
    print("\n" + "="*60)
    print("Prompt Cache Statistics")
    print("="*60)
    print(f"  Status: {'Enabled' if stats['enabled'] else 'Disabled'}")
    print(f"  Entries: {stats['cache_entries']}")
    print(f"  Size: ~{stats['estimated_size_mb']:.1f} MB")
    print(f"  Requests: {stats['total_requests']} (hits: {stats['hits']}, misses: {stats['misses']})")
    print(f"  Hit Rate: {stats['hit_rate']:.1%}")
    print("="*60 + "\n")