Factor Studios commited on
Commit
f02307a
·
verified ·
1 Parent(s): 20c4c8d

Update virtual_gpu_server_http.py

Browse files
Files changed (1) hide show
  1. virtual_gpu_server_http.py +33 -2
virtual_gpu_server_http.py CHANGED
@@ -649,24 +649,54 @@ async def load_model(
649
  # Get safe filename for storage
650
  safe_name = sanitize_filename(model_name)
651
 
652
- # Store model information
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
653
  model_info = {
654
  'model_name': model_name,
655
  'model_data': request.model_data,
656
  'model_path': request.model_path,
657
  'model_hash': request.model_hash,
658
  'loaded_at': time.time(),
659
- 'session_id': session['session_id']
 
 
 
 
 
 
 
660
  }
661
 
662
  server.model_cache[model_name] = model_info
663
 
664
  # Store in persistent storage
665
  model_file = server.models_path / f"{safe_name}.json"
 
666
  logging.info(f"Storing model info at: {model_file}")
667
 
 
668
  with open(model_file, 'w') as f:
669
  json.dump(model_info, f)
 
 
 
 
 
 
670
 
671
  server.ops_counter += 1
672
  return {
@@ -674,6 +704,7 @@ async def load_model(
674
  "message": f"Model {model_name} loaded successfully",
675
  "model_info": {
676
  "name": model_name,
 
677
  "loaded_at": model_info['loaded_at']
678
  }
679
  }
 
649
  # Get safe filename for storage
650
  safe_name = sanitize_filename(model_name)
651
 
652
+ if not request.model_data:
653
+ raise HTTPException(
654
+ status_code=400,
655
+ detail="model_data is required and must include architecture configuration"
656
+ )
657
+
658
+ # Validate required model configuration
659
+ required_fields = ['num_sms', 'tensor_cores_per_sm', 'cuda_cores_per_sm']
660
+ missing_fields = [field for field in required_fields if field not in request.model_data]
661
+ if missing_fields:
662
+ raise HTTPException(
663
+ status_code=400,
664
+ detail=f"Missing required model configuration fields: {missing_fields}"
665
+ )
666
+
667
+ # Store model information with full configuration
668
  model_info = {
669
  'model_name': model_name,
670
  'model_data': request.model_data,
671
  'model_path': request.model_path,
672
  'model_hash': request.model_hash,
673
  'loaded_at': time.time(),
674
+ 'session_id': session['session_id'],
675
+ 'architecture': {
676
+ 'num_sms': request.model_data['num_sms'],
677
+ 'tensor_cores_per_sm': request.model_data['tensor_cores_per_sm'],
678
+ 'cuda_cores_per_sm': request.model_data['cuda_cores_per_sm'],
679
+ 'vram_allocation': request.model_data.get('vram_allocation', 'dynamic'),
680
+ 'compute_capability': request.model_data.get('compute_capability', '8.0')
681
+ }
682
  }
683
 
684
  server.model_cache[model_name] = model_info
685
 
686
  # Store in persistent storage
687
  model_file = server.models_path / f"{safe_name}.json"
688
+ model_data_file = server.models_path / f"{safe_name}.data"
689
  logging.info(f"Storing model info at: {model_file}")
690
 
691
+ # Store metadata and configuration
692
  with open(model_file, 'w') as f:
693
  json.dump(model_info, f)
694
+
695
+ # Store actual model data separately
696
+ if request.model_data.get('weights') or request.model_data.get('parameters'):
697
+ logging.info(f"Storing model data at: {model_data_file}")
698
+ with open(model_data_file, 'w') as f:
699
+ json.dump(request.model_data, f)
700
 
701
  server.ops_counter += 1
702
  return {
 
704
  "message": f"Model {model_name} loaded successfully",
705
  "model_info": {
706
  "name": model_name,
707
+ "architecture": model_info['architecture'],
708
  "loaded_at": model_info['loaded_at']
709
  }
710
  }