Spaces:
Runtime error
Runtime error
Factor Studios
commited on
Update http_storage.py
Browse files- http_storage.py +47 -152
http_storage.py
CHANGED
|
@@ -158,85 +158,71 @@ class HTTPGPUStorage:
|
|
| 158 |
return False
|
| 159 |
|
| 160 |
def load_tensor(self, tensor_id: str) -> Optional[np.ndarray]:
|
| 161 |
-
"""Load tensor data
|
| 162 |
try:
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
logging.warning(f"Tensor {tensor_id} not registered in VRAM")
|
| 166 |
-
# Still try to load it in case it exists on server
|
| 167 |
-
|
| 168 |
-
response = self._make_request('GET', f'/vram/blocks/{tensor_id}')
|
| 169 |
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
if data is None:
|
| 175 |
-
logging.error(f"No data found for tensor {tensor_id}")
|
| 176 |
-
return None
|
| 177 |
-
|
| 178 |
-
try:
|
| 179 |
-
# Convert to numpy array with correct dtype
|
| 180 |
-
expected_dtype = metadata.get('dtype', 'float32')
|
| 181 |
-
expected_shape = metadata.get('shape')
|
| 182 |
-
|
| 183 |
-
arr = np.array(data, dtype=np.dtype(expected_dtype))
|
| 184 |
-
if expected_shape and arr.shape != tuple(expected_shape):
|
| 185 |
-
arr = arr.reshape(expected_shape)
|
| 186 |
-
|
| 187 |
-
# Update registry if not present
|
| 188 |
-
if tensor_id not in self.tensor_registry:
|
| 189 |
-
with self.lock:
|
| 190 |
-
self.tensor_registry[tensor_id] = metadata
|
| 191 |
-
|
| 192 |
-
return arr
|
| 193 |
-
|
| 194 |
-
except Exception as e:
|
| 195 |
-
logging.error(f"Error converting tensor data: {str(e)}")
|
| 196 |
-
return None
|
| 197 |
-
else:
|
| 198 |
-
logging.error(f"Failed to load tensor {tensor_id}: {response.get('message', 'Unknown error')}")
|
| 199 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
|
| 201 |
except Exception as e:
|
| 202 |
logging.error(f"Error loading tensor {tensor_id}: {str(e)}")
|
| 203 |
return None
|
| 204 |
|
| 205 |
def store_state(self, component: str, state_id: str, state_data: Dict[str, Any]) -> bool:
|
| 206 |
-
"""Store component state
|
| 207 |
try:
|
| 208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
"data": state_data,
|
| 210 |
"timestamp": time.time()
|
| 211 |
}
|
| 212 |
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
f'/state/{component}/{state_id}',
|
| 216 |
-
json=request_data
|
| 217 |
-
)
|
| 218 |
|
| 219 |
-
|
| 220 |
-
return True
|
| 221 |
-
else:
|
| 222 |
-
logging.error(f"Failed to store state for {component}/{state_id}: {response.get('message', 'Unknown error')}")
|
| 223 |
-
return False
|
| 224 |
|
| 225 |
except Exception as e:
|
| 226 |
logging.error(f"Error storing state for {component}/{state_id}: {str(e)}")
|
| 227 |
return False
|
| 228 |
|
| 229 |
def load_state(self, component: str, state_id: str) -> Optional[Dict[str, Any]]:
|
| 230 |
-
"""Load component state
|
| 231 |
try:
|
| 232 |
-
|
| 233 |
|
| 234 |
-
if
|
| 235 |
-
|
| 236 |
-
else:
|
| 237 |
-
logging.error(f"Failed to load state for {component}/{state_id}: {response.get('message', 'Unknown error')}")
|
| 238 |
return None
|
| 239 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
except Exception as e:
|
| 241 |
logging.error(f"Error loading state for {component}/{state_id}: {str(e)}")
|
| 242 |
return None
|
|
@@ -271,23 +257,6 @@ class HTTPGPUStorage:
|
|
| 271 |
logging.error(f"Error getting cached data for key {key}: {str(e)}")
|
| 272 |
return None
|
| 273 |
|
| 274 |
-
def is_model_loaded(self, model_name: str) -> bool:
|
| 275 |
-
"""Check if a model is loaded via HTTP API"""
|
| 276 |
-
try:
|
| 277 |
-
response = self._make_request(
|
| 278 |
-
"GET",
|
| 279 |
-
f"/models/{model_name}/status",
|
| 280 |
-
timeout=60
|
| 281 |
-
)
|
| 282 |
-
|
| 283 |
-
if response and response.get('status') == 'loaded':
|
| 284 |
-
return True
|
| 285 |
-
return False
|
| 286 |
-
|
| 287 |
-
except Exception as e:
|
| 288 |
-
logging.error(f"Error checking model status for {model_name}: {str(e)}")
|
| 289 |
-
return False
|
| 290 |
-
|
| 291 |
def load_model(self, model_name: str, model_path: Optional[str] = None, model_data: Optional[Dict] = None) -> bool:
|
| 292 |
"""Load a model from local storage"""
|
| 293 |
try:
|
|
@@ -349,92 +318,18 @@ class HTTPGPUStorage:
|
|
| 349 |
logging.error(f"Error calculating model hash: {str(e)}")
|
| 350 |
return ""
|
| 351 |
|
| 352 |
-
def start_inference(self, model_name: str, input_data: np.ndarray) -> Optional[Dict[str, Any]]:
|
| 353 |
-
"""Start inference with a loaded model via HTTP API"""
|
| 354 |
-
try:
|
| 355 |
-
if not self.is_model_loaded(model_name):
|
| 356 |
-
logging.error(f"Model {model_name} not loaded. Please load the model first.")
|
| 357 |
-
return None
|
| 358 |
-
|
| 359 |
-
request_data = {
|
| 360 |
-
"input_data": input_data.tolist() if isinstance(input_data, np.ndarray) else input_data
|
| 361 |
-
}
|
| 362 |
-
|
| 363 |
-
response = self._make_request(
|
| 364 |
-
'POST',
|
| 365 |
-
f'/models/{model_name}/inference',
|
| 366 |
-
json=request_data
|
| 367 |
-
)
|
| 368 |
-
|
| 369 |
-
if response and response.get('status') == 'success':
|
| 370 |
-
return {
|
| 371 |
-
'output': np.array(response['output']) if 'output' in response else None,
|
| 372 |
-
'metrics': response.get('metrics', {}),
|
| 373 |
-
'model_info': self.model_registry.get(model_name, {})
|
| 374 |
-
}
|
| 375 |
-
else:
|
| 376 |
-
logging.error(f"Inference failed for model {model_name}: {response.get('message', 'Unknown error')}")
|
| 377 |
-
return None
|
| 378 |
-
|
| 379 |
-
except Exception as e:
|
| 380 |
-
logging.error(f"Error during inference for model {model_name}: {str(e)}")
|
| 381 |
-
return None
|
| 382 |
|
| 383 |
def ping(self) -> bool:
|
| 384 |
-
"""
|
| 385 |
try:
|
| 386 |
-
|
| 387 |
-
|
|
|
|
|
|
|
|
|
|
| 388 |
except Exception as e:
|
| 389 |
-
logging.error(f"
|
| 390 |
return False
|
| 391 |
-
|
| 392 |
-
def is_connected(self) -> bool:
|
| 393 |
-
"""Check if the client is connected to the server."""
|
| 394 |
-
return self.ping()
|
| 395 |
-
|
| 396 |
-
def get_connection_status(self) -> Dict[str, Any]:
|
| 397 |
-
"""Get detailed connection status."""
|
| 398 |
-
if self.is_connected():
|
| 399 |
-
return {"status": "connected", "session_id": self.session_id}
|
| 400 |
-
else:
|
| 401 |
-
return {"status": "disconnected", "error_count": self.error_count}
|
| 402 |
-
|
| 403 |
-
def set_keep_alive(self, interval: int):
|
| 404 |
-
"""Set keep-alive interval (compatibility method)."""
|
| 405 |
-
logging.info(f"Keep-alive interval set to {interval} seconds (HTTP client does not use websockets).")
|
| 406 |
-
|
| 407 |
-
def reconnect(self):
|
| 408 |
-
"""Attempt to reconnect (compatibility method)."""
|
| 409 |
-
logging.info("Attempting to reconnect HTTP client...")
|
| 410 |
-
self._create_session()
|
| 411 |
-
|
| 412 |
-
def wait_for_connection(self, timeout: float = 30.0) -> bool:
|
| 413 |
-
"""Wait for HTTP connection to be established (compatibility method)"""
|
| 414 |
-
start_time = time.time()
|
| 415 |
-
while time.time() - start_time < timeout:
|
| 416 |
-
if self.is_connected():
|
| 417 |
-
logging.info("HTTP connection established.")
|
| 418 |
-
return True
|
| 419 |
-
time.sleep(1) # Wait for 1 second before retrying
|
| 420 |
-
logging.error("HTTP connection not established within timeout.")
|
| 421 |
-
return False
|
| 422 |
-
|
| 423 |
-
def close(self):
|
| 424 |
-
"""Close HTTP client"""
|
| 425 |
-
self._closing = True
|
| 426 |
-
logging.info("HTTP client is closing.")
|
| 427 |
-
# Invalidate session on server side if possible
|
| 428 |
-
if self.session_token:
|
| 429 |
-
try:
|
| 430 |
-
self.http_session.post(f"{self.api_base}/sessions/invalidate",
|
| 431 |
-
headers={'Authorization': f'Bearer {self.session_token}'},
|
| 432 |
-
timeout=5)
|
| 433 |
-
except Exception as e:
|
| 434 |
-
logging.warning(f"Failed to invalidate session on server: {e}")
|
| 435 |
-
self.http_session.close()
|
| 436 |
-
HTTPGPUStorage._instance = None # Clear singleton instance
|
| 437 |
-
|
| 438 |
# Compatibility alias for existing code
|
| 439 |
WebSocketGPUStorage = HTTPGPUStorage
|
| 440 |
|
|
|
|
| 158 |
return False
|
| 159 |
|
| 160 |
def load_tensor(self, tensor_id: str) -> Optional[np.ndarray]:
|
| 161 |
+
"""Load tensor data from local storage"""
|
| 162 |
try:
|
| 163 |
+
tensor_path = self.vram_path / f"{tensor_id}.npy"
|
| 164 |
+
metadata_path = self.vram_path / f"{tensor_id}_meta.json"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
+
# Check if tensor files exist
|
| 167 |
+
if not tensor_path.exists() or not metadata_path.exists():
|
| 168 |
+
logging.warning(f"Tensor {tensor_id} not found in local storage")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
return None
|
| 170 |
+
|
| 171 |
+
# Load metadata
|
| 172 |
+
with open(metadata_path, 'r') as f:
|
| 173 |
+
metadata = json.load(f)
|
| 174 |
+
|
| 175 |
+
# Load tensor data
|
| 176 |
+
arr = np.load(str(tensor_path))
|
| 177 |
+
|
| 178 |
+
# Update registry if not present
|
| 179 |
+
if tensor_id not in self.tensor_registry:
|
| 180 |
+
with self.lock:
|
| 181 |
+
self.tensor_registry[tensor_id] = metadata
|
| 182 |
+
|
| 183 |
+
return arr
|
| 184 |
|
| 185 |
except Exception as e:
|
| 186 |
logging.error(f"Error loading tensor {tensor_id}: {str(e)}")
|
| 187 |
return None
|
| 188 |
|
| 189 |
def store_state(self, component: str, state_id: str, state_data: Dict[str, Any]) -> bool:
|
| 190 |
+
"""Store component state in local storage"""
|
| 191 |
try:
|
| 192 |
+
# Create component directory if needed
|
| 193 |
+
component_dir = self.state_path / component
|
| 194 |
+
component_dir.mkdir(parents=True, exist_ok=True)
|
| 195 |
+
|
| 196 |
+
# Save state data with timestamp
|
| 197 |
+
state_file = component_dir / f"{state_id}.json"
|
| 198 |
+
data_to_save = {
|
| 199 |
"data": state_data,
|
| 200 |
"timestamp": time.time()
|
| 201 |
}
|
| 202 |
|
| 203 |
+
with open(state_file, 'w') as f:
|
| 204 |
+
json.dump(data_to_save, f, indent=2)
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
+
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
|
| 208 |
except Exception as e:
|
| 209 |
logging.error(f"Error storing state for {component}/{state_id}: {str(e)}")
|
| 210 |
return False
|
| 211 |
|
| 212 |
def load_state(self, component: str, state_id: str) -> Optional[Dict[str, Any]]:
|
| 213 |
+
"""Load component state from local storage"""
|
| 214 |
try:
|
| 215 |
+
state_file = self.state_path / component / f"{state_id}.json"
|
| 216 |
|
| 217 |
+
if not state_file.exists():
|
| 218 |
+
logging.warning(f"State file not found for {component}/{state_id}")
|
|
|
|
|
|
|
| 219 |
return None
|
| 220 |
|
| 221 |
+
with open(state_file, 'r') as f:
|
| 222 |
+
saved_data = json.load(f)
|
| 223 |
+
|
| 224 |
+
return saved_data.get('data')
|
| 225 |
+
|
| 226 |
except Exception as e:
|
| 227 |
logging.error(f"Error loading state for {component}/{state_id}: {str(e)}")
|
| 228 |
return None
|
|
|
|
| 257 |
logging.error(f"Error getting cached data for key {key}: {str(e)}")
|
| 258 |
return None
|
| 259 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
def load_model(self, model_name: str, model_path: Optional[str] = None, model_data: Optional[Dict] = None) -> bool:
|
| 261 |
"""Load a model from local storage"""
|
| 262 |
try:
|
|
|
|
| 318 |
logging.error(f"Error calculating model hash: {str(e)}")
|
| 319 |
return ""
|
| 320 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
|
| 322 |
def ping(self) -> bool:
|
| 323 |
+
"""Check if local storage is accessible"""
|
| 324 |
try:
|
| 325 |
+
# Check if all storage directories exist and are accessible
|
| 326 |
+
for path in [self.vram_path, self.models_path, self.cache_path, self.state_path]:
|
| 327 |
+
if not path.exists() or not os.access(str(path), os.R_OK | os.W_OK):
|
| 328 |
+
return False
|
| 329 |
+
return True
|
| 330 |
except Exception as e:
|
| 331 |
+
logging.error(f"Storage check failed: {e}")
|
| 332 |
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
# Compatibility alias for existing code
|
| 334 |
WebSocketGPUStorage = HTTPGPUStorage
|
| 335 |
|