gary-boon Claude Opus 4.5 commited on
Commit
929ba88
·
1 Parent(s): 66a46b6

feat: implement lazy-loading for attention matrices

Browse files

- Add MatrixCache class with 60-min TTL for storing attention/QKV matrices
- Modify response builder to cache matrices instead of including in payload
- Add new endpoint /analyze/research/attention/matrix for on-demand retrieval
- Include requestId in responses for cache lookup
- Reduces initial response from 400MB+ to ~500KB-1MB

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

Files changed (1) hide show
  1. backend/model_service.py +135 -11
backend/model_service.py CHANGED
@@ -18,6 +18,9 @@ import numpy as np
18
  import logging
19
  from datetime import datetime
20
  import traceback
 
 
 
21
  from .auth import verify_api_key
22
  from .instrumentation import ModelInstrumentor, InstrumentationData, TokenMetadata
23
  from .storage import ZarrStorage, generate_run_id
@@ -56,6 +59,63 @@ def sanitize_for_json(obj):
56
  else:
57
  return obj
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  app = FastAPI(title="Visualisable.ai Model Service", version="0.1.0")
60
 
61
  # CORS configuration for local development and production
@@ -1507,6 +1567,9 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
1507
  import time
1508
  start_time = time.time()
1509
 
 
 
 
1510
  # Get parameters
1511
  prompt = request.get("prompt", "def quicksort(arr):")
1512
  max_tokens = request.get("max_tokens", 8)
@@ -1777,15 +1840,21 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
1777
  k_matrix = qkv_captures[layer_idx]['k'][:, head_idx, :].float().numpy().tolist()
1778
  v_matrix = qkv_captures[layer_idx]['v'][:, head_idx, :].float().numpy().tolist()
