gary-boon Claude Opus 4.5 commited on
Commit
959074d
·
1 Parent(s): bb689ce

Fix RAM exhaustion for large token generation

Browse files

Add memory management to MatrixCache:
- Track request IDs for each cache entry
- Clear old request cache entries before starting new analysis
- Force garbage collection after clearing cache
- Clear GPU/MPS cache on Apple Silicon

This prevents memory accumulation when running multiple
analyses, particularly for large models like Devstral
with 40 layers × 32 heads per step.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

Files changed (1) hide show
  1. backend/model_service.py +45 -0
backend/model_service.py CHANGED
@@ -8,6 +8,7 @@ from fastapi.responses import StreamingResponse
8
  from fastapi.middleware.cors import CORSMiddleware
9
  from pydantic import BaseModel
10
  import asyncio
 
11
  import json
12
  import os
13
  import time
@@ -69,15 +70,53 @@ class MatrixCache:
69
  def __init__(self, ttl_seconds: int = 3600):
70
  self._cache: Dict[str, Dict] = {}
71
  self._timestamps: Dict[str, float] = {}
 
72
  self._lock = Lock()
73
  self._ttl = ttl_seconds
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  def store(self, request_id: str, step: int, layer: int, head: int, data: dict):
76
  """Store matrix data for a specific head."""
77
  key = f"{request_id}:{step}:{layer}:{head}"
78
  with self._lock:
79
  self._cache[key] = data
80
  self._timestamps[key] = time_now()
 
81
 
82
  def get(self, request_id: str, step: int, layer: int, head: int) -> Optional[dict]:
83
  """Retrieve matrix data, returning None if expired or not found."""
@@ -1570,6 +1609,9 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
1570
  # Generate unique request ID for matrix cache lookup
1571
  request_id = str(uuid.uuid4())
1572
 
 
 
 
1573
  # Get parameters
1574
  prompt = request.get("prompt", "def quicksort(arr):")
1575
  max_tokens = request.get("max_tokens", 8)
@@ -2094,6 +2136,9 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2094
  # Generate unique request ID for matrix cache lookup
2095
  request_id = str(uuid.uuid4())
2096
 
 
 
 
2097
  # Get parameters
2098
  prompt = request.get("prompt", "def quicksort(arr):")
2099
  max_tokens = request.get("max_tokens", 8)
 
8
  from fastapi.middleware.cors import CORSMiddleware
9
  from pydantic import BaseModel
10
  import asyncio
11
+ import gc
12
  import json
13
  import os
14
  import time
 
70
  def __init__(self, ttl_seconds: int = 3600):
71
  self._cache: Dict[str, Dict] = {}
72
  self._timestamps: Dict[str, float] = {}
73
+ self._request_ids: set = set() # Track active request IDs
74
  self._lock = Lock()
75
  self._ttl = ttl_seconds
76
 
77
+ def clear_request(self, request_id: str):
78
+ """Clear all cache entries for a specific request."""
79
+ with self._lock:
80
+ keys_to_delete = [k for k in self._cache.keys() if k.startswith(f"{request_id}:")]
81
+ for k in keys_to_delete:
82
+ del self._cache[k]
83
+ if k in self._timestamps:
84
+ del self._timestamps[k]
85
+ self._request_ids.discard(request_id)
86
+ if keys_to_delete:
87
+ logger.info(f"MatrixCache: cleared {len(keys_to_delete)} entries for request {request_id[:8]}")
88
+
89
+ def clear_old_requests(self, keep_request_id: str = None):
90
+ """Clear all requests except the specified one to free memory."""
91
+ with self._lock:
92
+ request_ids_to_clear = self._request_ids - {keep_request_id} if keep_request_id else self._request_ids.copy()
93
+ total_cleared = 0
94
+ for rid in request_ids_to_clear:
95
+ keys_to_delete = [k for k in self._cache.keys() if k.startswith(f"{rid}:")]
96
+ for k in keys_to_delete:
97
+ del self._cache[k]
98
+ if k in self._timestamps:
99
+ del self._timestamps[k]
100
+ total_cleared += len(keys_to_delete)
101
+ self._request_ids = {keep_request_id} if keep_request_id else set()
102
+ if total_cleared:
103
+ logger.info(f"MatrixCache: cleared {total_cleared} entries from old requests")
104
+ # Force garbage collection to release memory back to system
105
+ gc.collect()
106
+ # Also clear any GPU cache if using CUDA
107
+ if torch.cuda.is_available():
108
+ torch.cuda.empty_cache()
109
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
110
+ # For Apple Silicon MPS, trigger garbage collection
111
+ torch.mps.empty_cache() if hasattr(torch.mps, 'empty_cache') else None
112
+
113
  def store(self, request_id: str, step: int, layer: int, head: int, data: dict):
114
  """Store matrix data for a specific head."""
115
  key = f"{request_id}:{step}:{layer}:{head}"
116
  with self._lock:
117
  self._cache[key] = data
118
  self._timestamps[key] = time_now()
119
+ self._request_ids.add(request_id)
120
 
121
  def get(self, request_id: str, step: int, layer: int, head: int) -> Optional[dict]:
122
  """Retrieve matrix data, returning None if expired or not found."""
 
1609
  # Generate unique request ID for matrix cache lookup
1610
  request_id = str(uuid.uuid4())
1611
 
1612
+ # Clear old cached matrices to free memory before starting new analysis
1613
+ matrix_cache.clear_old_requests(request_id)
1614
+
1615
  # Get parameters
1616
  prompt = request.get("prompt", "def quicksort(arr):")
1617
  max_tokens = request.get("max_tokens", 8)
 
2136
  # Generate unique request ID for matrix cache lookup
2137
  request_id = str(uuid.uuid4())
2138
 
2139
+ # Clear old cached matrices to free memory before starting new analysis
2140
+ matrix_cache.clear_old_requests(request_id)
2141
+
2142
  # Get parameters
2143
  prompt = request.get("prompt", "def quicksort(arr):")
2144
  max_tokens = request.get("max_tokens", 8)