Spaces:
Sleeping
Sleeping
Factor Studios
commited on
Update virtual_gpu_server_http.py
Browse files- virtual_gpu_server_http.py +67 -77
virtual_gpu_server_http.py
CHANGED
|
@@ -16,14 +16,13 @@ from datetime import datetime, timedelta
|
|
| 16 |
import hashlib
|
| 17 |
import gzip
|
| 18 |
import base64
|
| 19 |
-
import logging
|
| 20 |
from pydantic import BaseModel
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
-
# Configure logging
|
| 23 |
-
logging.basicConfig(
|
| 24 |
-
level=logging.INFO,
|
| 25 |
-
format='%(asctime)s - %(levelname)s - %(message)s'
|
| 26 |
-
)
|
| 27 |
|
| 28 |
# Create FastAPI instance with enhanced configuration
|
| 29 |
app = FastAPI(
|
|
@@ -127,6 +126,8 @@ class VirtualGPUServer:
|
|
| 127 |
self.state_cache: Dict[str, Any] = {}
|
| 128 |
self.memory_cache: Dict[str, Any] = {}
|
| 129 |
self.model_cache: Dict[str, Any] = {}
|
|
|
|
|
|
|
| 130 |
|
| 131 |
# Session management for HTTP API
|
| 132 |
self.http_sessions: Dict[str, Dict[str, Any]] = {}
|
|
@@ -197,6 +198,39 @@ class VirtualGPUServer:
|
|
| 197 |
"""Decompress gzip data"""
|
| 198 |
return gzip.decompress(data)
|
| 199 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
async def handle_vram_operation(self, operation: dict) -> dict:
|
| 201 |
"""Handle VRAM read/write operations (preserved from WebSocket implementation)"""
|
| 202 |
try:
|
|
@@ -628,15 +662,6 @@ async def get_cache(
|
|
| 628 |
detail=f"Cache get operation failed: {str(e)}"
|
| 629 |
)
|
| 630 |
|
| 631 |
-
def sanitize_model_name(model_name: str) -> str:
|
| 632 |
-
"""
|
| 633 |
-
Sanitize model name for safe file system usage.
|
| 634 |
-
Decodes URL-encoded name and replaces slashes with double underscores.
|
| 635 |
-
"""
|
| 636 |
-
from urllib.parse import unquote
|
| 637 |
-
decoded_name = unquote(model_name)
|
| 638 |
-
return decoded_name.replace('/', '__')
|
| 639 |
-
|
| 640 |
@app.post("/api/v1/models/{model_name}/load")
|
| 641 |
async def load_model(
|
| 642 |
model_name: str,
|
|
@@ -645,17 +670,12 @@ async def load_model(
|
|
| 645 |
):
|
| 646 |
"""Load AI model"""
|
| 647 |
try:
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
safe_name = sanitize_model_name(model_name)
|
| 653 |
-
logging.info(f"Sanitized model name: {safe_name}")
|
| 654 |
-
|
| 655 |
-
# Store model information
|
| 656 |
model_info = {
|
| 657 |
-
'model_name': model_name,
|
| 658 |
-
'safe_name': safe_name, # Store sanitized name
|
| 659 |
'model_data': request.model_data,
|
| 660 |
'model_path': request.model_path,
|
| 661 |
'model_hash': request.model_hash,
|
|
@@ -663,28 +683,28 @@ async def load_model(
|
|
| 663 |
'session_id': session['session_id']
|
| 664 |
}
|
| 665 |
|
| 666 |
-
#
|
| 667 |
server.model_cache[model_name] = model_info
|
|
|
|
| 668 |
|
| 669 |
-
# Store in persistent storage
|
| 670 |
model_file = server.models_path / f"{safe_name}.json"
|
| 671 |
-
logging.info(f"Storing model info at: {model_file}")
|
| 672 |
-
|
| 673 |
with open(model_file, 'w') as f:
|
| 674 |
json.dump(model_info, f)
|
| 675 |
|
| 676 |
server.ops_counter += 1
|
|
|
|
| 677 |
return {
|
| 678 |
"status": "success",
|
| 679 |
"message": f"Model {model_name} loaded successfully",
|
| 680 |
"model_info": {
|
| 681 |
"name": model_name,
|
| 682 |
-
"safe_name": safe_name,
|
| 683 |
"loaded_at": model_info['loaded_at']
|
| 684 |
}
|
| 685 |
}
|
| 686 |
|
| 687 |
except Exception as e:
|
|
|
|
| 688 |
raise HTTPException(
|
| 689 |
status_code=500,
|
| 690 |
detail=f"Model load operation failed: {str(e)}"
|
|
@@ -698,25 +718,12 @@ async def run_inference(
|
|
| 698 |
):
|
| 699 |
"""Run model inference"""
|
| 700 |
try:
|
| 701 |
-
logging.info(f"
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
# Check if model is loaded (try both original and safe names)
|
| 706 |
-
if model_name not in server.model_cache:
|
| 707 |
-
# Try loading from file system using safe name
|
| 708 |
-
model_file = server.models_path / f"{safe_name}.json"
|
| 709 |
-
if not model_file.exists():
|
| 710 |
-
logging.error(f"Model {model_name} not found in cache or filesystem")
|
| 711 |
-
raise HTTPException(status_code=404, detail=f"Model {model_name} not loaded")
|
| 712 |
-
|
| 713 |
-
logging.info(f"Loading model info from file: {model_file}")
|
| 714 |
-
with open(model_file) as f:
|
| 715 |
-
model_info = json.load(f)
|
| 716 |
-
server.model_cache[model_name] = model_info
|
| 717 |
|
| 718 |
-
# Simulate inference processing
|
| 719 |
-
# In a real implementation, this would invoke the actual model
|
| 720 |
result = {
|
| 721 |
"status": "success",
|
| 722 |
"output": request.input_data, # Echo input for now
|
|
@@ -724,17 +731,16 @@ async def run_inference(
|
|
| 724 |
"inference_time": 0.1,
|
| 725 |
"tokens_processed": len(request.input_data)
|
| 726 |
},
|
| 727 |
-
"model_info": server.model_cache
|
| 728 |
}
|
| 729 |
|
| 730 |
server.ops_counter += 1
|
| 731 |
-
logging.info(f"Inference completed successfully for model: {model_name}")
|
| 732 |
return result
|
| 733 |
|
| 734 |
except HTTPException:
|
| 735 |
raise
|
| 736 |
except Exception as e:
|
| 737 |
-
logging.
|
| 738 |
raise HTTPException(
|
| 739 |
status_code=500,
|
| 740 |
detail=f"Inference operation failed: {str(e)}"
|
|
@@ -747,39 +753,20 @@ async def get_model_status(
|
|
| 747 |
):
|
| 748 |
"""Get model status"""
|
| 749 |
try:
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
logging.info(f"Checking model status - Safe name: {safe_name}")
|
| 753 |
-
|
| 754 |
-
# Check cache first
|
| 755 |
-
if model_name in server.model_cache:
|
| 756 |
-
logging.info(f"Model {model_name} found in cache")
|
| 757 |
return {
|
| 758 |
"status": "loaded",
|
| 759 |
-
"model_info": server.model_cache[
|
| 760 |
}
|
| 761 |
-
|
| 762 |
-
# Check file system using safe name
|
| 763 |
-
model_file = server.models_path / f"{safe_name}.json"
|
| 764 |
-
if model_file.exists():
|
| 765 |
-
logging.info(f"Model file found: {model_file}")
|
| 766 |
-
with open(model_file) as f:
|
| 767 |
-
model_info = json.load(f)
|
| 768 |
-
# Update cache
|
| 769 |
-
server.model_cache[model_name] = model_info
|
| 770 |
return {
|
| 771 |
-
"status": "
|
| 772 |
-
"
|
| 773 |
}
|
| 774 |
-
|
| 775 |
-
logging.info(f"Model {model_name} not found in cache or filesystem")
|
| 776 |
-
return {
|
| 777 |
-
"status": "not_loaded",
|
| 778 |
-
"message": f"Model {model_name} is not loaded"
|
| 779 |
-
}
|
| 780 |
|
| 781 |
except Exception as e:
|
| 782 |
-
logging.
|
| 783 |
raise HTTPException(
|
| 784 |
status_code=500,
|
| 785 |
detail=f"Model status check failed: {str(e)}"
|
|
@@ -844,6 +831,7 @@ async def transfer_between_chips(
|
|
| 844 |
except HTTPException:
|
| 845 |
raise
|
| 846 |
except Exception as e:
|
|
|
|
| 847 |
raise HTTPException(
|
| 848 |
status_code=500,
|
| 849 |
detail=f"Chip transfer failed: {str(e)}"
|
|
@@ -876,6 +864,7 @@ async def create_sync_barrier(
|
|
| 876 |
}
|
| 877 |
|
| 878 |
except Exception as e:
|
|
|
|
| 879 |
raise HTTPException(
|
| 880 |
status_code=500,
|
| 881 |
detail=f"Barrier creation failed: {str(e)}"
|
|
@@ -913,6 +902,7 @@ async def wait_sync_barrier(
|
|
| 913 |
except HTTPException:
|
| 914 |
raise
|
| 915 |
except Exception as e:
|
|
|
|
| 916 |
raise HTTPException(
|
| 917 |
status_code=500,
|
| 918 |
detail=f"Barrier wait failed: {str(e)}"
|
|
|
|
| 16 |
import hashlib
|
| 17 |
import gzip
|
| 18 |
import base64
|
|
|
|
| 19 |
from pydantic import BaseModel
|
| 20 |
+
import urllib.parse
|
| 21 |
+
import re
|
| 22 |
+
import logging
|
| 23 |
|
| 24 |
+
# Configure basic logging
|
| 25 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
# Create FastAPI instance with enhanced configuration
|
| 28 |
app = FastAPI(
|
|
|
|
| 126 |
self.state_cache: Dict[str, Any] = {}
|
| 127 |
self.memory_cache: Dict[str, Any] = {}
|
| 128 |
self.model_cache: Dict[str, Any] = {}
|
| 129 |
+
# map original model_name -> safe_filename used on disk
|
| 130 |
+
self.model_name_map: Dict[str, str] = {}
|
| 131 |
|
| 132 |
# Session management for HTTP API
|
| 133 |
self.http_sessions: Dict[str, Dict[str, Any]] = {}
|
|
|
|
| 198 |
"""Decompress gzip data"""
|
| 199 |
return gzip.decompress(data)
|
| 200 |
|
| 201 |
+
def sanitize_model_name(self, model_name: str) -> str:
|
| 202 |
+
"""Create a filesystem-safe filename from a provided model_name.
|
| 203 |
+
This will URL-decode percent-encoded values, then replace unsafe characters with underscores.
|
| 204 |
+
"""
|
| 205 |
+
if not model_name:
|
| 206 |
+
return "unnamed_model"
|
| 207 |
+
# URL-decode first (handles cases where client sent %2F)
|
| 208 |
+
decoded = urllib.parse.unquote(model_name)
|
| 209 |
+
# Replace characters that are not alphanumeric, dot, underscore or dash
|
| 210 |
+
safe = re.sub(r'[^0-9A-Za-z._-]', '_', decoded)
|
| 211 |
+
# Trim length to avoid overly long filenames
|
| 212 |
+
return safe[:240]
|
| 213 |
+
|
| 214 |
+
def resolve_model_key(self, model_name: str) -> Optional[str]:
|
| 215 |
+
"""Resolve the canonical model key used in model_cache.
|
| 216 |
+
Accepts several possible incoming forms (percent-encoded, decoded, sanitized) and returns
|
| 217 |
+
the key present in model_cache if any, otherwise None.
|
| 218 |
+
"""
|
| 219 |
+
# direct hit
|
| 220 |
+
if model_name in self.model_cache:
|
| 221 |
+
return model_name
|
| 222 |
+
# try URL-decoded
|
| 223 |
+
decoded = urllib.parse.unquote(model_name)
|
| 224 |
+
if decoded in self.model_cache:
|
| 225 |
+
return decoded
|
| 226 |
+
# try sanitized form matching stored map
|
| 227 |
+
safe = self.sanitize_model_name(model_name)
|
| 228 |
+
# see if we have an original key that maps to safe filename
|
| 229 |
+
for orig, safe_name in self.model_name_map.items():
|
| 230 |
+
if safe_name == safe:
|
| 231 |
+
return orig
|
| 232 |
+
return None
|
| 233 |
+
|
| 234 |
async def handle_vram_operation(self, operation: dict) -> dict:
|
| 235 |
"""Handle VRAM read/write operations (preserved from WebSocket implementation)"""
|
| 236 |
try:
|
|
|
|
| 662 |
detail=f"Cache get operation failed: {str(e)}"
|
| 663 |
)
|
| 664 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 665 |
@app.post("/api/v1/models/{model_name}/load")
|
| 666 |
async def load_model(
|
| 667 |
model_name: str,
|
|
|
|
| 670 |
):
|
| 671 |
"""Load AI model"""
|
| 672 |
try:
|
| 673 |
+
logging.info(f"Received model load request for: {model_name}")
|
| 674 |
+
# Create a safe filename and persist model info under the original key
|
| 675 |
+
safe_name = server.sanitize_model_name(model_name)
|
| 676 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 677 |
model_info = {
|
| 678 |
+
'model_name': model_name,
|
|
|
|
| 679 |
'model_data': request.model_data,
|
| 680 |
'model_path': request.model_path,
|
| 681 |
'model_hash': request.model_hash,
|
|
|
|
| 683 |
'session_id': session['session_id']
|
| 684 |
}
|
| 685 |
|
| 686 |
+
# Store mapping and cache
|
| 687 |
server.model_cache[model_name] = model_info
|
| 688 |
+
server.model_name_map[model_name] = safe_name
|
| 689 |
|
| 690 |
+
# Store in persistent storage using safe filename
|
| 691 |
model_file = server.models_path / f"{safe_name}.json"
|
|
|
|
|
|
|
| 692 |
with open(model_file, 'w') as f:
|
| 693 |
json.dump(model_info, f)
|
| 694 |
|
| 695 |
server.ops_counter += 1
|
| 696 |
+
logging.info(f"Model '{model_name}' saved to disk as '{safe_name}.json'")
|
| 697 |
return {
|
| 698 |
"status": "success",
|
| 699 |
"message": f"Model {model_name} loaded successfully",
|
| 700 |
"model_info": {
|
| 701 |
"name": model_name,
|
|
|
|
| 702 |
"loaded_at": model_info['loaded_at']
|
| 703 |
}
|
| 704 |
}
|
| 705 |
|
| 706 |
except Exception as e:
|
| 707 |
+
logging.exception("Model load operation failed")
|
| 708 |
raise HTTPException(
|
| 709 |
status_code=500,
|
| 710 |
detail=f"Model load operation failed: {str(e)}"
|
|
|
|
| 718 |
):
|
| 719 |
"""Run model inference"""
|
| 720 |
try:
|
| 721 |
+
logging.info(f"Inference requested for model: {model_name}")
|
| 722 |
+
resolved_key = server.resolve_model_key(model_name)
|
| 723 |
+
if not resolved_key:
|
| 724 |
+
raise HTTPException(status_code=404, detail=f"Model {model_name} not loaded")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 725 |
|
| 726 |
+
# Simulate inference processing (echo input for now)
|
|
|
|
| 727 |
result = {
|
| 728 |
"status": "success",
|
| 729 |
"output": request.input_data, # Echo input for now
|
|
|
|
| 731 |
"inference_time": 0.1,
|
| 732 |
"tokens_processed": len(request.input_data)
|
| 733 |
},
|
| 734 |
+
"model_info": server.model_cache.get(resolved_key)
|
| 735 |
}
|
| 736 |
|
| 737 |
server.ops_counter += 1
|
|
|
|
| 738 |
return result
|
| 739 |
|
| 740 |
except HTTPException:
|
| 741 |
raise
|
| 742 |
except Exception as e:
|
| 743 |
+
logging.exception("Inference operation failed")
|
| 744 |
raise HTTPException(
|
| 745 |
status_code=500,
|
| 746 |
detail=f"Inference operation failed: {str(e)}"
|
|
|
|
| 753 |
):
|
| 754 |
"""Get model status"""
|
| 755 |
try:
|
| 756 |
+
resolved_key = server.resolve_model_key(model_name)
|
| 757 |
+
if resolved_key:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 758 |
return {
|
| 759 |
"status": "loaded",
|
| 760 |
+
"model_info": server.model_cache[resolved_key]
|
| 761 |
}
|
| 762 |
+
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 763 |
return {
|
| 764 |
+
"status": "not_loaded",
|
| 765 |
+
"message": f"Model {model_name} is not loaded"
|
| 766 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 767 |
|
| 768 |
except Exception as e:
|
| 769 |
+
logging.exception("Model status check failed")
|
| 770 |
raise HTTPException(
|
| 771 |
status_code=500,
|
| 772 |
detail=f"Model status check failed: {str(e)}"
|
|
|
|
| 831 |
except HTTPException:
|
| 832 |
raise
|
| 833 |
except Exception as e:
|
| 834 |
+
logging.exception("Chip transfer failed")
|
| 835 |
raise HTTPException(
|
| 836 |
status_code=500,
|
| 837 |
detail=f"Chip transfer failed: {str(e)}"
|
|
|
|
| 864 |
}
|
| 865 |
|
| 866 |
except Exception as e:
|
| 867 |
+
logging.exception("Barrier creation failed")
|
| 868 |
raise HTTPException(
|
| 869 |
status_code=500,
|
| 870 |
detail=f"Barrier creation failed: {str(e)}"
|
|
|
|
| 902 |
except HTTPException:
|
| 903 |
raise
|
| 904 |
except Exception as e:
|
| 905 |
+
logging.exception("Barrier wait failed")
|
| 906 |
raise HTTPException(
|
| 907 |
status_code=500,
|
| 908 |
detail=f"Barrier wait failed: {str(e)}"
|