1779
 
 
 
 
 
 
 
 
 
 
1780
  critical_heads.append({
1781
  "head_idx": head_idx,
1782
  "entropy": entropy,
1783
  "avg_entropy": avg_entropy, # Averaged over all query positions
1784
  "max_weight": max_weight,
1785
- "attention_weights": attention_matrix, # Full attention matrix for spreadsheet
1786
- "q_matrix": q_matrix, # [seq_len, head_dim]
1787
- "k_matrix": k_matrix,
1788
- "v_matrix": v_matrix,
1789
  "pattern": {
1790
  "type": pattern_type,
1791
  "confidence": confidence
@@ -1915,6 +1984,7 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
1915
 
1916
  # Build response
1917
  response = {
 
1918
  "prompt": prompt,
1919
  "promptTokens": build_token_data(prompt_token_ids, prompt_tokens, "prompt"),
1920
  "generatedTokens": build_token_data(generated_token_ids, generated_tokens, "generated"),
@@ -1922,7 +1992,7 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
1922
  "tokenAlternatives": token_alternatives_by_step, # Top-k alternatives for each token
1923
  "layersDataByStep": layer_data_by_token, # Layer data for ALL generation steps
1924
  "layersData": layer_data_by_token[-1] if layer_data_by_token else [], # Keep for backward compatibility
1925
- "qkvData": qkv_by_layer_head,
1926
  "modelInfo": {
1927
  "numLayers": n_layers,
1928
  "numHeads": n_heads,
@@ -1969,12 +2039,15 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
1969
  import time
1970
  start_time = time.time()
1971
 
 
 
 
1972
  # Get parameters
1973
  prompt = request.get("prompt", "def quicksort(arr):")
1974
  max_tokens = request.get("max_tokens", 8)
1975
  temperature = request.get("temperature", 0.7)
1976
 
1977
- logger.info(f"[SSE] Research attention analysis: prompt_len={len(prompt)}, max_tokens={max_tokens}")
1978
 
1979
  # === STAGE 1: TOKENIZING ===
1980
  yield sse_event('tokenizing', stage=1, totalStages=5, progress=2,
@@ -2233,15 +2306,21 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2233
  k_matrix = qkv_captures[layer_idx]['k'][:, head_idx, :].float().numpy().tolist()
2234
  v_matrix = qkv_captures[layer_idx]['v'][:, head_idx, :].float().numpy().tolist()
2235
 
 
 
 
 
 
 
 
 
 
2236
  critical_heads.append({
2237
  "head_idx": head_idx,
2238
  "entropy": entropy,
2239
  "avg_entropy": avg_entropy, # Averaged over all query positions
2240
  "max_weight": max_weight,
2241
- "attention_weights": attention_matrix,
2242
- "q_matrix": q_matrix,
2243
- "k_matrix": k_matrix,
2244
- "v_matrix": v_matrix,
2245
  "pattern": {"type": pattern_type, "confidence": confidence} if pattern_type else None
2246
  })
2247
 
@@ -2364,6 +2443,7 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2364
 
2365
  # Build response
2366
  response = {
 
2367
  "prompt": prompt,
2368
  "promptTokens": build_token_data(prompt_token_ids, prompt_tokens, "prompt"),
2369
  "generatedTokens": build_token_data(generated_token_ids, generated_tokens, "generated"),
@@ -2371,7 +2451,7 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2371
  "tokenAlternatives": token_alternatives_by_step,
2372
  "layersDataByStep": layer_data_by_token,
2373
  "layersData": layer_data_by_token[-1] if layer_data_by_token else [],
2374
- "qkvData": qkv_by_layer_head,
2375
  "modelInfo": {
2376
  "numLayers": n_layers,
2377
  "numHeads": n_heads,
@@ -2418,6 +2498,50 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2418
  )
2419
 
2420
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2421
  @app.post("/analyze/study")
2422
  async def analyze_study(request: StudyRequest, authenticated: bool = Depends(verify_api_key)):
2423
  """
 
18
  import logging
19
  from datetime import datetime
20
  import traceback
21
+ import uuid
22
+ from threading import Lock
23
+ from time import time as time_now
24
  from .auth import verify_api_key
25
  from .instrumentation import ModelInstrumentor, InstrumentationData, TokenMetadata
26
  from .storage import ZarrStorage, generate_run_id
 
59
  else:
60
  return obj
61
 
62
+
63
+ # Matrix cache for lazy loading (60 min TTL)
64
+ class MatrixCache:
65
+ """
66
+ Thread-safe in-memory cache for attention matrices.
67
+ Stores Q/K/V and attention weights per (request_id, step, layer, head).
68
+ """
69
+ def __init__(self, ttl_seconds: int = 3600):
70
+ self._cache: Dict[str, Dict] = {}
71
+ self._timestamps: Dict[str, float] = {}
72
+ self._lock = Lock()
73
+ self._ttl = ttl_seconds
74
+
75
+ def store(self, request_id: str, step: int, layer: int, head: int, data: dict):
76
+ """Store matrix data for a specific head."""
77
+ key = f"{request_id}:{step}:{layer}:{head}"
78
+ with self._lock:
79
+ self._cache[key] = data
80
+ self._timestamps[key] = time_now()
81
+
82
+ def get(self, request_id: str, step: int, layer: int, head: int) -> Optional[dict]:
83
+ """Retrieve matrix data, returning None if expired or not found."""
84
+ key = f"{request_id}:{step}:{layer}:{head}"
85
+ with self._lock:
86
+ if key in self._cache:
87
+ if time_now() - self._timestamps[key] < self._ttl:
88
+ return self._cache[key]
89
+ else:
90
+ # Expired - clean up
91
+ del self._cache[key]
92
+ del self._timestamps[key]
93
+ return None
94
+
95
+ def cleanup_expired(self):
96
+ """Remove all expired entries from cache."""
97
+ with self._lock:
98
+ now = time_now()
99
+ expired = [k for k, t in self._timestamps.items() if now - t >= self._ttl]
100
+ for k in expired:
101
+ del self._cache[k]
102
+ del self._timestamps[k]
103
+ if expired:
104
+ logger.info(f"MatrixCache: cleaned up {len(expired)} expired entries")
105
+
106
+ def get_stats(self) -> dict:
107
+ """Return cache statistics."""
108
+ with self._lock:
109
+ return {
110
+ "entries": len(self._cache),
111
+ "ttl_seconds": self._ttl
112
+ }
113
+
114
+
115
+ # Global matrix cache instance
116
+ matrix_cache = MatrixCache(ttl_seconds=3600) # 60 min TTL
117
+
118
+
119
  app = FastAPI(title="Visualisable.ai Model Service", version="0.1.0")
120
 
121
  # CORS configuration for local development and production
 
1567
  import time
1568
  start_time = time.time()
1569
 
1570
+ # Generate unique request ID for matrix cache lookup
1571
+ request_id = str(uuid.uuid4())
1572
+
1573
  # Get parameters
1574
  prompt = request.get("prompt", "def quicksort(arr):")
1575
  max_tokens = request.get("max_tokens", 8)
 
1840
  k_matrix = qkv_captures[layer_idx]['k'][:, head_idx, :].float().numpy().tolist()
1841
  v_matrix = qkv_captures[layer_idx]['v'][:, head_idx, :].float().numpy().tolist()
1842
 
1843
+ # Store matrices in cache for lazy loading (reduces response size)
1844
+ matrix_cache.store(request_id, step, layer_idx, head_idx, {
1845
+ "attention_weights": attention_matrix,
1846
+ "q_matrix": q_matrix,
1847
+ "k_matrix": k_matrix,
1848
+ "v_matrix": v_matrix
1849
+ })
1850
+
1851
+ # Return only metadata (matrices fetched on-demand via /matrix endpoint)
1852
  critical_heads.append({
1853
  "head_idx": head_idx,
1854
  "entropy": entropy,
1855
  "avg_entropy": avg_entropy, # Averaged over all query positions
1856
  "max_weight": max_weight,
1857
+ "has_matrices": attention_matrix is not None, # Flag for frontend
 
 
 
1858
  "pattern": {
1859
  "type": pattern_type,
1860
  "confidence": confidence
 
1984
 
1985
  # Build response
1986
  response = {
1987
+ "requestId": request_id, # For lazy-loading matrices via /matrix endpoint
1988
  "prompt": prompt,
1989
  "promptTokens": build_token_data(prompt_token_ids, prompt_tokens, "prompt"),
1990
  "generatedTokens": build_token_data(generated_token_ids, generated_tokens, "generated"),
 
1992
  "tokenAlternatives": token_alternatives_by_step, # Top-k alternatives for each token
1993
  "layersDataByStep": layer_data_by_token, # Layer data for ALL generation steps
1994
  "layersData": layer_data_by_token[-1] if layer_data_by_token else [], # Keep for backward compatibility
1995
+ "qkvData": {}, # Deprecated: matrices now lazy-loaded via /matrix endpoint
1996
  "modelInfo": {
1997
  "numLayers": n_layers,
1998
  "numHeads": n_heads,
 
2039
  import time
2040
  start_time = time.time()
2041
 
2042
+ # Generate unique request ID for matrix cache lookup
2043
+ request_id = str(uuid.uuid4())
2044
+
2045
  # Get parameters
2046
  prompt = request.get("prompt", "def quicksort(arr):")
2047
  max_tokens = request.get("max_tokens", 8)
2048
  temperature = request.get("temperature", 0.7)
2049
 
2050
+ logger.info(f"[SSE] Research attention analysis: prompt_len={len(prompt)}, max_tokens={max_tokens}, request_id={request_id}")
2051
 
2052
  # === STAGE 1: TOKENIZING ===
2053
  yield sse_event('tokenizing', stage=1, totalStages=5, progress=2,
 
2306
  k_matrix = qkv_captures[layer_idx]['k'][:, head_idx, :].float().numpy().tolist()
2307
  v_matrix = qkv_captures[layer_idx]['v'][:, head_idx, :].float().numpy().tolist()
2308
 
2309
+ # Store matrices in cache for lazy loading (reduces response size)
2310
+ matrix_cache.store(request_id, step, layer_idx, head_idx, {
2311
+ "attention_weights": attention_matrix,
2312
+ "q_matrix": q_matrix,
2313
+ "k_matrix": k_matrix,
2314
+ "v_matrix": v_matrix
2315
+ })
2316
+
2317
+ # Return only metadata (matrices fetched on-demand via /matrix endpoint)
2318
  critical_heads.append({
2319
  "head_idx": head_idx,
2320
  "entropy": entropy,
2321
  "avg_entropy": avg_entropy, # Averaged over all query positions
2322
  "max_weight": max_weight,
2323
+ "has_matrices": attention_matrix is not None, # Flag for frontend
 
 
 
2324
  "pattern": {"type": pattern_type, "confidence": confidence} if pattern_type else None
2325
  })
2326
 
 
2443
 
2444
  # Build response
2445
  response = {
2446
+ "requestId": request_id, # For lazy-loading matrices via /matrix endpoint
2447
  "prompt": prompt,
2448
  "promptTokens": build_token_data(prompt_token_ids, prompt_tokens, "prompt"),
2449
  "generatedTokens": build_token_data(generated_token_ids, generated_tokens, "generated"),
 
2451
  "tokenAlternatives": token_alternatives_by_step,
2452
  "layersDataByStep": layer_data_by_token,
2453
  "layersData": layer_data_by_token[-1] if layer_data_by_token else [],
2454
+ "qkvData": {}, # Deprecated: matrices now lazy-loaded via /matrix endpoint
2455
  "modelInfo": {
2456
  "numLayers": n_layers,
2457
  "numHeads": n_heads,
 
2498
  )
2499
 
2500
 
2501
+ @app.get("/analyze/research/attention/matrix")
2502
+ async def get_attention_matrix(
2503
+ request_id: str,
2504
+ step: int,
2505
+ layer: int,
2506
+ head: int,
2507
+ authenticated: bool = Depends(verify_api_key)
2508
+ ):
2509
+ """
2510
+ Retrieve cached attention/QKV matrices for a specific head.
2511
+
2512
+ Used for lazy-loading matrix data when user clicks "View Matrix" in the frontend.
2513
+ Matrices are cached during the initial analysis and available for 60 minutes.
2514
+
2515
+ Parameters:
2516
+ - request_id: UUID from the original analysis response
2517
+ - step: Generation step (0 = first generated token)
2518
+ - layer: Layer index (0-based)
2519
+ - head: Head index (0-based)
2520
+
2521
+ Returns:
2522
+ - attention_weights: [seq_len, seq_len] attention matrix
2523
+ - q_matrix: [seq_len, head_dim] query projections
2524
+ - k_matrix: [seq_len, head_dim] key projections
2525
+ - v_matrix: [seq_len, head_dim] value projections
2526
+ """
2527
+ data = matrix_cache.get(request_id, step, layer, head)
2528
+ if data is None:
2529
+ logger.warning(f"Matrix cache miss: request_id={request_id}, step={step}, layer={layer}, head={head}")
2530
+ raise HTTPException(
2531
+ status_code=404,
2532
+ detail="Matrix data not found. Cache may have expired (60 min TTL). Please re-analyze."
2533
+ )
2534
+
2535
+ logger.info(f"Matrix cache hit: request_id={request_id}, step={step}, layer={layer}, head={head}")
2536
+ return data
2537
+
2538
+
2539
+ @app.get("/analyze/research/attention/matrix/stats")
2540
+ async def get_matrix_cache_stats(authenticated: bool = Depends(verify_api_key)):
2541
+ """Return matrix cache statistics for monitoring."""
2542
+ return matrix_cache.get_stats()
2543
+
2544
+
2545
  @app.post("/analyze/study")
2546
  async def analyze_study(request: StudyRequest, authenticated: bool = Depends(verify_api_key)):
2547
  """