Factor Studios commited on
Commit
9105b87
·
verified ·
1 Parent(s): e264c14

Update http_storage.py

Browse files
Files changed (1) hide show
  1. http_storage.py +496 -475
http_storage.py CHANGED
@@ -1,475 +1,496 @@
1
- import requests
2
- import json
3
- import numpy as np
4
- from typing import Dict, Any, Optional, Union
5
- import threading
6
- import time
7
- import hashlib
8
- import logging
9
- from requests.adapters import HTTPAdapter
10
- from urllib3.util.retry import Retry
11
-
12
- class HTTPGPUStorage:
13
- """
14
- HTTP-based GPU storage client that replaces WebSocket functionality.
15
- Maintains the same interface as WebSocketGPUStorage for backward compatibility.
16
- """
17
-
18
- # Singleton instance
19
- _instance = None
20
- _lock = threading.Lock()
21
-
22
- def __new__(cls, base_url: str = "https://factorst-intiv.hf.space"):
23
- with cls._lock:
24
- if cls._instance is None:
25
- cls._instance = super().__new__(cls)
26
- cls._instance._init_singleton(base_url)
27
- return cls._instance
28
-
29
- def _init_singleton(self, base_url: str):
30
- """Initialize the singleton instance"""
31
- if hasattr(self, 'initialized'):
32
- return
33
-
34
- self.base_url = base_url.rstrip('/')
35
- self.api_base = f"{self.base_url}/api/v1"
36
- self.session_token = None
37
- self.session_id = None
38
- self.lock = threading.Lock()
39
- self._closing = False
40
- self.error_count = 0
41
- self.last_error_time = 0
42
- self.max_retries = 5
43
-
44
- # Tensor and model registries (maintained for compatibility)
45
- self.tensor_registry: Dict[str, Dict[str, Any]] = {}
46
- self.model_registry: Dict[str, Dict[str, Any]] = {}
47
- self.resource_monitor = {
48
- 'vram_used': 0,
49
- 'active_tensors': 0,
50
- 'loaded_models': set()
51
- }
52
-
53
- # Configure HTTP session with connection pooling and retries
54
- self.http_session = requests.Session()
55
-
56
- # Configure retry strategy
57
- retry_strategy = Retry(
58
- total=3,
59
- status_forcelist=[429, 500, 502, 503, 504],
60
- allowed_methods=["HEAD", "GET", "OPTIONS", "POST", "PUT", "DELETE"], # Updated parameter name
61
- backoff_factor=1
62
- )
63
-
64
- adapter = HTTPAdapter(
65
- max_retries=retry_strategy,
66
- pool_connections=10,
67
- pool_maxsize=20
68
- )
69
-
70
- self.http_session.mount("http://", adapter)
71
- self.http_session.mount("https://", adapter)
72
-
73
- # Set default headers
74
- self.http_session.headers.update({
75
- 'Content-Type': 'application/json',
76
- 'User-Agent': 'VirtualGPU-HTTP-Client/2.0'
77
- })
78
-
79
- # Initialize session
80
- self._create_session()
81
- self.initialized = True
82
-
83
- def __init__(self, base_url: str = "https://factorst-intiv.hf.space"):
84
- """This will actually just return the singleton instance"""
85
- pass
86
-
87
- def _create_session(self):
88
- """Create HTTP session with the server"""
89
- try:
90
- response = self.http_session.post(
91
- f"{self.api_base}/sessions",
92
- json={"client_id": "virtual_gpu_client"},
93
- timeout=30
94
- )
95
- response.raise_for_status()
96
-
97
- session_data = response.json()
98
- self.session_token = session_data['session_token']
99
- self.session_id = session_data['session_id']
100
-
101
- # Update session headers
102
- self.http_session.headers.update({
103
- 'Authorization': f'Bearer {self.session_token}'
104
- })
105
-
106
- logging.info(f"HTTP session created: {self.session_id}")
107
- return True
108
-
109
- except Exception as e:
110
- logging.error(f"Failed to create HTTP session: {e}")
111
- self.error_count += 1
112
- self.last_error_time = time.time()
113
- return False
114
-
115
- def _make_request(self, method: str, endpoint: str, **kwargs) -> Optional[Dict[str, Any]]:
116
- """Make HTTP request with error handling and retries"""
117
- if self._closing:
118
- return {"status": "error", "message": "HTTP client is closing"}
119
-
120
- url = f"{self.api_base}{endpoint}"
121
-
122
- try:
123
- # Ensure we have a valid session
124
- if not self.session_token:
125
- if not self._create_session():
126
- return {"status": "error", "message": "Failed to create session"}
127
-
128
- response = self.http_session.request(method, url, timeout=30, **kwargs)
129
-
130
- # Handle authentication errors by recreating session
131
- if response.status_code == 401:
132
- logging.warning("Session expired, recreating...")
133
- if self._create_session():
134
- response = self.http_session.request(method, url, timeout=30, **kwargs)
135
- else:
136
- return {"status": "error", "message": "Failed to recreate session"}
137
-
138
- response.raise_for_status()
139
-
140
- # Reset error count on successful request
141
- self.error_count = 0
142
-
143
- return response.json()
144
-
145
- except requests.exceptions.RequestException as e:
146
- self.error_count += 1
147
- self.last_error_time = time.time()
148
- logging.error(f"HTTP request failed: {e}")
149
- return {"status": "error", "message": f"HTTP request failed: {str(e)}"}
150
- except Exception as e:
151
- self.error_count += 1
152
- self.last_error_time = time.time()
153
- logging.error(f"Unexpected error in HTTP request: {e}")
154
- return {"status": "error", "message": f"Unexpected error: {str(e)}"}
155
-
156
- def store_tensor(self, tensor_id: str, data: np.ndarray, model_size: Optional[int] = None) -> bool:
157
- """Store tensor data via HTTP API"""
158
- try:
159
- if data is None:
160
- raise ValueError("Cannot store None tensor")
161
-
162
- # Calculate tensor metadata
163
- tensor_shape = data.shape
164
- tensor_dtype = str(data.dtype)
165
- tensor_size = data.nbytes
166
-
167
- request_data = {
168
- "data": data.tolist(),
169
- "metadata": {
170
- 'shape': tensor_shape,
171
- 'dtype': tensor_dtype,
172
- 'size': tensor_size,
173
- 'timestamp': time.time()
174
- },
175
- "model_size": model_size if model_size is not None else -1
176
- }
177
-
178
- response = self._make_request(
179
- 'POST',
180
- f'/vram/blocks/{tensor_id}',
181
- json=request_data
182
- )
183
-
184
- if response and response.get('status') == 'success':
185
- # Update tensor registry
186
- with self.lock:
187
- self.tensor_registry[tensor_id] = {
188
- 'shape': tensor_shape,
189
- 'dtype': tensor_dtype,
190
- 'size': tensor_size,
191
- 'timestamp': time.time()
192
- }
193
- self.resource_monitor['vram_used'] += tensor_size
194
- self.resource_monitor['active_tensors'] += 1
195
- return True
196
- else:
197
- logging.error(f"Failed to store tensor {tensor_id}: {response.get('message', 'Unknown error')}")
198
- return False
199
-
200
- except Exception as e:
201
- logging.error(f"Error storing tensor {tensor_id}: {str(e)}")
202
- return False
203
-
204
- def load_tensor(self, tensor_id: str) -> Optional[np.ndarray]:
205
- """Load tensor data via HTTP API"""
206
- try:
207
- # Check tensor registry first
208
- if tensor_id not in self.tensor_registry:
209
- logging.warning(f"Tensor {tensor_id} not registered in VRAM")
210
- # Still try to load it in case it exists on server
211
-
212
- response = self._make_request('GET', f'/vram/blocks/{tensor_id}')
213
-
214
- if response and response.get('status') == 'success':
215
- data = response.get('data')
216
- metadata = response.get('metadata', {})
217
-
218
- if data is None:
219
- logging.error(f"No data found for tensor {tensor_id}")
220
- return None
221
-
222
- try:
223
- # Convert to numpy array with correct dtype
224
- expected_dtype = metadata.get('dtype', 'float32')
225
- expected_shape = metadata.get('shape')
226
-
227
- arr = np.array(data, dtype=np.dtype(expected_dtype))
228
- if expected_shape and arr.shape != tuple(expected_shape):
229
- arr = arr.reshape(expected_shape)
230
-
231
- # Update registry if not present
232
- if tensor_id not in self.tensor_registry:
233
- with self.lock:
234
- self.tensor_registry[tensor_id] = metadata
235
-
236
- return arr
237
-
238
- except Exception as e:
239
- logging.error(f"Error converting tensor data: {str(e)}")
240
- return None
241
- else:
242
- logging.error(f"Failed to load tensor {tensor_id}: {response.get('message', 'Unknown error')}")
243
- return None
244
-
245
- except Exception as e:
246
- logging.error(f"Error loading tensor {tensor_id}: {str(e)}")
247
- return None
248
-
249
- def store_state(self, component: str, state_id: str, state_data: Dict[str, Any]) -> bool:
250
- """Store component state via HTTP API"""
251
- try:
252
- request_data = {
253
- "data": state_data,
254
- "timestamp": time.time()
255
- }
256
-
257
- response = self._make_request(
258
- 'POST',
259
- f'/state/{component}/{state_id}',
260
- json=request_data
261
- )
262
-
263
- if response and response.get('status') == 'success':
264
- return True
265
- else:
266
- logging.error(f"Failed to store state for {component}/{state_id}: {response.get('message', 'Unknown error')}")
267
- return False
268
-
269
- except Exception as e:
270
- logging.error(f"Error storing state for {component}/{state_id}: {str(e)}")
271
- return False
272
-
273
- def load_state(self, component: str, state_id: str) -> Optional[Dict[str, Any]]:
274
- """Load component state via HTTP API"""
275
- try:
276
- response = self._make_request("GET", f"/api/v1/state/{component}/{state_id}")
277
-
278
- if response and response.get('status') == 'success':
279
- return response.get('data')
280
- else:
281
- logging.error(f"Failed to load state for {component}/{state_id}: {response.get('message', 'Unknown error')}")
282
- return None
283
-
284
- except Exception as e:
285
- logging.error(f"Error loading state for {component}/{state_id}: {str(e)}")
286
- return None
287
-
288
- def cache_data(self, key: str, data: Any) -> bool:
289
- """Cache data via HTTP API"""
290
- try:
291
- request_data = {"data": data}
292
-
293
- response = self._make_request(
294
- 'POST',
295
- f'/cache/{key}',
296
- json=request_data
297
- )
298
-
299
- return response and response.get('status') == 'success'
300
-
301
- except Exception as e:
302
- logging.error(f"Error caching data for key {key}: {str(e)}")
303
- return False
304
-
305
- def get_cached_data(self, key: str) -> Optional[Any]:
306
- """Get cached data via HTTP API"""
307
- try:
308
- response = self._make_request("GET", f"/cache/{key}")
309
-
310
- if response and response.get('status') == 'success':
311
- return response.get('data')
312
- return None
313
-
314
- except Exception as e:
315
- logging.error(f"Error getting cached data for key {key}: {str(e)}")
316
- return None
317
-
318
- def is_model_loaded(self, model_name: str) -> bool:
319
- """Check if a model is loaded via HTTP API"""
320
- try:
321
- response = self._make_request("GET", f"/models/{model_name}/status")
322
-
323
- if response and response.get('status') == 'loaded':
324
- return True
325
- return False
326
-
327
- except Exception as e:
328
- logging.error(f"Error checking model status for {model_name}: {str(e)}")
329
- return False
330
-
331
- def load_model(self, model_name: str, model_path: Optional[str] = None, model_data: Optional[Dict] = None) -> bool:
332
- """Load a model via HTTP API"""
333
- try:
334
- # Check if model is already loaded
335
- if self.is_model_loaded(model_name):
336
- logging.info(f"Model {model_name} already loaded")
337
- return True
338
-
339
- # Calculate model hash if path provided
340
- model_hash = None
341
- if model_path:
342
- model_hash = self._calculate_model_hash(model_path)
343
-
344
- request_data = {
345
- "model_data": model_data,
346
- "model_path": model_path,
347
- "model_hash": model_hash
348
- }
349
-
350
- response = self._make_request(
351
- 'POST',
352
- f'/models/{model_name}/load',
353
- json=request_data
354
- )
355
-
356
- if response and response.get('status') == 'success':
357
- with self.lock:
358
- self.model_registry[model_name] = {
359
- 'hash': model_hash,
360
- 'timestamp': time.time(),
361
- 'model_data': model_data
362
- }
363
- self.resource_monitor['loaded_models'].add(model_name)
364
- logging.info(f"Successfully loaded model {model_name}")
365
- return True
366
- else:
367
- logging.error(f"Failed to load model {model_name}: {response.get('message', 'Unknown error')}")
368
- return False
369
-
370
- except Exception as e:
371
- logging.error(f"Error loading model {model_name}: {str(e)}")
372
- return False
373
-
374
- def _calculate_model_hash(self, model_path: str) -> str:
375
- """Calculate SHA256 hash of model file"""
376
- try:
377
- sha256_hash = hashlib.sha256()
378
- with open(model_path, "rb") as f:
379
- for byte_block in iter(lambda: f.read(4096), b""):
380
- sha256_hash.update(byte_block)
381
- return sha256_hash.hexdigest()
382
- except Exception as e:
383
- logging.error(f"Error calculating model hash: {str(e)}")
384
- return ""
385
-
386
- def start_inference(self, model_name: str, input_data: np.ndarray) -> Optional[Dict[str, Any]]:
387
- """Start inference with a loaded model via HTTP API"""
388
- try:
389
- if not self.is_model_loaded(model_name):
390
- logging.error(f"Model {model_name} not loaded. Please load the model first.")
391
- return None
392
-
393
- request_data = {
394
- "input_data": input_data.tolist() if isinstance(input_data, np.ndarray) else input_data
395
- }
396
-
397
- response = self._make_request(
398
- 'POST',
399
- f'/models/{model_name}/inference',
400
- json=request_data
401
- )
402
-
403
- if response and response.get('status') == 'success':
404
- return {
405
- 'output': np.array(response['output']) if 'output' in response else None,
406
- 'metrics': response.get('metrics', {}),
407
- 'model_info': self.model_registry.get(model_name, {})
408
- }
409
- else:
410
- logging.error(f"Inference failed for model {model_name}: {response.get('message', 'Unknown error')}")
411
- return None
412
-
413
- except Exception as e:
414
- logging.error(f"Error during inference for model {model_name}: {str(e)}")
415
- return None
416
-
417
- def ping(self) -> bool:
418
- """Ping the server to check connection status."""
419
- try:
420
- response = self._make_request('GET', '/status')
421
- return response and response.get('status') == 'ok'
422
- except Exception as e:
423
- logging.error(f"Ping failed: {e}")
424
- return False
425
-
426
- def is_connected(self) -> bool:
427
- """Check if the client is connected to the server."""
428
- return self.ping()
429
-
430
- def get_connection_status(self) -> Dict[str, Any]:
431
- """Get detailed connection status."""
432
- if self.is_connected():
433
- return {"status": "connected", "session_id": self.session_id}
434
- else:
435
- return {"status": "disconnected", "error_count": self.error_count}
436
-
437
- def set_keep_alive(self, interval: int):
438
- """Set keep-alive interval (compatibility method)."""
439
- logging.info(f"Keep-alive interval set to {interval} seconds (HTTP client does not use websockets).")
440
-
441
- def reconnect(self):
442
- """Attempt to reconnect (compatibility method)."""
443
- logging.info("Attempting to reconnect HTTP client...")
444
- self._create_session()
445
-
446
- def wait_for_connection(self, timeout: float = 30.0) -> bool:
447
- """Wait for HTTP connection to be established (compatibility method)"""
448
- start_time = time.time()
449
- while time.time() - start_time < timeout:
450
- if self.is_connected():
451
- logging.info("HTTP connection established.")
452
- return True
453
- time.sleep(1) # Wait for 1 second before retrying
454
- logging.error("HTTP connection not established within timeout.")
455
- return False
456
-
457
- def close(self):
458
- """Close HTTP client"""
459
- self._closing = True
460
- logging.info("HTTP client is closing.")
461
- # Invalidate session on server side if possible
462
- if self.session_token:
463
- try:
464
- self.http_session.post(f"{self.api_base}/sessions/invalidate",
465
- headers={'Authorization': f'Bearer {self.session_token}'},
466
- timeout=5)
467
- except Exception as e:
468
- logging.warning(f"Failed to invalidate session on server: {e}")
469
- self.http_session.close()
470
- HTTPGPUStorage._instance = None # Clear singleton instance
471
-
472
- # Compatibility alias for existing code
473
- WebSocketGPUStorage = HTTPGPUStorage
474
-
475
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+ import numpy as np
4
+ from typing import Dict, Any, Optional, Union
5
+ import threading
6
+ import time
7
+ import hashlib
8
+ import logging
9
+ from requests.adapters import HTTPAdapter
10
+ from urllib3.util.retry import Retry
11
+
12
+ class HTTPGPUStorage:
13
+ """
14
+ HTTP-based GPU storage client that replaces WebSocket functionality.
15
+ Maintains the same interface as WebSocketGPUStorage for backward compatibility.
16
+ """
17
+
18
+ # Singleton instance
19
+ _instance = None
20
+ _lock = threading.Lock()
21
+
22
+ def __new__(cls, base_url: str = "https://factorst-intiv.hf.space"):
23
+ with cls._lock:
24
+ if cls._instance is None:
25
+ cls._instance = super().__new__(cls)
26
+ cls._instance._init_singleton(base_url)
27
+ return cls._instance
28
+
29
+ def _init_singleton(self, base_url: str):
30
+ """Initialize the singleton instance"""
31
+ if hasattr(self, 'initialized'):
32
+ return
33
+
34
+ self.base_url = base_url.rstrip('/')
35
+ self.api_base = f"{self.base_url}/api/v1"
36
+ self.session_token = None
37
+ self.session_id = None
38
+ self.lock = threading.Lock()
39
+ self._closing = False
40
+ self.error_count = 0
41
+ self.last_error_time = 0
42
+ self.max_retries = 5
43
+
44
+ # Tensor and model registries (maintained for compatibility)
45
+ self.tensor_registry: Dict[str, Dict[str, Any]] = {}
46
+ self.model_registry: Dict[str, Dict[str, Any]] = {}
47
+ self.resource_monitor = {
48
+ 'vram_used': 0,
49
+ 'active_tensors': 0,
50
+ 'loaded_models': set()
51
+ }
52
+
53
+ # Configure HTTP session with connection pooling and retries
54
+ self.http_session = requests.Session()
55
+
56
+ # Configure retry strategy
57
+ retry_strategy = Retry(
58
+ total=3,
59
+ status_forcelist=[429, 500, 502, 503, 504],
60
+ allowed_methods=["HEAD", "GET", "OPTIONS", "POST", "PUT", "DELETE"], # Updated parameter name
61
+ backoff_factor=1
62
+ )
63
+
64
+ adapter = HTTPAdapter(
65
+ max_retries=retry_strategy,
66
+ pool_connections=10,
67
+ pool_maxsize=20
68
+ )
69
+
70
+ self.http_session.mount("http://", adapter)
71
+ self.http_session.mount("https://", adapter)
72
+
73
+ # Set default headers
74
+ self.http_session.headers.update({
75
+ 'Content-Type': 'application/json',
76
+ 'User-Agent': 'VirtualGPU-HTTP-Client/2.0'
77
+ })
78
+
79
+ # Initialize session
80
+ self._create_session()
81
+ self.initialized = True
82
+
83
+ def __init__(self, base_url: str = "https://factorst-intiv.hf.space"):
84
+ """This will actually just return the singleton instance"""
85
+ pass
86
+
87
+ def _create_session(self):
88
+ """Create HTTP session with the server"""
89
+ try:
90
+ response = self.http_session.post(
91
+ f"{self.api_base}/sessions",
92
+ json={"client_id": "virtual_gpu_client"},
93
+ timeout=30
94
+ )
95
+ response.raise_for_status()
96
+
97
+ session_data = response.json()
98
+ self.session_token = session_data['session_token']
99
+ self.session_id = session_data['session_id']
100
+
101
+ # Update session headers
102
+ self.http_session.headers.update({
103
+ 'Authorization': f'Bearer {self.session_token}'
104
+ })
105
+
106
+ logging.info(f"HTTP session created: {self.session_id}")
107
+ return True
108
+
109
+ except Exception as e:
110
+ logging.error(f"Failed to create HTTP session: {e}")
111
+ self.error_count += 1
112
+ self.last_error_time = time.time()
113
+ return False
114
+
115
+ def _make_request(self, method: str, endpoint: str, **kwargs) -> Optional[Dict[str, Any]]:
116
+ """Make HTTP request with error handling and retries"""
117
+ if self._closing:
118
+ return {"status": "error", "message": "HTTP client is closing"}
119
+
120
+ url = f"{self.api_base}{endpoint}"
121
+
122
+ try:
123
+ # Ensure we have a valid session
124
+ if not self.session_token:
125
+ if not self._create_session():
126
+ return {"status": "error", "message": "Failed to create session"}
127
+
128
+ response = self.http_session.request(method, url, timeout=30, **kwargs)
129
+
130
+ # Handle authentication errors by recreating session
131
+ if response.status_code == 401:
132
+ logging.warning("Session expired, recreating...")
133
+ if self._create_session():
134
+ response = self.http_session.request(method, url, timeout=30, **kwargs)
135
+ else:
136
+ return {"status": "error", "message": "Failed to recreate session"}
137
+
138
+ response.raise_for_status()
139
+
140
+ # Reset error count on successful request
141
+ self.error_count = 0
142
+
143
+ return response.json()
144
+
145
+ except requests.exceptions.RequestException as e:
146
+ self.error_count += 1
147
+ self.last_error_time = time.time()
148
+ logging.error(f"HTTP request failed: {e}")
149
+ return {"status": "error", "message": f"HTTP request failed: {str(e)}"}
150
+ except Exception as e:
151
+ self.error_count += 1
152
+ self.last_error_time = time.time()
153
+ logging.error(f"Unexpected error in HTTP request: {e}")
154
+ return {"status": "error", "message": f"Unexpected error: {str(e)}"}
155
+
156
+ def store_tensor(self, tensor_id: str, data: np.ndarray, model_size: Optional[int] = None) -> bool:
157
+ """Store tensor data via HTTP API"""
158
+ try:
159
+ if data is None:
160
+ raise ValueError("Cannot store None tensor")
161
+
162
+ # Calculate tensor metadata
163
+ tensor_shape = data.shape
164
+ tensor_dtype = str(data.dtype)
165
+ tensor_size = data.nbytes
166
+
167
+ request_data = {
168
+ "data": data.tolist(),
169
+ "metadata": {
170
+ 'shape': tensor_shape,
171
+ 'dtype': tensor_dtype,
172
+ 'size': tensor_size,
173
+ 'timestamp': time.time()
174
+ },
175
+ "model_size": model_size if model_size is not None else -1
176
+ }
177
+
178
+ response = self._make_request(
179
+ 'POST',
180
+ f'/vram/blocks/{tensor_id}',
181
+ json=request_data
182
+ )
183
+
184
+ if response and response.get('status') == 'success':
185
+ # Update tensor registry
186
+ with self.lock:
187
+ self.tensor_registry[tensor_id] = {
188
+ 'shape': tensor_shape,
189
+ 'dtype': tensor_dtype,
190
+ 'size': tensor_size,
191
+ 'timestamp': time.time()
192
+ }
193
+ self.resource_monitor['vram_used'] += tensor_size
194
+ self.resource_monitor['active_tensors'] += 1
195
+ return True
196
+ else:
197
+ logging.error(f"Failed to store tensor {tensor_id}: {response.get('message', 'Unknown error')}")
198
+ return False
199
+
200
+ except Exception as e:
201
+ logging.error(f"Error storing tensor {tensor_id}: {str(e)}")
202
+ return False
203
+
204
+ def load_tensor(self, tensor_id: str) -> Optional[np.ndarray]:
205
+ """Load tensor data via HTTP API"""
206
+ try:
207
+ # Check tensor registry first
208
+ if tensor_id not in self.tensor_registry:
209
+ logging.warning(f"Tensor {tensor_id} not registered in VRAM")
210
+ # Still try to load it in case it exists on server
211
+
212
+ response = self._make_request('GET', f'/vram/blocks/{tensor_id}')
213
+
214
+ if response and response.get('status') == 'success':
215
+ data = response.get('data')
216
+ metadata = response.get('metadata', {})
217
+
218
+ if data is None:
219
+ logging.error(f"No data found for tensor {tensor_id}")
220
+ return None
221
+
222
+ try:
223
+ # Convert to numpy array with correct dtype
224
+ expected_dtype = metadata.get('dtype', 'float32')
225
+ expected_shape = metadata.get('shape')
226
+
227
+ arr = np.array(data, dtype=np.dtype(expected_dtype))
228
+ if expected_shape and arr.shape != tuple(expected_shape):
229
+ arr = arr.reshape(expected_shape)
230
+
231
+ # Update registry if not present
232
+ if tensor_id not in self.tensor_registry:
233
+ with self.lock:
234
+ self.tensor_registry[tensor_id] = metadata
235
+
236
+ return arr
237
+
238
+ except Exception as e:
239
+ logging.error(f"Error converting tensor data: {str(e)}")
240
+ return None
241
+ else:
242
+ logging.error(f"Failed to load tensor {tensor_id}: {response.get('message', 'Unknown error')}")
243
+ return None
244
+
245
+ except Exception as e:
246
+ logging.error(f"Error loading tensor {tensor_id}: {str(e)}")
247
+ return None
248
+
249
+ def store_state(self, component: str, state_id: str, state_data: Dict[str, Any]) -> bool:
250
+ """Store component state via HTTP API"""
251
+ try:
252
+ request_data = {
253
+ "data": state_data,
254
+ "timestamp": time.time()
255
+ }
256
+
257
+ response = self._make_request(
258
+ 'POST',
259
+ f'/state/{component}/{state_id}',
260
+ json=request_data
261
+ )
262
+
263
+ if response and response.get('status') == 'success':
264
+ return True
265
+ else:
266
+ logging.error(f"Failed to store state for {component}/{state_id}: {response.get('message', 'Unknown error')}")
267
+ return False
268
+
269
+ except Exception as e:
270
+ logging.error(f"Error storing state for {component}/{state_id}: {str(e)}")
271
+ return False
272
+
273
+ def load_state(self, component: str, state_id: str) -> Optional[Dict[str, Any]]:
274
+ """Load component state via HTTP API"""
275
+ try:
276
+ response = self._make_request("GET", f"/api/v1/state/{component}/{state_id}")
277
+
278
+ if response and response.get('status') == 'success':
279
+ return response.get('data')
280
+ else:
281
+ logging.error(f"Failed to load state for {component}/{state_id}: {response.get('message', 'Unknown error')}")
282
+ return None
283
+
284
+ except Exception as e:
285
+ logging.error(f"Error loading state for {component}/{state_id}: {str(e)}")
286
+ return None
287
+
288
+ def cache_data(self, key: str, data: Any) -> bool:
289
+ """Cache data via HTTP API"""
290
+ try:
291
+ request_data = {"data": data}
292
+
293
+ response = self._make_request(
294
+ 'POST',
295
+ f'/cache/{key}',
296
+ json=request_data
297
+ )
298
+
299
+ return response and response.get('status') == 'success'
300
+
301
+ except Exception as e:
302
+ logging.error(f"Error caching data for key {key}: {str(e)}")
303
+ return False
304
+
305
+ def get_cached_data(self, key: str) -> Optional[Any]:
306
+ """Get cached data via HTTP API"""
307
+ try:
308
+ response = self._make_request("GET", f"/cache/{key}")
309
+
310
+ if response and response.get('status') == 'success':
311
+ return response.get('data')
312
+ return None
313
+
314
+ except Exception as e:
315
+ logging.error(f"Error getting cached data for key {key}: {str(e)}")
316
+ return None
317
+
318
+ def is_model_loaded(self, model_name: str) -> bool:
319
+ """Check if a model is loaded via HTTP API"""
320
+ try:
321
+ response = self._make_request("GET", f"/models/{model_name}/status")
322
+
323
+ if response and response.get('status') == 'loaded':
324
+ return True
325
+ return False
326
+
327
+ except Exception as e:
328
+ logging.error(f"Error checking model status for {model_name}: {str(e)}")
329
+ return False
330
+
331
+ def load_model(self, model_name: str, model_path: Optional[str] = None, model_data: Optional[Dict] = None) -> bool:
332
+ """Load a model via HTTP API"""
333
+ max_retries = 3
334
+ retry_delay = 2
335
+ last_error = None
336
+
337
+ for attempt in range(max_retries):
338
+ try:
339
+ # Ensure connection is active
340
+ if self._closing:
341
+ self._closing = False
342
+ if not self._create_session():
343
+ raise ConnectionError("Failed to recreate session")
344
+
345
+ # Check if model is already loaded
346
+ if self.is_model_loaded(model_name):
347
+ logging.info(f"Model {model_name} already loaded")
348
+ return True
349
+
350
+ # Calculate model hash if path provided
351
+ model_hash = None
352
+ if model_path:
353
+ model_hash = self._calculate_model_hash(model_path)
354
+
355
+ request_data = {
356
+ "model_data": model_data,
357
+ "model_path": model_path,
358
+ "model_hash": model_hash
359
+ }
360
+
361
+ response = self._make_request(
362
+ 'POST',
363
+ f'/models/{model_name}/load',
364
+ json=request_data,
365
+ timeout=22020 # Increased timeout for model loading
366
+ )
367
+
368
+ if response and response.get('status') == 'success':
369
+ with self.lock:
370
+ self.model_registry[model_name] = {
371
+ 'hash': model_hash,
372
+ 'timestamp': time.time(),
373
+ 'model_data': model_data
374
+ }
375
+ self.resource_monitor['loaded_models'].add(model_name)
376
+ logging.info(f"Successfully loaded model {model_name}")
377
+ return True
378
+ else:
379
+ last_error = response.get('message', 'HTTP connection unresponsive')
380
+ logging.error(f"Load attempt {attempt + 1} failed: {last_error}")
381
+ if attempt < max_retries - 1:
382
+ time.sleep(retry_delay * (1.5 ** attempt))
383
+ continue
384
+
385
+ except Exception as e:
386
+ last_error = str(e)
387
+ logging.error(f"Load attempt {attempt + 1} failed: {last_error}")
388
+ if attempt < max_retries - 1:
389
+ time.sleep(retry_delay * (1.5 ** attempt))
390
+ continue
391
+
392
+ logging.error(f"Failed to load model {model_name}: {last_error}")
393
+ return False
394
+
395
+ def _calculate_model_hash(self, model_path: str) -> str:
396
+ """Calculate SHA256 hash of model file"""
397
+ try:
398
+ sha256_hash = hashlib.sha256()
399
+ with open(model_path, "rb") as f:
400
+ for byte_block in iter(lambda: f.read(4096), b""):
401
+ sha256_hash.update(byte_block)
402
+ return sha256_hash.hexdigest()
403
+ except Exception as e:
404
+ logging.error(f"Error calculating model hash: {str(e)}")
405
+ return ""
406
+
407
+ def start_inference(self, model_name: str, input_data: np.ndarray) -> Optional[Dict[str, Any]]:
408
+ """Start inference with a loaded model via HTTP API"""
409
+ try:
410
+ if not self.is_model_loaded(model_name):
411
+ logging.error(f"Model {model_name} not loaded. Please load the model first.")
412
+ return None
413
+
414
+ request_data = {
415
+ "input_data": input_data.tolist() if isinstance(input_data, np.ndarray) else input_data
416
+ }
417
+
418
+ response = self._make_request(
419
+ 'POST',
420
+ f'/models/{model_name}/inference',
421
+ json=request_data
422
+ )
423
+
424
+ if response and response.get('status') == 'success':
425
+ return {
426
+ 'output': np.array(response['output']) if 'output' in response else None,
427
+ 'metrics': response.get('metrics', {}),
428
+ 'model_info': self.model_registry.get(model_name, {})
429
+ }
430
+ else:
431
+ logging.error(f"Inference failed for model {model_name}: {response.get('message', 'Unknown error')}")
432
+ return None
433
+
434
+ except Exception as e:
435
+ logging.error(f"Error during inference for model {model_name}: {str(e)}")
436
+ return None
437
+
438
+ def ping(self) -> bool:
439
+ """Ping the server to check connection status."""
440
+ try:
441
+ response = self._make_request('GET', '/status')
442
+ return response and response.get('status') == 'ok'
443
+ except Exception as e:
444
+ logging.error(f"Ping failed: {e}")
445
+ return False
446
+
447
+ def is_connected(self) -> bool:
448
+ """Check if the client is connected to the server."""
449
+ return self.ping()
450
+
451
+ def get_connection_status(self) -> Dict[str, Any]:
452
+ """Get detailed connection status."""
453
+ if self.is_connected():
454
+ return {"status": "connected", "session_id": self.session_id}
455
+ else:
456
+ return {"status": "disconnected", "error_count": self.error_count}
457
+
458
+ def set_keep_alive(self, interval: int):
459
+ """Set keep-alive interval (compatibility method)."""
460
+ logging.info(f"Keep-alive interval set to {interval} seconds (HTTP client does not use websockets).")
461
+
462
+ def reconnect(self):
463
+ """Attempt to reconnect (compatibility method)."""
464
+ logging.info("Attempting to reconnect HTTP client...")
465
+ self._create_session()
466
+
467
+ def wait_for_connection(self, timeout: float = 30.0) -> bool:
468
+ """Wait for HTTP connection to be established (compatibility method)"""
469
+ start_time = time.time()
470
+ while time.time() - start_time < timeout:
471
+ if self.is_connected():
472
+ logging.info("HTTP connection established.")
473
+ return True
474
+ time.sleep(1) # Wait for 1 second before retrying
475
+ logging.error("HTTP connection not established within timeout.")
476
+ return False
477
+
478
+ def close(self):
479
+ """Close HTTP client"""
480
+ self._closing = True
481
+ logging.info("HTTP client is closing.")
482
+ # Invalidate session on server side if possible
483
+ if self.session_token:
484
+ try:
485
+ self.http_session.post(f"{self.api_base}/sessions/invalidate",
486
+ headers={'Authorization': f'Bearer {self.session_token}'},
487
+ timeout=5)
488
+ except Exception as e:
489
+ logging.warning(f"Failed to invalidate session on server: {e}")
490
+ self.http_session.close()
491
+ HTTPGPUStorage._instance = None # Clear singleton instance
492
+
493
+ # Compatibility alias for existing code
494
+ WebSocketGPUStorage = HTTPGPUStorage
495
+
496
+