Factor Studios commited on
Commit
aef4f5c
·
verified ·
1 Parent(s): 520d6cf

Update http_storage.py

Browse files
Files changed (1) hide show
  1. http_storage.py +441 -443
http_storage.py CHANGED
@@ -1,443 +1,441 @@
1
- import json
2
- import numpy as np
3
- from typing import Dict, Any, Optional, Union
4
- import threading
5
- import time
6
- import hashlib
7
- import logging
8
- import os
9
- import shutil
10
- import uuid
11
- from pathlib import Path
12
-
13
- class HTTPGPUStorage:
14
- """
15
- HTTP-based GPU storage client that replaces WebSocket functionality.
16
- Maintains the same interface as WebSocketGPUStorage for backward compatibility.
17
- """
18
-
19
- # Singleton instance
20
- _instance = None
21
- _lock = threading.Lock()
22
-
23
- def __new__(cls, storage_path: str = "storage"):
24
- with cls._lock:
25
- if cls._instance is None:
26
- cls._instance = super().__new__(cls)
27
- # Convert to absolute path if relative
28
- if not os.path.isabs(storage_path):
29
- storage_path = os.path.abspath(storage_path)
30
- cls._instance._init_singleton(storage_path)
31
- return cls._instance
32
-
33
- def _init_singleton(self, storage_path: str):
34
- """Initialize the singleton instance with local storage"""
35
- if hasattr(self, 'initialized'):
36
- return
37
-
38
- # Setup storage paths
39
- self.base_path = Path(storage_path)
40
- self.vram_path = self.base_path / "vram_blocks"
41
- self.models_path = self.base_path / "models"
42
- self.cache_path = self.base_path / "cache"
43
- self.state_path = self.base_path / "states"
44
-
45
- # Create directories
46
- for path in [self.vram_path, self.models_path, self.cache_path, self.state_path]:
47
- path.mkdir(parents=True, exist_ok=True)
48
-
49
- self.lock = threading.Lock()
50
- self._closing = False
51
- self.error_count = 0
52
- self.last_error_time = 0
53
- self.session_id = str(uuid.uuid4())
54
-
55
- # Tensor and model registries (maintained for compatibility)
56
- self.tensor_registry: Dict[str, Dict[str, Any]] = {}
57
- self.model_registry: Dict[str, Dict[str, Any]] = {}
58
- self.resource_monitor = {
59
- 'vram_used': 0,
60
- 'active_tensors': 0,
61
- 'loaded_models': set()
62
- }
63
-
64
- # Initialize local storage monitoring
65
- self.storage_monitor = {
66
- 'total_size': 0,
67
- 'last_access': time.time(),
68
- 'disk_usage': os.path.getsize(str(self.base_path)) if os.path.exists(str(self.base_path)) else 0
69
- }
70
-
71
- # Initialize session
72
- self._create_session()
73
- self.initialized = True
74
-
75
- def __init__(self, storage_path: str = "storage"):
76
- """This will actually just return the singleton instance.
77
- The actual initialization happens in __new__ and _init_singleton"""
78
- pass
79
-
80
- def _create_session(self):
81
- """Initialize local storage session"""
82
- try:
83
- # Create status file to track session
84
- status_path = self.base_path / "session_status.json"
85
- status_data = {
86
- "session_id": self.session_id,
87
- "created_at": time.time(),
88
- "resource_limits": {
89
- "max_vram_gb": 40, # A100 size
90
- "max_models": 5,
91
- "max_batch_size": 32
92
- }
93
- }
94
-
95
- with open(status_path, 'w') as f:
96
- json.dump(status_data, f, indent=2)
97
-
98
- logging.info(f"Local storage session created: {self.session_id}")
99
- return True
100
-
101
- except Exception as e:
102
- logging.error(f"Failed to create HTTP session: {e}")
103
- self.error_count += 1
104
- self.last_error_time = time.time()
105
- return False
106
-
107
- def _check_storage(self) -> Dict[str, Any]:
108
- """Check local storage status and usage"""
109
- try:
110
- # Update storage monitoring
111
- self.storage_monitor.update({
112
- 'total_size': sum(f.stat().st_size for f in self.base_path.rglob('*') if f.is_file()),
113
- 'last_access': time.time(),
114
- 'disk_usage': os.path.getsize(str(self.base_path)) if os.path.exists(str(self.base_path)) else 0
115
- })
116
- return {"status": "ok", "monitor": self.storage_monitor}
117
- except Exception as e:
118
- logging.error(f"Error checking storage: {e}")
119
- return {"status": "error", "message": str(e)}
120
-
121
- def store_tensor(self, tensor_id: str, data: np.ndarray, model_size: Optional[int] = None) -> bool:
122
- """Store tensor data in local storage"""
123
- try:
124
- if data is None:
125
- raise ValueError("Cannot store None tensor")
126
-
127
- # Calculate tensor metadata
128
- tensor_shape = data.shape
129
- tensor_dtype = str(data.dtype)
130
- tensor_size = data.nbytes
131
-
132
- # Save tensor data
133
- tensor_path = self.vram_path / f"{tensor_id}.npy"
134
- np.save(str(tensor_path), data)
135
-
136
- # Save metadata
137
- metadata = {
138
- 'shape': tensor_shape,
139
- 'dtype': tensor_dtype,
140
- 'size': tensor_size,
141
- 'timestamp': time.time(),
142
- 'model_size': model_size if model_size is not None else -1
143
- }
144
-
145
- metadata_path = self.vram_path / f"{tensor_id}_meta.json"
146
- with open(metadata_path, 'w') as f:
147
- json.dump(metadata, f)
148
-
149
- # Update tensor registry
150
- with self.lock:
151
- self.tensor_registry[tensor_id] = metadata
152
- self.resource_monitor['vram_used'] += tensor_size
153
- self.resource_monitor['active_tensors'] += 1
154
- return True
155
-
156
- except Exception as e:
157
- logging.error(f"Error storing tensor {tensor_id}: {str(e)}")
158
- return False
159
-
160
- def load_tensor(self, tensor_id: str) -> Optional[np.ndarray]:
161
- """Load tensor data via HTTP API"""
162
- try:
163
- # Check tensor registry first
164
- if tensor_id not in self.tensor_registry:
165
- logging.warning(f"Tensor {tensor_id} not registered in VRAM")
166
- # Still try to load it in case it exists on server
167
-
168
- response = self._make_request('GET', f'/vram/blocks/{tensor_id}')
169
-
170
- if response and response.get('status') == 'success':
171
- data = response.get('data')
172
- metadata = response.get('metadata', {})
173
-
174
- if data is None:
175
- logging.error(f"No data found for tensor {tensor_id}")
176
- return None
177
-
178
- try:
179
- # Convert to numpy array with correct dtype
180
- expected_dtype = metadata.get('dtype', 'float32')
181
- expected_shape = metadata.get('shape')
182
-
183
- arr = np.array(data, dtype=np.dtype(expected_dtype))
184
- if expected_shape and arr.shape != tuple(expected_shape):
185
- arr = arr.reshape(expected_shape)
186
-
187
- # Update registry if not present
188
- if tensor_id not in self.tensor_registry:
189
- with self.lock:
190
- self.tensor_registry[tensor_id] = metadata
191
-
192
- return arr
193
-
194
- except Exception as e:
195
- logging.error(f"Error converting tensor data: {str(e)}")
196
- return None
197
- else:
198
- logging.error(f"Failed to load tensor {tensor_id}: {response.get('message', 'Unknown error')}")
199
- return None
200
-
201
- except Exception as e:
202
- logging.error(f"Error loading tensor {tensor_id}: {str(e)}")
203
- return None
204
-
205
- def store_state(self, component: str, state_id: str, state_data: Dict[str, Any]) -> bool:
206
- """Store component state via HTTP API"""
207
- try:
208
- request_data = {
209
- "data": state_data,
210
- "timestamp": time.time()
211
- }
212
-
213
- response = self._make_request(
214
- 'POST',
215
- f'/state/{component}/{state_id}',
216
- json=request_data
217
- )
218
-
219
- if response and response.get('status') == 'success':
220
- return True
221
- else:
222
- logging.error(f"Failed to store state for {component}/{state_id}: {response.get('message', 'Unknown error')}")
223
- return False
224
-
225
- except Exception as e:
226
- logging.error(f"Error storing state for {component}/{state_id}: {str(e)}")
227
- return False
228
-
229
- def load_state(self, component: str, state_id: str) -> Optional[Dict[str, Any]]:
230
- """Load component state via HTTP API"""
231
- try:
232
- response = self._make_request("GET", f"/api/v1/state/{component}/{state_id}")
233
-
234
- if response and response.get('status') == 'success':
235
- return response.get('data')
236
- else:
237
- logging.error(f"Failed to load state for {component}/{state_id}: {response.get('message', 'Unknown error')}")
238
- return None
239
-
240
- except Exception as e:
241
- logging.error(f"Error loading state for {component}/{state_id}: {str(e)}")
242
- return None
243
-
244
- def cache_data(self, key: str, data: Any) -> bool:
245
- """Cache data via HTTP API"""
246
- try:
247
- request_data = {"data": data}
248
-
249
- response = self._make_request(
250
- 'POST',
251
- f'/cache/{key}',
252
- json=request_data
253
- )
254
-
255
- return response and response.get('status') == 'success'
256
-
257
- except Exception as e:
258
- logging.error(f"Error caching data for key {key}: {str(e)}")
259
- return False
260
-
261
- def get_cached_data(self, key: str) -> Optional[Any]:
262
- """Get cached data via HTTP API"""
263
- try:
264
- response = self._make_request("GET", f"/cache/{key}")
265
-
266
- if response and response.get('status') == 'success':
267
- return response.get('data')
268
- return None
269
-
270
- except Exception as e:
271
- logging.error(f"Error getting cached data for key {key}: {str(e)}")
272
- return None
273
-
274
- def is_model_loaded(self, model_name: str) -> bool:
275
- """Check if a model is loaded via HTTP API"""
276
- try:
277
- response = self._make_request(
278
- "GET",
279
- f"/models/{model_name}/status",
280
- timeout=60
281
- )
282
-
283
- if response and response.get('status') == 'loaded':
284
- return True
285
- return False
286
-
287
- except Exception as e:
288
- logging.error(f"Error checking model status for {model_name}: {str(e)}")
289
- return False
290
-
291
- def load_model(self, model_name: str, model_path: Optional[str] = None, model_data: Optional[Dict] = None) -> bool:
292
- """Load a model from local storage"""
293
- try:
294
- # Check if model is already loaded
295
- if self.is_model_loaded(model_name):
296
- logging.info(f"Model {model_name} already loaded")
297
- return True
298
-
299
- # Generate model directory path
300
- model_dir = self.models_path / model_name.replace('/', '_')
301
- model_dir.mkdir(parents=True, exist_ok=True)
302
-
303
- # Save model data if provided
304
- if model_data:
305
- model_config_path = model_dir / "config.json"
306
- with open(model_config_path, 'w') as f:
307
- json.dump(model_data, f, indent=2)
308
-
309
- # Update model registry
310
- with self.lock:
311
- self.model_registry[model_name] = {
312
- 'path': str(model_dir),
313
- 'config': model_data,
314
- 'loaded_at': time.time()
315
- }
316
-
317
- # Copy model files if path provided
318
- if model_path and os.path.exists(model_path):
319
- model_file_path = model_dir / "model.bin"
320
- shutil.copy2(model_path, model_file_path)
321
-
322
- logging.info(f"Successfully loaded model {model_name} to local storage")
323
- return True
324
-
325
- except Exception as e:
326
- logging.error(f"Error loading model {model_name}: {str(e)}")
327
- return False
328
-
329
- # Clean up any existing model files
330
- for existing_file in model_dir.glob('*'):
331
- try:
332
- existing_file.unlink()
333
- except Exception as e:
334
- logging.warning(f"Could not remove existing file {existing_file}: {e}")
335
-
336
- return True
337
-
338
- except Exception as e:
339
- logging.error(f"Error loading model {model_name}: {e}")
340
- return False
341
-
342
- def _calculate_model_hash(self, model_path: str) -> str:
343
- """Calculate SHA256 hash of model file"""
344
- try:
345
- sha256_hash = hashlib.sha256()
346
- with open(model_path, "rb") as f:
347
- for byte_block in iter(lambda: f.read(4096), b""):
348
- sha256_hash.update(byte_block)
349
- return sha256_hash.hexdigest()
350
- except Exception as e:
351
- logging.error(f"Error calculating model hash: {str(e)}")
352
- return ""
353
-
354
- def start_inference(self, model_name: str, input_data: np.ndarray) -> Optional[Dict[str, Any]]:
355
- """Start inference with a loaded model via HTTP API"""
356
- try:
357
- if not self.is_model_loaded(model_name):
358
- logging.error(f"Model {model_name} not loaded. Please load the model first.")
359
- return None
360
-
361
- request_data = {
362
- "input_data": input_data.tolist() if isinstance(input_data, np.ndarray) else input_data
363
- }
364
-
365
- response = self._make_request(
366
- 'POST',
367
- f'/models/{model_name}/inference',
368
- json=request_data
369
- )
370
-
371
- if response and response.get('status') == 'success':
372
- return {
373
- 'output': np.array(response['output']) if 'output' in response else None,
374
- 'metrics': response.get('metrics', {}),
375
- 'model_info': self.model_registry.get(model_name, {})
376
- }
377
- else:
378
- logging.error(f"Inference failed for model {model_name}: {response.get('message', 'Unknown error')}")
379
- return None
380
-
381
- except Exception as e:
382
- logging.error(f"Error during inference for model {model_name}: {str(e)}")
383
- return None
384
-
385
- def ping(self) -> bool:
386
- """Ping the server to check connection status."""
387
- try:
388
- response = self._make_request('GET', '/status')
389
- return response and response.get('status') == 'ok'
390
- except Exception as e:
391
- logging.error(f"Ping failed: {e}")
392
- return False
393
-
394
- def is_connected(self) -> bool:
395
- """Check if the client is connected to the server."""
396
- return self.ping()
397
-
398
- def get_connection_status(self) -> Dict[str, Any]:
399
- """Get detailed connection status."""
400
- if self.is_connected():
401
- return {"status": "connected", "session_id": self.session_id}
402
- else:
403
- return {"status": "disconnected", "error_count": self.error_count}
404
-
405
- def set_keep_alive(self, interval: int):
406
- """Set keep-alive interval (compatibility method)."""
407
- logging.info(f"Keep-alive interval set to {interval} seconds (HTTP client does not use websockets).")
408
-
409
- def reconnect(self):
410
- """Attempt to reconnect (compatibility method)."""
411
- logging.info("Attempting to reconnect HTTP client...")
412
- self._create_session()
413
-
414
- def wait_for_connection(self, timeout: float = 30.0) -> bool:
415
- """Wait for HTTP connection to be established (compatibility method)"""
416
- start_time = time.time()
417
- while time.time() - start_time < timeout:
418
- if self.is_connected():
419
- logging.info("HTTP connection established.")
420
- return True
421
- time.sleep(1) # Wait for 1 second before retrying
422
- logging.error("HTTP connection not established within timeout.")
423
- return False
424
-
425
- def close(self):
426
- """Close HTTP client"""
427
- self._closing = True
428
- logging.info("HTTP client is closing.")
429
- # Invalidate session on server side if possible
430
- if self.session_token:
431
- try:
432
- self.http_session.post(f"{self.api_base}/sessions/invalidate",
433
- headers={'Authorization': f'Bearer {self.session_token}'},
434
- timeout=5)
435
- except Exception as e:
436
- logging.warning(f"Failed to invalidate session on server: {e}")
437
- self.http_session.close()
438
- HTTPGPUStorage._instance = None # Clear singleton instance
439
-
440
- # Compatibility alias for existing code
441
- WebSocketGPUStorage = HTTPGPUStorage
442
-
443
-
 
