Spaces:
Sleeping
Sleeping
gary-boon
Claude Opus 4.5
commited on
Commit
·
929ba88
1
Parent(s):
66a46b6
feat: implement lazy-loading for attention matrices
Browse files- Add MatrixCache class with 60-min TTL for storing attention/QKV matrices
- Modify response builder to cache matrices instead of including in payload
- Add new endpoint /analyze/research/attention/matrix for on-demand retrieval
- Include requestId in responses for cache lookup
- Reduces initial response from 400MB+ to ~500KB-1MB
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- backend/model_service.py +135 -11
backend/model_service.py
CHANGED
|
@@ -18,6 +18,9 @@ import numpy as np
|
|
| 18 |
import logging
|
| 19 |
from datetime import datetime
|
| 20 |
import traceback
|
|
|
|
|
|
|
|
|
|
| 21 |
from .auth import verify_api_key
|
| 22 |
from .instrumentation import ModelInstrumentor, InstrumentationData, TokenMetadata
|
| 23 |
from .storage import ZarrStorage, generate_run_id
|
|
@@ -56,6 +59,63 @@ def sanitize_for_json(obj):
|
|
| 56 |
else:
|
| 57 |
return obj
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
app = FastAPI(title="Visualisable.ai Model Service", version="0.1.0")
|
| 60 |
|
| 61 |
# CORS configuration for local development and production
|
|
@@ -1507,6 +1567,9 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
|
|
| 1507 |
import time
|
| 1508 |
start_time = time.time()
|
| 1509 |
|
|
|
|
|
|
|
|
|
|
| 1510 |
# Get parameters
|
| 1511 |
prompt = request.get("prompt", "def quicksort(arr):")
|
| 1512 |
max_tokens = request.get("max_tokens", 8)
|
|
@@ -1777,15 +1840,21 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
|
|
| 1777 |
k_matrix = qkv_captures[layer_idx]['k'][:, head_idx, :].float().numpy().tolist()
|
| 1778 |
v_matrix = qkv_captures[layer_idx]['v'][:, head_idx, :].float().numpy().tolist()
|
| 1779 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1780 |
critical_heads.append({
|
| 1781 |
"head_idx": head_idx,
|
| 1782 |
"entropy": entropy,
|
| 1783 |
"avg_entropy": avg_entropy, # Averaged over all query positions
|
| 1784 |
"max_weight": max_weight,
|
| 1785 |
-
"
|
| 1786 |
-
"q_matrix": q_matrix, # [seq_len, head_dim]
|
| 1787 |
-
"k_matrix": k_matrix,
|
| 1788 |
-
"v_matrix": v_matrix,
|
| 1789 |
"pattern": {
|
| 1790 |
"type": pattern_type,
|
| 1791 |
"confidence": confidence
|
|
@@ -1915,6 +1984,7 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
|
|
| 1915 |
|
| 1916 |
# Build response
|
| 1917 |
response = {
|
|
|
|
| 1918 |
"prompt": prompt,
|
| 1919 |
"promptTokens": build_token_data(prompt_token_ids, prompt_tokens, "prompt"),
|
| 1920 |
"generatedTokens": build_token_data(generated_token_ids, generated_tokens, "generated"),
|
|
@@ -1922,7 +1992,7 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
|
|
| 1922 |
"tokenAlternatives": token_alternatives_by_step, # Top-k alternatives for each token
|
| 1923 |
"layersDataByStep": layer_data_by_token, # Layer data for ALL generation steps
|
| 1924 |
"layersData": layer_data_by_token[-1] if layer_data_by_token else [], # Keep for backward compatibility
|
| 1925 |
-
"qkvData":
|
| 1926 |
"modelInfo": {
|
| 1927 |
"numLayers": n_layers,
|
| 1928 |
"numHeads": n_heads,
|
|
@@ -1969,12 +2039,15 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 1969 |
import time
|
| 1970 |
start_time = time.time()
|
| 1971 |
|
|
|
|
|
|
|
|
|
|
| 1972 |
# Get parameters
|
| 1973 |
prompt = request.get("prompt", "def quicksort(arr):")
|
| 1974 |
max_tokens = request.get("max_tokens", 8)
|
| 1975 |
temperature = request.get("temperature", 0.7)
|
| 1976 |
|
| 1977 |
-
logger.info(f"[SSE] Research attention analysis: prompt_len={len(prompt)}, max_tokens={max_tokens}")
|
| 1978 |
|
| 1979 |
# === STAGE 1: TOKENIZING ===
|
| 1980 |
yield sse_event('tokenizing', stage=1, totalStages=5, progress=2,
|
|
@@ -2233,15 +2306,21 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2233 |
k_matrix = qkv_captures[layer_idx]['k'][:, head_idx, :].float().numpy().tolist()
|
| 2234 |
v_matrix = qkv_captures[layer_idx]['v'][:, head_idx, :].float().numpy().tolist()
|
| 2235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2236 |
critical_heads.append({
|
| 2237 |
"head_idx": head_idx,
|
| 2238 |
"entropy": entropy,
|
| 2239 |
"avg_entropy": avg_entropy, # Averaged over all query positions
|
| 2240 |
"max_weight": max_weight,
|
| 2241 |
-
"
|
| 2242 |
-
"q_matrix": q_matrix,
|
| 2243 |
-
"k_matrix": k_matrix,
|
| 2244 |
-
"v_matrix": v_matrix,
|
| 2245 |
"pattern": {"type": pattern_type, "confidence": confidence} if pattern_type else None
|
| 2246 |
})
|
| 2247 |
|
|
@@ -2364,6 +2443,7 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2364 |
|
| 2365 |
# Build response
|
| 2366 |
response = {
|
|
|
|
| 2367 |
"prompt": prompt,
|
| 2368 |
"promptTokens": build_token_data(prompt_token_ids, prompt_tokens, "prompt"),
|
| 2369 |
"generatedTokens": build_token_data(generated_token_ids, generated_tokens, "generated"),
|
|
@@ -2371,7 +2451,7 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2371 |
"tokenAlternatives": token_alternatives_by_step,
|
| 2372 |
"layersDataByStep": layer_data_by_token,
|
| 2373 |
"layersData": layer_data_by_token[-1] if layer_data_by_token else [],
|
| 2374 |
-
"qkvData":
|
| 2375 |
"modelInfo": {
|
| 2376 |
"numLayers": n_layers,
|
| 2377 |
"numHeads": n_heads,
|
|
@@ -2418,6 +2498,50 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2418 |
)
|
| 2419 |
|
| 2420 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2421 |
@app.post("/analyze/study")
|
| 2422 |
async def analyze_study(request: StudyRequest, authenticated: bool = Depends(verify_api_key)):
|
| 2423 |
"""
|
|
|
|
| 18 |
import logging
|
| 19 |
from datetime import datetime
|
| 20 |
import traceback
|
| 21 |
+
import uuid
|
| 22 |
+
from threading import Lock
|
| 23 |
+
from time import time as time_now
|
| 24 |
from .auth import verify_api_key
|
| 25 |
from .instrumentation import ModelInstrumentor, InstrumentationData, TokenMetadata
|
| 26 |
from .storage import ZarrStorage, generate_run_id
|
|
|
|
| 59 |
else:
|
| 60 |
return obj
|
| 61 |
|
| 62 |
+
|
| 63 |
+
# Matrix cache for lazy loading (60 min TTL)
|
| 64 |
+
class MatrixCache:
|
| 65 |
+
"""
|
| 66 |
+
Thread-safe in-memory cache for attention matrices.
|
| 67 |
+
Stores Q/K/V and attention weights per (request_id, step, layer, head).
|
| 68 |
+
"""
|
| 69 |
+
def __init__(self, ttl_seconds: int = 3600):
|
| 70 |
+
self._cache: Dict[str, Dict] = {}
|
| 71 |
+
self._timestamps: Dict[str, float] = {}
|
| 72 |
+
self._lock = Lock()
|
| 73 |
+
self._ttl = ttl_seconds
|
| 74 |
+
|
| 75 |
+
def store(self, request_id: str, step: int, layer: int, head: int, data: dict):
|
| 76 |
+
"""Store matrix data for a specific head."""
|
| 77 |
+
key = f"{request_id}:{step}:{layer}:{head}"
|
| 78 |
+
with self._lock:
|
| 79 |
+
self._cache[key] = data
|
| 80 |
+
self._timestamps[key] = time_now()
|
| 81 |
+
|
| 82 |
+
def get(self, request_id: str, step: int, layer: int, head: int) -> Optional[dict]:
|
| 83 |
+
"""Retrieve matrix data, returning None if expired or not found."""
|
| 84 |
+
key = f"{request_id}:{step}:{layer}:{head}"
|
| 85 |
+
with self._lock:
|
| 86 |
+
if key in self._cache:
|
| 87 |
+
if time_now() - self._timestamps[key] < self._ttl:
|
| 88 |
+
return self._cache[key]
|
| 89 |
+
else:
|
| 90 |
+
# Expired - clean up
|
| 91 |
+
del self._cache[key]
|
| 92 |
+
del self._timestamps[key]
|
| 93 |
+
return None
|
| 94 |
+
|
| 95 |
+
def cleanup_expired(self):
|
| 96 |
+
"""Remove all expired entries from cache."""
|
| 97 |
+
with self._lock:
|
| 98 |
+
now = time_now()
|
| 99 |
+
expired = [k for k, t in self._timestamps.items() if now - t >= self._ttl]
|
| 100 |
+
for k in expired:
|
| 101 |
+
del self._cache[k]
|
| 102 |
+
del self._timestamps[k]
|
| 103 |
+
if expired:
|
| 104 |
+
logger.info(f"MatrixCache: cleaned up {len(expired)} expired entries")
|
| 105 |
+
|
| 106 |
+
def get_stats(self) -> dict:
|
| 107 |
+
"""Return cache statistics."""
|
| 108 |
+
with self._lock:
|
| 109 |
+
return {
|
| 110 |
+
"entries": len(self._cache),
|
| 111 |
+
"ttl_seconds": self._ttl
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# Global matrix cache instance
|
| 116 |
+
matrix_cache = MatrixCache(ttl_seconds=3600) # 60 min TTL
|
| 117 |
+
|
| 118 |
+
|
| 119 |
app = FastAPI(title="Visualisable.ai Model Service", version="0.1.0")
|
| 120 |
|
| 121 |
# CORS configuration for local development and production
|
|
|
|
| 1567 |
import time
|
| 1568 |
start_time = time.time()
|
| 1569 |
|
| 1570 |
+
# Generate unique request ID for matrix cache lookup
|
| 1571 |
+
request_id = str(uuid.uuid4())
|
| 1572 |
+
|
| 1573 |
# Get parameters
|
| 1574 |
prompt = request.get("prompt", "def quicksort(arr):")
|
| 1575 |
max_tokens = request.get("max_tokens", 8)
|
|
|
|
| 1840 |
k_matrix = qkv_captures[layer_idx]['k'][:, head_idx, :].float().numpy().tolist()
|
| 1841 |
v_matrix = qkv_captures[layer_idx]['v'][:, head_idx, :].float().numpy().tolist()
|
| 1842 |
|
| 1843 |
+
# Store matrices in cache for lazy loading (reduces response size)
|
| 1844 |
+
matrix_cache.store(request_id, step, layer_idx, head_idx, {
|
| 1845 |
+
"attention_weights": attention_matrix,
|
| 1846 |
+
"q_matrix": q_matrix,
|
| 1847 |
+
"k_matrix": k_matrix,
|
| 1848 |
+
"v_matrix": v_matrix
|
| 1849 |
+
})
|
| 1850 |
+
|
| 1851 |
+
# Return only metadata (matrices fetched on-demand via /matrix endpoint)
|
| 1852 |
critical_heads.append({
|
| 1853 |
"head_idx": head_idx,
|
| 1854 |
"entropy": entropy,
|
| 1855 |
"avg_entropy": avg_entropy, # Averaged over all query positions
|
| 1856 |
"max_weight": max_weight,
|
| 1857 |
+
"has_matrices": attention_matrix is not None, # Flag for frontend
|
|
|
|
|
|
|
|
|
|
| 1858 |
"pattern": {
|
| 1859 |
"type": pattern_type,
|
| 1860 |
"confidence": confidence
|
|
|
|
| 1984 |
|
| 1985 |
# Build response
|
| 1986 |
response = {
|
| 1987 |
+
"requestId": request_id, # For lazy-loading matrices via /matrix endpoint
|
| 1988 |
"prompt": prompt,
|
| 1989 |
"promptTokens": build_token_data(prompt_token_ids, prompt_tokens, "prompt"),
|
| 1990 |
"generatedTokens": build_token_data(generated_token_ids, generated_tokens, "generated"),
|
|
|
|
| 1992 |
"tokenAlternatives": token_alternatives_by_step, # Top-k alternatives for each token
|
| 1993 |
"layersDataByStep": layer_data_by_token, # Layer data for ALL generation steps
|
| 1994 |
"layersData": layer_data_by_token[-1] if layer_data_by_token else [], # Keep for backward compatibility
|
| 1995 |
+
"qkvData": {}, # Deprecated: matrices now lazy-loaded via /matrix endpoint
|
| 1996 |
"modelInfo": {
|
| 1997 |
"numLayers": n_layers,
|
| 1998 |
"numHeads": n_heads,
|
|
|
|
| 2039 |
import time
|
| 2040 |
start_time = time.time()
|
| 2041 |
|
| 2042 |
+
# Generate unique request ID for matrix cache lookup
|
| 2043 |
+
request_id = str(uuid.uuid4())
|
| 2044 |
+
|
| 2045 |
# Get parameters
|
| 2046 |
prompt = request.get("prompt", "def quicksort(arr):")
|
| 2047 |
max_tokens = request.get("max_tokens", 8)
|
| 2048 |
temperature = request.get("temperature", 0.7)
|
| 2049 |
|
| 2050 |
+
logger.info(f"[SSE] Research attention analysis: prompt_len={len(prompt)}, max_tokens={max_tokens}, request_id={request_id}")
|
| 2051 |
|
| 2052 |
# === STAGE 1: TOKENIZING ===
|
| 2053 |
yield sse_event('tokenizing', stage=1, totalStages=5, progress=2,
|
|
|
|
| 2306 |
k_matrix = qkv_captures[layer_idx]['k'][:, head_idx, :].float().numpy().tolist()
|
| 2307 |
v_matrix = qkv_captures[layer_idx]['v'][:, head_idx, :].float().numpy().tolist()
|
| 2308 |
|
| 2309 |
+
# Store matrices in cache for lazy loading (reduces response size)
|
| 2310 |
+
matrix_cache.store(request_id, step, layer_idx, head_idx, {
|
| 2311 |
+
"attention_weights": attention_matrix,
|
| 2312 |
+
"q_matrix": q_matrix,
|
| 2313 |
+
"k_matrix": k_matrix,
|
| 2314 |
+
"v_matrix": v_matrix
|
| 2315 |
+
})
|
| 2316 |
+
|
| 2317 |
+
# Return only metadata (matrices fetched on-demand via /matrix endpoint)
|
| 2318 |
critical_heads.append({
|
| 2319 |
"head_idx": head_idx,
|
| 2320 |
"entropy": entropy,
|
| 2321 |
"avg_entropy": avg_entropy, # Averaged over all query positions
|
| 2322 |
"max_weight": max_weight,
|
| 2323 |
+
"has_matrices": attention_matrix is not None, # Flag for frontend
|
|
|
|
|
|
|
|
|
|
| 2324 |
"pattern": {"type": pattern_type, "confidence": confidence} if pattern_type else None
|
| 2325 |
})
|
| 2326 |
|
|
|
|
| 2443 |
|
| 2444 |
# Build response
|
| 2445 |
response = {
|
| 2446 |
+
"requestId": request_id, # For lazy-loading matrices via /matrix endpoint
|
| 2447 |
"prompt": prompt,
|
| 2448 |
"promptTokens": build_token_data(prompt_token_ids, prompt_tokens, "prompt"),
|
| 2449 |
"generatedTokens": build_token_data(generated_token_ids, generated_tokens, "generated"),
|
|
|
|
| 2451 |
"tokenAlternatives": token_alternatives_by_step,
|
| 2452 |
"layersDataByStep": layer_data_by_token,
|
| 2453 |
"layersData": layer_data_by_token[-1] if layer_data_by_token else [],
|
| 2454 |
+
"qkvData": {}, # Deprecated: matrices now lazy-loaded via /matrix endpoint
|
| 2455 |
"modelInfo": {
|
| 2456 |
"numLayers": n_layers,
|
| 2457 |
"numHeads": n_heads,
|
|
|
|
| 2498 |
)
|
| 2499 |
|
| 2500 |
|
| 2501 |
+
@app.get("/analyze/research/attention/matrix")
|
| 2502 |
+
async def get_attention_matrix(
|
| 2503 |
+
request_id: str,
|
| 2504 |
+
step: int,
|
| 2505 |
+
layer: int,
|
| 2506 |
+
head: int,
|
| 2507 |
+
authenticated: bool = Depends(verify_api_key)
|
| 2508 |
+
):
|
| 2509 |
+
"""
|
| 2510 |
+
Retrieve cached attention/QKV matrices for a specific head.
|
| 2511 |
+
|
| 2512 |
+
Used for lazy-loading matrix data when user clicks "View Matrix" in the frontend.
|
| 2513 |
+
Matrices are cached during the initial analysis and available for 60 minutes.
|
| 2514 |
+
|
| 2515 |
+
Parameters:
|
| 2516 |
+
- request_id: UUID from the original analysis response
|
| 2517 |
+
- step: Generation step (0 = first generated token)
|
| 2518 |
+
- layer: Layer index (0-based)
|
| 2519 |
+
- head: Head index (0-based)
|
| 2520 |
+
|
| 2521 |
+
Returns:
|
| 2522 |
+
- attention_weights: [seq_len, seq_len] attention matrix
|
| 2523 |
+
- q_matrix: [seq_len, head_dim] query projections
|
| 2524 |
+
- k_matrix: [seq_len, head_dim] key projections
|
| 2525 |
+
- v_matrix: [seq_len, head_dim] value projections
|
| 2526 |
+
"""
|
| 2527 |
+
data = matrix_cache.get(request_id, step, layer, head)
|
| 2528 |
+
if data is None:
|
| 2529 |
+
logger.warning(f"Matrix cache miss: request_id={request_id}, step={step}, layer={layer}, head={head}")
|
| 2530 |
+
raise HTTPException(
|
| 2531 |
+
status_code=404,
|
| 2532 |
+
detail="Matrix data not found. Cache may have expired (60 min TTL). Please re-analyze."
|
| 2533 |
+
)
|
| 2534 |
+
|
| 2535 |
+
logger.info(f"Matrix cache hit: request_id={request_id}, step={step}, layer={layer}, head={head}")
|
| 2536 |
+
return data
|
| 2537 |
+
|
| 2538 |
+
|
| 2539 |
+
@app.get("/analyze/research/attention/matrix/stats")
|
| 2540 |
+
async def get_matrix_cache_stats(authenticated: bool = Depends(verify_api_key)):
|
| 2541 |
+
"""Return matrix cache statistics for monitoring."""
|
| 2542 |
+
return matrix_cache.get_stats()
|
| 2543 |
+
|
| 2544 |
+
|
| 2545 |
@app.post("/analyze/study")
|
| 2546 |
async def analyze_study(request: StudyRequest, authenticated: bool = Depends(verify_api_key)):
|
| 2547 |
"""
|