Factor Studios commited on
Commit
e2b3d8e
·
verified ·
1 Parent(s): b8076f9

Update virtual_gpu_server_http.py

Browse files
Files changed (1) hide show
  1. virtual_gpu_server_http.py +67 -77
virtual_gpu_server_http.py CHANGED
@@ -16,14 +16,13 @@ from datetime import datetime, timedelta
16
  import hashlib
17
  import gzip
18
  import base64
19
- import logging
20
  from pydantic import BaseModel
 
 
 
21
 
22
- # Configure logging
23
- logging.basicConfig(
24
- level=logging.INFO,
25
- format='%(asctime)s - %(levelname)s - %(message)s'
26
- )
27
 
28
  # Create FastAPI instance with enhanced configuration
29
  app = FastAPI(
@@ -127,6 +126,8 @@ class VirtualGPUServer:
127
  self.state_cache: Dict[str, Any] = {}
128
  self.memory_cache: Dict[str, Any] = {}
129
  self.model_cache: Dict[str, Any] = {}
 
 
130
 
131
  # Session management for HTTP API
132
  self.http_sessions: Dict[str, Dict[str, Any]] = {}
@@ -197,6 +198,39 @@ class VirtualGPUServer:
197
  """Decompress gzip data"""
198
  return gzip.decompress(data)
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  async def handle_vram_operation(self, operation: dict) -> dict:
201
  """Handle VRAM read/write operations (preserved from WebSocket implementation)"""
202
  try:
@@ -628,15 +662,6 @@ async def get_cache(
628
  detail=f"Cache get operation failed: {str(e)}"
629
  )
630
 
631
- def sanitize_model_name(model_name: str) -> str:
632
- """
633
- Sanitize model name for safe file system usage.
634
- Decodes URL-encoded name and replaces slashes with double underscores.
635
- """
636
- from urllib.parse import unquote
637
- decoded_name = unquote(model_name)
638
- return decoded_name.replace('/', '__')
639
-
640
  @app.post("/api/v1/models/{model_name}/load")
641
  async def load_model(
642
  model_name: str,
@@ -645,17 +670,12 @@ async def load_model(
645
  ):
646
  """Load AI model"""
647
  try:
648
- # Log the received model name for debugging
649
- logging.info(f"Received model load request - Raw name: {model_name}")
650
-
651
- # Sanitize model name for filesystem operations
652
- safe_name = sanitize_model_name(model_name)
653
- logging.info(f"Sanitized model name: {safe_name}")
654
-
655
- # Store model information
656
  model_info = {
657
- 'model_name': model_name, # Store original name
658
- 'safe_name': safe_name, # Store sanitized name
659
  'model_data': request.model_data,
660
  'model_path': request.model_path,
661
  'model_hash': request.model_hash,
@@ -663,28 +683,28 @@ async def load_model(
663
  'session_id': session['session_id']
664
  }
665
 
666
- # Use sanitized name for cache and file operations
667
  server.model_cache[model_name] = model_info
 
668
 
669
- # Store in persistent storage with safe name
670
  model_file = server.models_path / f"{safe_name}.json"
671
- logging.info(f"Storing model info at: {model_file}")
672
-
673
  with open(model_file, 'w') as f:
674
  json.dump(model_info, f)
675
 
676
  server.ops_counter += 1
 
677
  return {
678
  "status": "success",
679
  "message": f"Model {model_name} loaded successfully",
680
  "model_info": {
681
  "name": model_name,
682
- "safe_name": safe_name,
683
  "loaded_at": model_info['loaded_at']
684
  }
685
  }
686
 
687
  except Exception as e:
 
688
  raise HTTPException(
689
  status_code=500,
690
  detail=f"Model load operation failed: {str(e)}"
@@ -698,25 +718,12 @@ async def run_inference(
698
  ):
699
  """Run model inference"""
700
  try:
701
- logging.info(f"Running inference - Raw model name: {model_name}")
702
- safe_name = sanitize_model_name(model_name)
703
- logging.info(f"Running inference - Safe model name: {safe_name}")
704
-
705
- # Check if model is loaded (try both original and safe names)
706
- if model_name not in server.model_cache:
707
- # Try loading from file system using safe name
708
- model_file = server.models_path / f"{safe_name}.json"
709
- if not model_file.exists():
710
- logging.error(f"Model {model_name} not found in cache or filesystem")
711
- raise HTTPException(status_code=404, detail=f"Model {model_name} not loaded")
712
-
713
- logging.info(f"Loading model info from file: {model_file}")
714
- with open(model_file) as f:
715
- model_info = json.load(f)
716
- server.model_cache[model_name] = model_info
717
 
718
- # Simulate inference processing
719
- # In a real implementation, this would invoke the actual model
720
  result = {
721
  "status": "success",
722
  "output": request.input_data, # Echo input for now
@@ -724,17 +731,16 @@ async def run_inference(
724
  "inference_time": 0.1,
725
  "tokens_processed": len(request.input_data)
726
  },
727
- "model_info": server.model_cache[model_name]
728
  }
729
 
730
  server.ops_counter += 1
731
- logging.info(f"Inference completed successfully for model: {model_name}")
732
  return result
733
 
734
  except HTTPException:
735
  raise
736
  except Exception as e:
737
- logging.error(f"Inference operation failed for {model_name}: {str(e)}")
738
  raise HTTPException(
739
  status_code=500,
740
  detail=f"Inference operation failed: {str(e)}"
@@ -747,39 +753,20 @@ async def get_model_status(
747
  ):
748
  """Get model status"""
749
  try:
750
- logging.info(f"Checking model status - Raw name: {model_name}")
751
- safe_name = sanitize_model_name(model_name)
752
- logging.info(f"Checking model status - Safe name: {safe_name}")
753
-
754
- # Check cache first
755
- if model_name in server.model_cache:
756
- logging.info(f"Model {model_name} found in cache")
757
  return {
758
  "status": "loaded",
759
- "model_info": server.model_cache[model_name]
760
  }
761
-
762
- # Check file system using safe name
763
- model_file = server.models_path / f"{safe_name}.json"
764
- if model_file.exists():
765
- logging.info(f"Model file found: {model_file}")
766
- with open(model_file) as f:
767
- model_info = json.load(f)
768
- # Update cache
769
- server.model_cache[model_name] = model_info
770
  return {
771
- "status": "loaded",
772
- "model_info": model_info
773
  }
774
-
775
- logging.info(f"Model {model_name} not found in cache or filesystem")
776
- return {
777
- "status": "not_loaded",
778
- "message": f"Model {model_name} is not loaded"
779
- }
780
 
781
  except Exception as e:
782
- logging.error(f"Model status check failed for {model_name}: {str(e)}")
783
  raise HTTPException(
784
  status_code=500,
785
  detail=f"Model status check failed: {str(e)}"
@@ -844,6 +831,7 @@ async def transfer_between_chips(
844
  except HTTPException:
845
  raise
846
  except Exception as e:
 
847
  raise HTTPException(
848
  status_code=500,
849
  detail=f"Chip transfer failed: {str(e)}"
@@ -876,6 +864,7 @@ async def create_sync_barrier(
876
  }
877
 
878
  except Exception as e:
 
879
  raise HTTPException(
880
  status_code=500,
881
  detail=f"Barrier creation failed: {str(e)}"
@@ -913,6 +902,7 @@ async def wait_sync_barrier(
913
  except HTTPException:
914
  raise
915
  except Exception as e:
 
916
  raise HTTPException(
917
  status_code=500,
918
  detail=f"Barrier wait failed: {str(e)}"
 
16
  import hashlib
17
  import gzip
18
  import base64
 
19
  from pydantic import BaseModel
20
+ import urllib.parse
21
+ import re
22
+ import logging
23
 
24
+ # Configure basic logging
25
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
 
 
26
 
27
  # Create FastAPI instance with enhanced configuration
28
  app = FastAPI(
 
126
  self.state_cache: Dict[str, Any] = {}
127
  self.memory_cache: Dict[str, Any] = {}
128
  self.model_cache: Dict[str, Any] = {}
129
+ # map original model_name -> safe_filename used on disk
130
+ self.model_name_map: Dict[str, str] = {}
131
 
132
  # Session management for HTTP API
133
  self.http_sessions: Dict[str, Dict[str, Any]] = {}
 
198
  """Decompress gzip data"""
199
  return gzip.decompress(data)
200
 
201
+ def sanitize_model_name(self, model_name: str) -> str:
202
+ """Create a filesystem-safe filename from a provided model_name.
203
+ This will URL-decode percent-encoded values, then replace unsafe characters with underscores.
204
+ """
205
+ if not model_name:
206
+ return "unnamed_model"
207
+ # URL-decode first (handles cases where client sent %2F)
208
+ decoded = urllib.parse.unquote(model_name)
209
+ # Replace characters that are not alphanumeric, dot, underscore or dash
210
+ safe = re.sub(r'[^0-9A-Za-z._-]', '_', decoded)
211
+ # Trim length to avoid overly long filenames
212
+ return safe[:240]
213
+
214
+ def resolve_model_key(self, model_name: str) -> Optional[str]:
215
+ """Resolve the canonical model key used in model_cache.
216
+ Accepts several possible incoming forms (percent-encoded, decoded, sanitized) and returns
217
+ the key present in model_cache if any, otherwise None.
218
+ """
219
+ # direct hit
220
+ if model_name in self.model_cache:
221
+ return model_name
222
+ # try URL-decoded
223
+ decoded = urllib.parse.unquote(model_name)
224
+ if decoded in self.model_cache:
225
+ return decoded
226
+ # try sanitized form matching stored map
227
+ safe = self.sanitize_model_name(model_name)
228
+ # see if we have an original key that maps to safe filename
229
+ for orig, safe_name in self.model_name_map.items():
230
+ if safe_name == safe:
231
+ return orig
232
+ return None
233
+
234
  async def handle_vram_operation(self, operation: dict) -> dict:
235
  """Handle VRAM read/write operations (preserved from WebSocket implementation)"""
236
  try:
 
662
  detail=f"Cache get operation failed: {str(e)}"
663
  )
664
 
 
 
 
 
 
 
 
 
 
665
  @app.post("/api/v1/models/{model_name}/load")
666
  async def load_model(
667
  model_name: str,
 
670
  ):
671
  """Load AI model"""
672
  try:
673
+ logging.info(f"Received model load request for: {model_name}")
674
+ # Create a safe filename and persist model info under the original key
675
+ safe_name = server.sanitize_model_name(model_name)
676
+
 
 
 
 
677
  model_info = {
678
+ 'model_name': model_name,
 
679
  'model_data': request.model_data,
680
  'model_path': request.model_path,
681
  'model_hash': request.model_hash,
 
683
  'session_id': session['session_id']
684
  }
685
 
686
+ # Store mapping and cache
687
  server.model_cache[model_name] = model_info
688
+ server.model_name_map[model_name] = safe_name
689
 
690
+ # Store in persistent storage using safe filename
691
  model_file = server.models_path / f"{safe_name}.json"
 
 
692
  with open(model_file, 'w') as f:
693
  json.dump(model_info, f)
694
 
695
  server.ops_counter += 1
696
+ logging.info(f"Model '{model_name}' saved to disk as '{safe_name}.json'")
697
  return {
698
  "status": "success",
699
  "message": f"Model {model_name} loaded successfully",
700
  "model_info": {
701
  "name": model_name,
 
702
  "loaded_at": model_info['loaded_at']
703
  }
704
  }
705
 
706
  except Exception as e:
707
+ logging.exception("Model load operation failed")
708
  raise HTTPException(
709
  status_code=500,
710
  detail=f"Model load operation failed: {str(e)}"
 
718
  ):
719
  """Run model inference"""
720
  try:
721
+ logging.info(f"Inference requested for model: {model_name}")
722
+ resolved_key = server.resolve_model_key(model_name)
723
+ if not resolved_key:
724
+ raise HTTPException(status_code=404, detail=f"Model {model_name} not loaded")
 
 
 
 
 
 
 
 
 
 
 
 
725
 
726
+ # Simulate inference processing (echo input for now)
 
727
  result = {
728
  "status": "success",
729
  "output": request.input_data, # Echo input for now
 
731
  "inference_time": 0.1,
732
  "tokens_processed": len(request.input_data)
733
  },
734
+ "model_info": server.model_cache.get(resolved_key)
735
  }
736
 
737
  server.ops_counter += 1
 
738
  return result
739
 
740
  except HTTPException:
741
  raise
742
  except Exception as e:
743
+ logging.exception("Inference operation failed")
744
  raise HTTPException(
745
  status_code=500,
746
  detail=f"Inference operation failed: {str(e)}"
 
753
  ):
754
  """Get model status"""
755
  try:
756
+ resolved_key = server.resolve_model_key(model_name)
757
+ if resolved_key:
 
 
 
 
 
758
  return {
759
  "status": "loaded",
760
+ "model_info": server.model_cache[resolved_key]
761
  }
762
+ else:
 
 
 
 
 
 
 
 
763
  return {
764
+ "status": "not_loaded",
765
+ "message": f"Model {model_name} is not loaded"
766
  }
 
 
 
 
 
 
767
 
768
  except Exception as e:
769
+ logging.exception("Model status check failed")
770
  raise HTTPException(
771
  status_code=500,
772
  detail=f"Model status check failed: {str(e)}"
 
831
  except HTTPException:
832
  raise
833
  except Exception as e:
834
+ logging.exception("Chip transfer failed")
835
  raise HTTPException(
836
  status_code=500,
837
  detail=f"Chip transfer failed: {str(e)}"
 
864
  }
865
 
866
  except Exception as e:
867
+ logging.exception("Barrier creation failed")
868
  raise HTTPException(
869
  status_code=500,
870
  detail=f"Barrier creation failed: {str(e)}"
 
902
  except HTTPException:
903
  raise
904
  except Exception as e:
905
+ logging.exception("Barrier wait failed")
906
  raise HTTPException(
907
  status_code=500,
908
  detail=f"Barrier wait failed: {str(e)}"