1
+ import json
2
+ import numpy as np
3
+ from typing import Dict, Any, Optional, Union
4
+ import threading
5
+ import time
6
+ import hashlib
7
+ import logging
8
+ import os
9
+ import shutil
10
+ import uuid
11
+ from pathlib import Path
12
+
13
+ class HTTPGPUStorage:
14
+ """
15
+ HTTP-based GPU storage client that replaces WebSocket functionality.
16
+ Maintains the same interface as WebSocketGPUStorage for backward compatibility.
17
+ """
18
+
19
+ # Singleton instance
20
+ _instance = None
21
+ _lock = threading.Lock()
22
+
23
+ def __new__(cls, storage_path: str = "storage"):
24
+ with cls._lock:
25
+ if cls._instance is None:
26
+ cls._instance = super().__new__(cls)
27
+ # Convert to absolute path if relative
28
+ if not os.path.isabs(storage_path):
29
+ storage_path = os.path.abspath(storage_path)
30
+ cls._instance._init_singleton(storage_path)
31
+ return cls._instance
32
+
33
+ def _init_singleton(self, storage_path: str):
34
+ """Initialize the singleton instance with local storage"""
35
+ if hasattr(self, 'initialized'):
36
+ return
37
+
38
+ # Setup storage paths
39
+ self.base_path = Path(storage_path)
40
+ self.vram_path = self.base_path / "vram_blocks"
41
+ self.models_path = self.base_path / "models"
42
+ self.cache_path = self.base_path / "cache"
43
+ self.state_path = self.base_path / "states"
44
+
45
+ # Create directories
46
+ for path in [self.vram_path, self.models_path, self.cache_path, self.state_path]:
47
+ path.mkdir(parents=True, exist_ok=True)
48
+
49
+ self.lock = threading.Lock()
50
+ self._closing = False
51
+ self.error_count = 0
52
+ self.last_error_time = 0
53
+ self.session_id = str(uuid.uuid4())
54
+
55
+ # Tensor and model registries (maintained for compatibility)
56
+ self.tensor_registry: Dict[str, Dict[str, Any]] = {}
57
+ self.model_registry: Dict[str, Dict[str, Any]] = {}
58
+ self.resource_monitor = {
59
+ 'vram_used': 0,
60
+ 'active_tensors': 0,
61
+ 'loaded_models': set()
62
+ }
63
+
64
+ # Initialize local storage monitoring
65
+ self.storage_monitor = {
66
+ 'total_size': 0,
67
+ 'last_access': time.time(),
68
+ 'disk_usage': os.path.getsize(str(self.base_path)) if os.path.exists(str(self.base_path)) else 0
69
+ }
70
+
71
+ # Initialize session
72
+ self._create_session()
73
+ self.initialized = True
74
+
75
+ def __init__(self, storage_path: str = "storage"):
76
+ """This will actually just return the singleton instance.
77
+ The actual initialization happens in __new__ and _init_singleton"""
78
+ pass
79
+
80
+ def _create_session(self):
81
+ """Initialize local storage session"""
82
+ try:
83
+ # Create status file to track session
84
+ status_path = self.base_path / "session_status.json"
85
+ status_data = {
86
+ "session_id": self.session_id,
87
+ "created_at": time.time(),
88
+ "resource_limits": {
89
+ "max_vram_gb": 40, # A100 size
90
+ "max_models": 5,
91
+ "max_batch_size": 32
92
+ }
93
+ }
94
+
95
+ with open(status_path, 'w') as f:
96
+ json.dump(status_data, f, indent=2)
97
+
98
+ logging.info(f"Local storage session created: {self.session_id}")
99
+ return True
100
+
101
+ except Exception as e:
102
+ logging.error(f"Failed to create HTTP session: {e}")
103
+ self.error_count += 1
104
+ self.last_error_time = time.time()
105
+ return False
106
+
107
+ def _check_storage(self) -> Dict[str, Any]:
108
+ """Check local storage status and usage"""
109
+ try:
110
+ # Update storage monitoring
111
+ self.storage_monitor.update({
112
+ 'total_size': sum(f.stat().st_size for f in self.base_path.rglob('*') if f.is_file()),
113
+ 'last_access': time.time(),
114
+ 'disk_usage': os.path.getsize(str(self.base_path)) if os.path.exists(str(self.base_path)) else 0
115
+ })
116
+ return {"status": "ok", "monitor": self.storage_monitor}
117
+ except Exception as e:
118
+ logging.error(f"Error checking storage: {e}")
119
+ return {"status": "error", "message": str(e)}
120
+
121
+ def store_tensor(self, tensor_id: str, data: np.ndarray, model_size: Optional[int] = None) -> bool:
122
+ """Store tensor data in local storage"""
123
+ try:
124
+ if data is None:
125
+ raise ValueError("Cannot store None tensor")
126
+
127
+ # Calculate tensor metadata
128
+ tensor_shape = data.shape
129
+ tensor_dtype = str(data.dtype)
130
+ tensor_size = data.nbytes
131
+
132
+ # Save tensor data
133
+ tensor_path = self.vram_path / f"{tensor_id}.npy"
134
+ np.save(str(tensor_path), data)
135
+
136
+ # Save metadata
137
+ metadata = {
138
+ 'shape': tensor_shape,
139
+ 'dtype': tensor_dtype,
140
+ 'size': tensor_size,
141
+ 'timestamp': time.time(),
142
+ 'model_size': model_size if model_size is not None else -1
143
+ }
144
+
145
+ metadata_path = self.vram_path / f"{tensor_id}_meta.json"
146
+ with open(metadata_path, 'w') as f:
147
+ json.dump(metadata, f)
148
+
149
+ # Update tensor registry
150
+ with self.lock:
151
+ self.tensor_registry[tensor_id] = metadata
152
+ self.resource_monitor['vram_used'] += tensor_size
153
+ self.resource_monitor['active_tensors'] += 1
154
+ return True
155
+
156
+ except Exception as e:
157
+ logging.error(f"Error storing tensor {tensor_id}: {str(e)}")
158
+ return False
159
+
160
+ def load_tensor(self, tensor_id: str) -> Optional[np.ndarray]:
161
+ """Load tensor data via HTTP API"""
162
+ try:
163
+ # Check tensor registry first
164
+ if tensor_id not in self.tensor_registry:
165
+ logging.warning(f"Tensor {tensor_id} not registered in VRAM")
166
+ # Still try to load it in case it exists on server
167
+
168
+ response = self._make_request('GET', f'/vram/blocks/{tensor_id}')
169
+
170
+ if response and response.get('status') == 'success':
171
+ data = response.get('data')
172
+ metadata = response.get('metadata', {})
173
+
174
+ if data is None:
175
+ logging.error(f"No data found for tensor {tensor_id}")
176
+ return None
177
+
178
+ try:
179
+ # Convert to numpy array with correct dtype
180
+ expected_dtype = metadata.get('dtype', 'float32')
181
+ expected_shape = metadata.get('shape')
182
+
183
+ arr = np.array(data, dtype=np.dtype(expected_dtype))
184
+ if expected_shape and arr.shape != tuple(expected_shape):
185
+ arr = arr.reshape(expected_shape)
186
+
187
+ # Update registry if not present
188
+ if tensor_id not in self.tensor_registry:
189
+ with self.lock:
190
+ self.tensor_registry[tensor_id] = metadata
191
+
192
+ return arr
193
+
194
+ except Exception as e:
195
+ logging.error(f"Error converting tensor data: {str(e)}")
196
+ return None
197
+ else:
198
+ logging.error(f"Failed to load tensor {tensor_id}: {response.get('message', 'Unknown error')}")
199
+ return None
200
+
201
+ except Exception as e:
202
+ logging.error(f"Error loading tensor {tensor_id}: {str(e)}")
203
+ return None
204
+
205
+ def store_state(self, component: str, state_id: str, state_data: Dict[str, Any]) -> bool:
206
+ """Store component state via HTTP API"""
207
+ try:
208
+ request_data = {
209
+ "data": state_data,
210
+ "timestamp": time.time()
211
+ }
212
+
213
+ response = self._make_request(
214
+ 'POST',
215
+ f'/state/{component}/{state_id}',
216
+ json=request_data
217
+ )
218
+
219
+ if response and response.get('status') == 'success':
220
+ return True
221
+ else:
222
+ logging.error(f"Failed to store state for {component}/{state_id}: {response.get('message', 'Unknown error')}")
223
+ return False
224
+
225
+ except Exception as e:
226
+ logging.error(f"Error storing state for {component}/{state_id}: {str(e)}")
227
+ return False
228
+
229
+ def load_state(self, component: str, state_id: str) -> Optional[Dict[str, Any]]:
230
+ """Load component state via HTTP API"""
231
+ try:
232
+ response = self._make_request("GET", f"/api/v1/state/{component}/{state_id}")
233
+
234
+ if response and response.get('status') == 'success':
235
+ return response.get('data')
236
+ else:
237
+ logging.error(f"Failed to load state for {component}/{state_id}: {response.get('message', 'Unknown error')}")
238
+ return None
239
+
240
+ except Exception as e:
241
+ logging.error(f"Error loading state for {component}/{state_id}: {str(e)}")
242
+ return None
243
+
244
+ def cache_data(self, key: str, data: Any) -> bool:
245
+ """Cache data via HTTP API"""
246
+ try:
247
+ request_data = {"data": data}
248
+
249
+ response = self._make_request(
250
+ 'POST',
251
+ f'/cache/{key}',
252
+ json=request_data
253
+ )
254
+
255
+ return response and response.get('status') == 'success'
256
+
257
+ except Exception as e:
258
+ logging.error(f"Error caching data for key {key}: {str(e)}")
259
+ return False
260
+
261
+ def get_cached_data(self, key: str) -> Optional[Any]:
262
+ """Get cached data via HTTP API"""
263
+ try:
264
+ response = self._make_request("GET", f"/cache/{key}")
265
+
266
+ if response and response.get('status') == 'success':
267
+ return response.get('data')
268
+ return None
269
+
270
+ except Exception as e:
271
+ logging.error(f"Error getting cached data for key {key}: {str(e)}")
272
+ return None
273
+
274
+ def is_model_loaded(self, model_name: str) -> bool:
275
+ """Check if a model is loaded via HTTP API"""
276
+ try:
277
+ response = self._make_request(
278
+ "GET",
279
+ f"/models/{model_name}/status",
280
+ timeout=60
281
+ )
282
+
283
+ if response and response.get('status') == 'loaded':
284
+ return True
285
+ return False
286
+
287
+ except Exception as e:
288
+ logging.error(f"Error checking model status for {model_name}: {str(e)}")
289
+ return False
290
+
291
+ def load_model(self, model_name: str, model_path: Optional[str] = None, model_data: Optional[Dict] = None) -> bool:
292
+ """Load a model from local storage"""
293
+ try:
294
+ # Check if model is already loaded
295
+ if self.is_model_loaded(model_name):
296
+ logging.info(f"Model {model_name} already loaded")
297
+ return True
298
+
299
+ # Generate model directory path
300
+ model_dir = self.models_path / model_name.replace('/', '_')
301
+ model_dir.mkdir(parents=True, exist_ok=True)
302
+
303
+ # Clean up any existing files
304
+ for existing_file in model_dir.glob('*'):
305
+ try:
306
+ if existing_file.is_file():
307
+ existing_file.unlink()
308
+ except Exception as e:
309
+ logging.warning(f"Could not remove existing file {existing_file}: {e}")
310
+
311
+ # Save model data if provided
312
+ if model_data:
313
+ model_config_path = model_dir / "config.json"
314
+ with open(model_config_path, 'w') as f:
315
+ json.dump(model_data, f, indent=2)
316
+
317
+ # Update model registry
318
+ with self.lock:
319
+ self.model_registry[model_name] = {
320
+ 'path': str(model_dir),
321
+ 'config': model_data,
322
+ 'loaded_at': time.time(),
323
+ 'hash': self._calculate_model_hash(model_path) if model_path else None
324
+ }
325
+ self.resource_monitor['loaded_models'].add(model_name)
326
+
327
+ # Copy model files if path provided
328
+ if model_path and os.path.exists(model_path):
329
+ model_file_path = model_dir / "model.bin"
330
+ shutil.copy2(model_path, model_file_path)
331
+
332
+ logging.info(f"Successfully loaded model {model_name} to local storage")
333
+ return True
334
+
335
+ except Exception as e:
336
+ logging.error(f"Error loading model {model_name}: {str(e)}")
337
+ return False
338
+
339
+
340
+ def _calculate_model_hash(self, model_path: str) -> str:
341
+ """Calculate SHA256 hash of model file"""
342
+ try:
343
+ sha256_hash = hashlib.sha256()
344
+ with open(model_path, "rb") as f:
345
+ for byte_block in iter(lambda: f.read(4096), b""):
346
+ sha256_hash.update(byte_block)
347
+ return sha256_hash.hexdigest()
348
+ except Exception as e:
349
+ logging.error(f"Error calculating model hash: {str(e)}")
350
+ return ""
351
+
352
+ def start_inference(self, model_name: str, input_data: np.ndarray) -> Optional[Dict[str, Any]]:
353
+ """Start inference with a loaded model via HTTP API"""
354
+ try:
355
+ if not self.is_model_loaded(model_name):
356
+ logging.error(f"Model {model_name} not loaded. Please load the model first.")
357
+ return None
358
+
359
+ request_data = {
360
+ "input_data": input_data.tolist() if isinstance(input_data, np.ndarray) else input_data
361
+ }
362
+
363
+ response = self._make_request(
364
+ 'POST',
365
+ f'/models/{model_name}/inference',
366
+ json=request_data
367
+ )
368
+
369
+ if response and response.get('status') == 'success':
370
+ return {
371
+ 'output': np.array(response['output']) if 'output' in response else None,
372
+ 'metrics': response.get('metrics', {}),
373
+ 'model_info': self.model_registry.get(model_name, {})
374
+ }
375
+ else:
376
+ logging.error(f"Inference failed for model {model_name}: {response.get('message', 'Unknown error')}")
377
+ return None
378
+
379
+ except Exception as e:
380
+ logging.error(f"Error during inference for model {model_name}: {str(e)}")
381
+ return None
382
+
383
+ def ping(self) -> bool:
384
+ """Ping the server to check connection status."""
385
+ try:
386
+ response = self._make_request('GET', '/status')
387
+ return response and response.get('status') == 'ok'
388
+ except Exception as e:
389
+ logging.error(f"Ping failed: {e}")
390
+ return False
391
+
392
+ def is_connected(self) -> bool:
393
+ """Check if the client is connected to the server."""
394
+ return self.ping()
395
+
396
+ def get_connection_status(self) -> Dict[str, Any]:
397
+ """Get detailed connection status."""
398
+ if self.is_connected():
399
+ return {"status": "connected", "session_id": self.session_id}
400
+ else:
401
+ return {"status": "disconnected", "error_count": self.error_count}
402
+
403
+ def set_keep_alive(self, interval: int):
404
+ """Set keep-alive interval (compatibility method)."""
405
+ logging.info(f"Keep-alive interval set to {interval} seconds (HTTP client does not use websockets).")
406
+
407
+ def reconnect(self):
408
+ """Attempt to reconnect (compatibility method)."""
409
+ logging.info("Attempting to reconnect HTTP client...")
410
+ self._create_session()
411
+
412
+ def wait_for_connection(self, timeout: float = 30.0) -> bool:
413
+ """Wait for HTTP connection to be established (compatibility method)"""
414
+ start_time = time.time()
415
+ while time.time() - start_time < timeout:
416
+ if self.is_connected():
417
+ logging.info("HTTP connection established.")
418
+ return True
419
+ time.sleep(1) # Wait for 1 second before retrying
420
+ logging.error("HTTP connection not established within timeout.")
421
+ return False
422
+
423
+ def close(self):
424
+ """Close HTTP client"""
425
+ self._closing = True
426
+ logging.info("HTTP client is closing.")
427
+ # Invalidate session on server side if possible
428
+ if self.session_token:
429
+ try:
430
+ self.http_session.post(f"{self.api_base}/sessions/invalidate",
431
+ headers={'Authorization': f'Bearer {self.session_token}'},
432
+ timeout=5)
433
+ except Exception as e:
434
+ logging.warning(f"Failed to invalidate session on server: {e}")
435
+ self.http_session.close()
436
+ HTTPGPUStorage._instance = None # Clear singleton instance
437
+
438
+ # Compatibility alias for existing code
439
+ WebSocketGPUStorage = HTTPGPUStorage
440
+
441
+