Factor Studios commited on
Commit
f3885ff
·
verified ·
1 Parent(s): e2b3d8e

Update virtual_gpu_server_http.py

Browse files
Files changed (1) hide show
  1. virtual_gpu_server_http.py +68 -65
virtual_gpu_server_http.py CHANGED
@@ -16,13 +16,14 @@ from datetime import datetime, timedelta
16
  import hashlib
17
  import gzip
18
  import base64
19
- from pydantic import BaseModel
20
- import urllib.parse
21
- import re
22
  import logging
 
23
 
24
- # Configure basic logging
25
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
 
 
26
 
27
  # Create FastAPI instance with enhanced configuration
28
  app = FastAPI(
@@ -126,8 +127,6 @@ class VirtualGPUServer:
126
  self.state_cache: Dict[str, Any] = {}
127
  self.memory_cache: Dict[str, Any] = {}
128
  self.model_cache: Dict[str, Any] = {}
129
- # map original model_name -> safe_filename used on disk
130
- self.model_name_map: Dict[str, str] = {}
131
 
132
  # Session management for HTTP API
133
  self.http_sessions: Dict[str, Dict[str, Any]] = {}
@@ -198,39 +197,6 @@ class VirtualGPUServer:
198
  """Decompress gzip data"""
199
  return gzip.decompress(data)
200
 
201
- def sanitize_model_name(self, model_name: str) -> str:
202
- """Create a filesystem-safe filename from a provided model_name.
203
- This will URL-decode percent-encoded values, then replace unsafe characters with underscores.
204
- """
205
- if not model_name:
206
- return "unnamed_model"
207
- # URL-decode first (handles cases where client sent %2F)
208
- decoded = urllib.parse.unquote(model_name)
209
- # Replace characters that are not alphanumeric, dot, underscore or dash
210
- safe = re.sub(r'[^0-9A-Za-z._-]', '_', decoded)
211
- # Trim length to avoid overly long filenames
212
- return safe[:240]
213
-
214
- def resolve_model_key(self, model_name: str) -> Optional[str]:
215
- """Resolve the canonical model key used in model_cache.
216
- Accepts several possible incoming forms (percent-encoded, decoded, sanitized) and returns
217
- the key present in model_cache if any, otherwise None.
218
- """
219
- # direct hit
220
- if model_name in self.model_cache:
221
- return model_name
222
- # try URL-decoded
223
- decoded = urllib.parse.unquote(model_name)
224
- if decoded in self.model_cache:
225
- return decoded
226
- # try sanitized form matching stored map
227
- safe = self.sanitize_model_name(model_name)
228
- # see if we have an original key that maps to safe filename
229
- for orig, safe_name in self.model_name_map.items():
230
- if safe_name == safe:
231
- return orig
232
- return None
233
-
234
  async def handle_vram_operation(self, operation: dict) -> dict:
235
  """Handle VRAM read/write operations (preserved from WebSocket implementation)"""
236
  try:
@@ -662,6 +628,13 @@ async def get_cache(
662
  detail=f"Cache get operation failed: {str(e)}"
663
  )
664
 
 
 
 
 
 
 
 
665
  @app.post("/api/v1/models/{model_name}/load")
666
  async def load_model(
667
  model_name: str,
@@ -670,10 +643,13 @@ async def load_model(
670
  ):
671
  """Load AI model"""
672
  try:
 
673
  logging.info(f"Received model load request for: {model_name}")
674
- # Create a safe filename and persist model info under the original key
675
- safe_name = server.sanitize_model_name(model_name)
676
-
 
 
677
  model_info = {
678
  'model_name': model_name,
679
  'model_data': request.model_data,
@@ -683,17 +659,16 @@ async def load_model(
683
  'session_id': session['session_id']
684
  }
685
 
686
- # Store mapping and cache
687
  server.model_cache[model_name] = model_info
688
- server.model_name_map[model_name] = safe_name
689
 
690
- # Store in persistent storage using safe filename
691
  model_file = server.models_path / f"{safe_name}.json"
 
 
692
  with open(model_file, 'w') as f:
693
  json.dump(model_info, f)
694
 
695
  server.ops_counter += 1
696
- logging.info(f"Model '{model_name}' saved to disk as '{safe_name}.json'")
697
  return {
698
  "status": "success",
699
  "message": f"Model {model_name} loaded successfully",
@@ -704,7 +679,6 @@ async def load_model(
704
  }
705
 
706
  except Exception as e:
707
- logging.exception("Model load operation failed")
708
  raise HTTPException(
709
  status_code=500,
710
  detail=f"Model load operation failed: {str(e)}"
@@ -718,12 +692,25 @@ async def run_inference(
718
  ):
719
  """Run model inference"""
720
  try:
721
- logging.info(f"Inference requested for model: {model_name}")
722
- resolved_key = server.resolve_model_key(model_name)
723
- if not resolved_key:
724
- raise HTTPException(status_code=404, detail=f"Model {model_name} not loaded")
 
 
 
 
 
 
 
 
 
 
 
 
725
 
726
- # Simulate inference processing (echo input for now)
 
727
  result = {
728
  "status": "success",
729
  "output": request.input_data, # Echo input for now
@@ -731,16 +718,17 @@ async def run_inference(
731
  "inference_time": 0.1,
732
  "tokens_processed": len(request.input_data)
733
  },
734
- "model_info": server.model_cache.get(resolved_key)
735
  }
736
 
737
  server.ops_counter += 1
 
738
  return result
739
 
740
  except HTTPException:
741
  raise
742
  except Exception as e:
743
- logging.exception("Inference operation failed")
744
  raise HTTPException(
745
  status_code=500,
746
  detail=f"Inference operation failed: {str(e)}"
@@ -753,20 +741,38 @@ async def get_model_status(
753
  ):
754
  """Get model status"""
755
  try:
756
- resolved_key = server.resolve_model_key(model_name)
757
- if resolved_key:
 
 
 
758
  return {
759
  "status": "loaded",
760
- "model_info": server.model_cache[resolved_key]
761
  }
762
- else:
 
 
 
 
 
 
 
 
 
763
  return {
764
- "status": "not_loaded",
765
- "message": f"Model {model_name} is not loaded"
766
  }
 
 
 
 
 
 
767
 
768
  except Exception as e:
769
- logging.exception("Model status check failed")
770
  raise HTTPException(
771
  status_code=500,
772
  detail=f"Model status check failed: {str(e)}"
@@ -831,7 +837,6 @@ async def transfer_between_chips(
831
  except HTTPException:
832
  raise
833
  except Exception as e:
834
- logging.exception("Chip transfer failed")
835
  raise HTTPException(
836
  status_code=500,
837
  detail=f"Chip transfer failed: {str(e)}"
@@ -864,7 +869,6 @@ async def create_sync_barrier(
864
  }
865
 
866
  except Exception as e:
867
- logging.exception("Barrier creation failed")
868
  raise HTTPException(
869
  status_code=500,
870
  detail=f"Barrier creation failed: {str(e)}"
@@ -902,7 +906,6 @@ async def wait_sync_barrier(
902
  except HTTPException:
903
  raise
904
  except Exception as e:
905
- logging.exception("Barrier wait failed")
906
  raise HTTPException(
907
  status_code=500,
908
  detail=f"Barrier wait failed: {str(e)}"
 
16
  import hashlib
17
  import gzip
18
  import base64
 
 
 
19
  import logging
20
+ from pydantic import BaseModel
21
 
22
+ # Configure logging
23
+ logging.basicConfig(
24
+ level=logging.INFO,
25
+ format='%(asctime)s - %(levelname)s - %(message)s'
26
+ )
27
 
28
  # Create FastAPI instance with enhanced configuration
29
  app = FastAPI(
 
127
  self.state_cache: Dict[str, Any] = {}
128
  self.memory_cache: Dict[str, Any] = {}
129
  self.model_cache: Dict[str, Any] = {}
 
 
130
 
131
  # Session management for HTTP API
132
  self.http_sessions: Dict[str, Dict[str, Any]] = {}
 
197
  """Decompress gzip data"""
198
  return gzip.decompress(data)
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  async def handle_vram_operation(self, operation: dict) -> dict:
201
  """Handle VRAM read/write operations (preserved from WebSocket implementation)"""
202
  try:
 
628
  detail=f"Cache get operation failed: {str(e)}"
629
  )
630
 
631
+ def sanitize_filename(name: str) -> str:
632
+ """
633
+ Sanitize a string for safe file system usage.
634
+ Replaces slashes with double underscores.
635
+ """
636
+ return name.replace('/', '__')
637
+
638
  @app.post("/api/v1/models/{model_name}/load")
639
  async def load_model(
640
  model_name: str,
 
643
  ):
644
  """Load AI model"""
645
  try:
646
+ # Log the received model name for debugging
647
  logging.info(f"Received model load request for: {model_name}")
648
+
649
+ # Get safe filename for storage
650
+ safe_name = sanitize_filename(model_name)
651
+
652
+ # Store model information
653
  model_info = {
654
  'model_name': model_name,
655
  'model_data': request.model_data,
 
659
  'session_id': session['session_id']
660
  }
661
 
 
662
  server.model_cache[model_name] = model_info
 
663
 
664
+ # Store in persistent storage
665
  model_file = server.models_path / f"{safe_name}.json"
666
+ logging.info(f"Storing model info at: {model_file}")
667
+
668
  with open(model_file, 'w') as f:
669
  json.dump(model_info, f)
670
 
671
  server.ops_counter += 1
 
672
  return {
673
  "status": "success",
674
  "message": f"Model {model_name} loaded successfully",
 
679
  }
680
 
681
  except Exception as e:
 
682
  raise HTTPException(
683
  status_code=500,
684
  detail=f"Model load operation failed: {str(e)}"
 
692
  ):
693
  """Run model inference"""
694
  try:
695
+ logging.info(f"Running inference - Raw model name: {model_name}")
696
+ safe_name = sanitize_model_name(model_name)
697
+ logging.info(f"Running inference - Safe model name: {safe_name}")
698
+
699
+ # Check if model is loaded (try both original and safe names)
700
+ if model_name not in server.model_cache:
701
+ # Try loading from file system using safe name
702
+ model_file = server.models_path / f"{safe_name}.json"
703
+ if not model_file.exists():
704
+ logging.error(f"Model {model_name} not found in cache or filesystem")
705
+ raise HTTPException(status_code=404, detail=f"Model {model_name} not loaded")
706
+
707
+ logging.info(f"Loading model info from file: {model_file}")
708
+ with open(model_file) as f:
709
+ model_info = json.load(f)
710
+ server.model_cache[model_name] = model_info
711
 
712
+ # Simulate inference processing
713
+ # In a real implementation, this would invoke the actual model
714
  result = {
715
  "status": "success",
716
  "output": request.input_data, # Echo input for now
 
718
  "inference_time": 0.1,
719
  "tokens_processed": len(request.input_data)
720
  },
721
+ "model_info": server.model_cache[model_name]
722
  }
723
 
724
  server.ops_counter += 1
725
+ logging.info(f"Inference completed successfully for model: {model_name}")
726
  return result
727
 
728
  except HTTPException:
729
  raise
730
  except Exception as e:
731
+ logging.error(f"Inference operation failed for {model_name}: {str(e)}")
732
  raise HTTPException(
733
  status_code=500,
734
  detail=f"Inference operation failed: {str(e)}"
 
741
  ):
742
  """Get model status"""
743
  try:
744
+ logging.info(f"Checking model status for: {model_name}")
745
+
746
+ # Check cache first
747
+ if model_name in server.model_cache:
748
+ logging.info(f"Model {model_name} found in cache")
749
  return {
750
  "status": "loaded",
751
+ "model_info": server.model_cache[model_name]
752
  }
753
+
754
+ # Check file system using safe name
755
+ safe_name = sanitize_filename(model_name)
756
+ model_file = server.models_path / f"{safe_name}.json"
757
+ if model_file.exists():
758
+ logging.info(f"Model file found: {model_file}")
759
+ with open(model_file) as f:
760
+ model_info = json.load(f)
761
+ # Update cache
762
+ server.model_cache[model_name] = model_info
763
  return {
764
+ "status": "loaded",
765
+ "model_info": model_info
766
  }
767
+
768
+ logging.info(f"Model {model_name} not found in cache or filesystem")
769
+ return {
770
+ "status": "not_loaded",
771
+ "message": f"Model {model_name} is not loaded"
772
+ }
773
 
774
  except Exception as e:
775
+ logging.error(f"Model status check failed for {model_name}: {str(e)}")
776
  raise HTTPException(
777
  status_code=500,
778
  detail=f"Model status check failed: {str(e)}"
 
837
  except HTTPException:
838
  raise
839
  except Exception as e:
 
840
  raise HTTPException(
841
  status_code=500,
842
  detail=f"Chip transfer failed: {str(e)}"
 
869
  }
870
 
871
  except Exception as e:
 
872
  raise HTTPException(
873
  status_code=500,
874
  detail=f"Barrier creation failed: {str(e)}"
 
906
  except HTTPException:
907
  raise
908
  except Exception as e:
 
909
  raise HTTPException(
910
  status_code=500,
911
  detail=f"Barrier wait failed: {str(e)}"