Spaces:
Sleeping
Sleeping
gary-boon
Claude Opus 4.5
commited on
Commit
·
959074d
1
Parent(s):
bb689ce
Fix RAM exhaustion for large token generation
Browse filesAdd memory management to MatrixCache:
- Track request IDs for each cache entry
- Clear old request cache entries before starting new analysis
- Force garbage collection after clearing cache
- Clear GPU/MPS cache on Apple Silicon
This prevents memory accumulation when running multiple
analyses, particularly for large models like Devstral
with 40 layers × 32 heads per step.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- backend/model_service.py +45 -0
backend/model_service.py
CHANGED
|
@@ -8,6 +8,7 @@ from fastapi.responses import StreamingResponse
|
|
| 8 |
from fastapi.middleware.cors import CORSMiddleware
|
| 9 |
from pydantic import BaseModel
|
| 10 |
import asyncio
|
|
|
|
| 11 |
import json
|
| 12 |
import os
|
| 13 |
import time
|
|
@@ -69,15 +70,53 @@ class MatrixCache:
|
|
| 69 |
def __init__(self, ttl_seconds: int = 3600):
|
| 70 |
self._cache: Dict[str, Dict] = {}
|
| 71 |
self._timestamps: Dict[str, float] = {}
|
|
|
|
| 72 |
self._lock = Lock()
|
| 73 |
self._ttl = ttl_seconds
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
def store(self, request_id: str, step: int, layer: int, head: int, data: dict):
|
| 76 |
"""Store matrix data for a specific head."""
|
| 77 |
key = f"{request_id}:{step}:{layer}:{head}"
|
| 78 |
with self._lock:
|
| 79 |
self._cache[key] = data
|
| 80 |
self._timestamps[key] = time_now()
|
|
|
|
| 81 |
|
| 82 |
def get(self, request_id: str, step: int, layer: int, head: int) -> Optional[dict]:
|
| 83 |
"""Retrieve matrix data, returning None if expired or not found."""
|
|
@@ -1570,6 +1609,9 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
|
|
| 1570 |
# Generate unique request ID for matrix cache lookup
|
| 1571 |
request_id = str(uuid.uuid4())
|
| 1572 |
|
|
|
|
|
|
|
|
|
|
| 1573 |
# Get parameters
|
| 1574 |
prompt = request.get("prompt", "def quicksort(arr):")
|
| 1575 |
max_tokens = request.get("max_tokens", 8)
|
|
@@ -2094,6 +2136,9 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2094 |
# Generate unique request ID for matrix cache lookup
|
| 2095 |
request_id = str(uuid.uuid4())
|
| 2096 |
|
|
|
|
|
|
|
|
|
|
| 2097 |
# Get parameters
|
| 2098 |
prompt = request.get("prompt", "def quicksort(arr):")
|
| 2099 |
max_tokens = request.get("max_tokens", 8)
|
|
|
|
| 8 |
from fastapi.middleware.cors import CORSMiddleware
|
| 9 |
from pydantic import BaseModel
|
| 10 |
import asyncio
|
| 11 |
+
import gc
|
| 12 |
import json
|
| 13 |
import os
|
| 14 |
import time
|
|
|
|
| 70 |
def __init__(self, ttl_seconds: int = 3600):
|
| 71 |
self._cache: Dict[str, Dict] = {}
|
| 72 |
self._timestamps: Dict[str, float] = {}
|
| 73 |
+
self._request_ids: set = set() # Track active request IDs
|
| 74 |
self._lock = Lock()
|
| 75 |
self._ttl = ttl_seconds
|
| 76 |
|
| 77 |
+
def clear_request(self, request_id: str):
|
| 78 |
+
"""Clear all cache entries for a specific request."""
|
| 79 |
+
with self._lock:
|
| 80 |
+
keys_to_delete = [k for k in self._cache.keys() if k.startswith(f"{request_id}:")]
|
| 81 |
+
for k in keys_to_delete:
|
| 82 |
+
del self._cache[k]
|
| 83 |
+
if k in self._timestamps:
|
| 84 |
+
del self._timestamps[k]
|
| 85 |
+
self._request_ids.discard(request_id)
|
| 86 |
+
if keys_to_delete:
|
| 87 |
+
logger.info(f"MatrixCache: cleared {len(keys_to_delete)} entries for request {request_id[:8]}")
|
| 88 |
+
|
| 89 |
+
def clear_old_requests(self, keep_request_id: str = None):
|
| 90 |
+
"""Clear all requests except the specified one to free memory."""
|
| 91 |
+
with self._lock:
|
| 92 |
+
request_ids_to_clear = self._request_ids - {keep_request_id} if keep_request_id else self._request_ids.copy()
|
| 93 |
+
total_cleared = 0
|
| 94 |
+
for rid in request_ids_to_clear:
|
| 95 |
+
keys_to_delete = [k for k in self._cache.keys() if k.startswith(f"{rid}:")]
|
| 96 |
+
for k in keys_to_delete:
|
| 97 |
+
del self._cache[k]
|
| 98 |
+
if k in self._timestamps:
|
| 99 |
+
del self._timestamps[k]
|
| 100 |
+
total_cleared += len(keys_to_delete)
|
| 101 |
+
self._request_ids = {keep_request_id} if keep_request_id else set()
|
| 102 |
+
if total_cleared:
|
| 103 |
+
logger.info(f"MatrixCache: cleared {total_cleared} entries from old requests")
|
| 104 |
+
# Force garbage collection to release memory back to system
|
| 105 |
+
gc.collect()
|
| 106 |
+
# Also clear any GPU cache if using CUDA
|
| 107 |
+
if torch.cuda.is_available():
|
| 108 |
+
torch.cuda.empty_cache()
|
| 109 |
+
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
| 110 |
+
# For Apple Silicon MPS, trigger garbage collection
|
| 111 |
+
torch.mps.empty_cache() if hasattr(torch.mps, 'empty_cache') else None
|
| 112 |
+
|
| 113 |
def store(self, request_id: str, step: int, layer: int, head: int, data: dict):
|
| 114 |
"""Store matrix data for a specific head."""
|
| 115 |
key = f"{request_id}:{step}:{layer}:{head}"
|
| 116 |
with self._lock:
|
| 117 |
self._cache[key] = data
|
| 118 |
self._timestamps[key] = time_now()
|
| 119 |
+
self._request_ids.add(request_id)
|
| 120 |
|
| 121 |
def get(self, request_id: str, step: int, layer: int, head: int) -> Optional[dict]:
|
| 122 |
"""Retrieve matrix data, returning None if expired or not found."""
|
|
|
|
| 1609 |
# Generate unique request ID for matrix cache lookup
|
| 1610 |
request_id = str(uuid.uuid4())
|
| 1611 |
|
| 1612 |
+
# Clear old cached matrices to free memory before starting new analysis
|
| 1613 |
+
matrix_cache.clear_old_requests(request_id)
|
| 1614 |
+
|
| 1615 |
# Get parameters
|
| 1616 |
prompt = request.get("prompt", "def quicksort(arr):")
|
| 1617 |
max_tokens = request.get("max_tokens", 8)
|
|
|
|
| 2136 |
# Generate unique request ID for matrix cache lookup
|
| 2137 |
request_id = str(uuid.uuid4())
|
| 2138 |
|
| 2139 |
+
# Clear old cached matrices to free memory before starting new analysis
|
| 2140 |
+
matrix_cache.clear_old_requests(request_id)
|
| 2141 |
+
|
| 2142 |
# Get parameters
|
| 2143 |
prompt = request.get("prompt", "def quicksort(arr):")
|
| 2144 |
max_tokens = request.get("max_tokens", 8)
|