Spaces:
Sleeping
Sleeping
Factor Studios
commited on
Update virtual_gpu_server_http.py
Browse files- virtual_gpu_server_http.py +68 -65
virtual_gpu_server_http.py
CHANGED
|
@@ -16,13 +16,14 @@ from datetime import datetime, timedelta
|
|
| 16 |
import hashlib
|
| 17 |
import gzip
|
| 18 |
import base64
|
| 19 |
-
from pydantic import BaseModel
|
| 20 |
-
import urllib.parse
|
| 21 |
-
import re
|
| 22 |
import logging
|
|
|
|
| 23 |
|
| 24 |
-
# Configure
|
| 25 |
-
logging.basicConfig(
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
# Create FastAPI instance with enhanced configuration
|
| 28 |
app = FastAPI(
|
|
@@ -126,8 +127,6 @@ class VirtualGPUServer:
|
|
| 126 |
self.state_cache: Dict[str, Any] = {}
|
| 127 |
self.memory_cache: Dict[str, Any] = {}
|
| 128 |
self.model_cache: Dict[str, Any] = {}
|
| 129 |
-
# map original model_name -> safe_filename used on disk
|
| 130 |
-
self.model_name_map: Dict[str, str] = {}
|
| 131 |
|
| 132 |
# Session management for HTTP API
|
| 133 |
self.http_sessions: Dict[str, Dict[str, Any]] = {}
|
|
@@ -198,39 +197,6 @@ class VirtualGPUServer:
|
|
| 198 |
"""Decompress gzip data"""
|
| 199 |
return gzip.decompress(data)
|
| 200 |
|
| 201 |
-
def sanitize_model_name(self, model_name: str) -> str:
|
| 202 |
-
"""Create a filesystem-safe filename from a provided model_name.
|
| 203 |
-
This will URL-decode percent-encoded values, then replace unsafe characters with underscores.
|
| 204 |
-
"""
|
| 205 |
-
if not model_name:
|
| 206 |
-
return "unnamed_model"
|
| 207 |
-
# URL-decode first (handles cases where client sent %2F)
|
| 208 |
-
decoded = urllib.parse.unquote(model_name)
|
| 209 |
-
# Replace characters that are not alphanumeric, dot, underscore or dash
|
| 210 |
-
safe = re.sub(r'[^0-9A-Za-z._-]', '_', decoded)
|
| 211 |
-
# Trim length to avoid overly long filenames
|
| 212 |
-
return safe[:240]
|
| 213 |
-
|
| 214 |
-
def resolve_model_key(self, model_name: str) -> Optional[str]:
|
| 215 |
-
"""Resolve the canonical model key used in model_cache.
|
| 216 |
-
Accepts several possible incoming forms (percent-encoded, decoded, sanitized) and returns
|
| 217 |
-
the key present in model_cache if any, otherwise None.
|
| 218 |
-
"""
|
| 219 |
-
# direct hit
|
| 220 |
-
if model_name in self.model_cache:
|
| 221 |
-
return model_name
|
| 222 |
-
# try URL-decoded
|
| 223 |
-
decoded = urllib.parse.unquote(model_name)
|
| 224 |
-
if decoded in self.model_cache:
|
| 225 |
-
return decoded
|
| 226 |
-
# try sanitized form matching stored map
|
| 227 |
-
safe = self.sanitize_model_name(model_name)
|
| 228 |
-
# see if we have an original key that maps to safe filename
|
| 229 |
-
for orig, safe_name in self.model_name_map.items():
|
| 230 |
-
if safe_name == safe:
|
| 231 |
-
return orig
|
| 232 |
-
return None
|
| 233 |
-
|
| 234 |
async def handle_vram_operation(self, operation: dict) -> dict:
|
| 235 |
"""Handle VRAM read/write operations (preserved from WebSocket implementation)"""
|
| 236 |
try:
|
|
@@ -662,6 +628,13 @@ async def get_cache(
|
|
| 662 |
detail=f"Cache get operation failed: {str(e)}"
|
| 663 |
)
|
| 664 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 665 |
@app.post("/api/v1/models/{model_name}/load")
|
| 666 |
async def load_model(
|
| 667 |
model_name: str,
|
|
@@ -670,10 +643,13 @@ async def load_model(
|
|
| 670 |
):
|
| 671 |
"""Load AI model"""
|
| 672 |
try:
|
|
|
|
| 673 |
logging.info(f"Received model load request for: {model_name}")
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
|
|
|
|
|
|
| 677 |
model_info = {
|
| 678 |
'model_name': model_name,
|
| 679 |
'model_data': request.model_data,
|
|
@@ -683,17 +659,16 @@ async def load_model(
|
|
| 683 |
'session_id': session['session_id']
|
| 684 |
}
|
| 685 |
|
| 686 |
-
# Store mapping and cache
|
| 687 |
server.model_cache[model_name] = model_info
|
| 688 |
-
server.model_name_map[model_name] = safe_name
|
| 689 |
|
| 690 |
-
# Store in persistent storage
|
| 691 |
model_file = server.models_path / f"{safe_name}.json"
|
|
|
|
|
|
|
| 692 |
with open(model_file, 'w') as f:
|
| 693 |
json.dump(model_info, f)
|
| 694 |
|
| 695 |
server.ops_counter += 1
|
| 696 |
-
logging.info(f"Model '{model_name}' saved to disk as '{safe_name}.json'")
|
| 697 |
return {
|
| 698 |
"status": "success",
|
| 699 |
"message": f"Model {model_name} loaded successfully",
|
|
@@ -704,7 +679,6 @@ async def load_model(
|
|
| 704 |
}
|
| 705 |
|
| 706 |
except Exception as e:
|
| 707 |
-
logging.exception("Model load operation failed")
|
| 708 |
raise HTTPException(
|
| 709 |
status_code=500,
|
| 710 |
detail=f"Model load operation failed: {str(e)}"
|
|
@@ -718,12 +692,25 @@ async def run_inference(
|
|
| 718 |
):
|
| 719 |
"""Run model inference"""
|
| 720 |
try:
|
| 721 |
-
logging.info(f"
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 725 |
|
| 726 |
-
# Simulate inference processing
|
|
|
|
| 727 |
result = {
|
| 728 |
"status": "success",
|
| 729 |
"output": request.input_data, # Echo input for now
|
|
@@ -731,16 +718,17 @@ async def run_inference(
|
|
| 731 |
"inference_time": 0.1,
|
| 732 |
"tokens_processed": len(request.input_data)
|
| 733 |
},
|
| 734 |
-
"model_info": server.model_cache
|
| 735 |
}
|
| 736 |
|
| 737 |
server.ops_counter += 1
|
|
|
|
| 738 |
return result
|
| 739 |
|
| 740 |
except HTTPException:
|
| 741 |
raise
|
| 742 |
except Exception as e:
|
| 743 |
-
logging.
|
| 744 |
raise HTTPException(
|
| 745 |
status_code=500,
|
| 746 |
detail=f"Inference operation failed: {str(e)}"
|
|
@@ -753,20 +741,38 @@ async def get_model_status(
|
|
| 753 |
):
|
| 754 |
"""Get model status"""
|
| 755 |
try:
|
| 756 |
-
|
| 757 |
-
|
|
|
|
|
|
|
|
|
|
| 758 |
return {
|
| 759 |
"status": "loaded",
|
| 760 |
-
"model_info": server.model_cache[
|
| 761 |
}
|
| 762 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 763 |
return {
|
| 764 |
-
"status": "
|
| 765 |
-
"
|
| 766 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 767 |
|
| 768 |
except Exception as e:
|
| 769 |
-
logging.
|
| 770 |
raise HTTPException(
|
| 771 |
status_code=500,
|
| 772 |
detail=f"Model status check failed: {str(e)}"
|
|
@@ -831,7 +837,6 @@ async def transfer_between_chips(
|
|
| 831 |
except HTTPException:
|
| 832 |
raise
|
| 833 |
except Exception as e:
|
| 834 |
-
logging.exception("Chip transfer failed")
|
| 835 |
raise HTTPException(
|
| 836 |
status_code=500,
|
| 837 |
detail=f"Chip transfer failed: {str(e)}"
|
|
@@ -864,7 +869,6 @@ async def create_sync_barrier(
|
|
| 864 |
}
|
| 865 |
|
| 866 |
except Exception as e:
|
| 867 |
-
logging.exception("Barrier creation failed")
|
| 868 |
raise HTTPException(
|
| 869 |
status_code=500,
|
| 870 |
detail=f"Barrier creation failed: {str(e)}"
|
|
@@ -902,7 +906,6 @@ async def wait_sync_barrier(
|
|
| 902 |
except HTTPException:
|
| 903 |
raise
|
| 904 |
except Exception as e:
|
| 905 |
-
logging.exception("Barrier wait failed")
|
| 906 |
raise HTTPException(
|
| 907 |
status_code=500,
|
| 908 |
detail=f"Barrier wait failed: {str(e)}"
|
|
|
|
| 16 |
import hashlib
|
| 17 |
import gzip
|
| 18 |
import base64
|
|
|
|
|
|
|
|
|
|
| 19 |
import logging
|
| 20 |
+
from pydantic import BaseModel
|
| 21 |
|
| 22 |
+
# Configure logging
|
| 23 |
+
logging.basicConfig(
|
| 24 |
+
level=logging.INFO,
|
| 25 |
+
format='%(asctime)s - %(levelname)s - %(message)s'
|
| 26 |
+
)
|
| 27 |
|
| 28 |
# Create FastAPI instance with enhanced configuration
|
| 29 |
app = FastAPI(
|
|
|
|
| 127 |
self.state_cache: Dict[str, Any] = {}
|
| 128 |
self.memory_cache: Dict[str, Any] = {}
|
| 129 |
self.model_cache: Dict[str, Any] = {}
|
|
|
|
|
|
|
| 130 |
|
| 131 |
# Session management for HTTP API
|
| 132 |
self.http_sessions: Dict[str, Dict[str, Any]] = {}
|
|
|
|
| 197 |
"""Decompress gzip data"""
|
| 198 |
return gzip.decompress(data)
|
| 199 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
async def handle_vram_operation(self, operation: dict) -> dict:
|
| 201 |
"""Handle VRAM read/write operations (preserved from WebSocket implementation)"""
|
| 202 |
try:
|
|
|
|
| 628 |
detail=f"Cache get operation failed: {str(e)}"
|
| 629 |
)
|
| 630 |
|
| 631 |
+
def sanitize_filename(name: str) -> str:
|
| 632 |
+
"""
|
| 633 |
+
Sanitize a string for safe file system usage.
|
| 634 |
+
Replaces slashes with double underscores.
|
| 635 |
+
"""
|
| 636 |
+
return name.replace('/', '__')
|
| 637 |
+
|
| 638 |
@app.post("/api/v1/models/{model_name}/load")
|
| 639 |
async def load_model(
|
| 640 |
model_name: str,
|
|
|
|
| 643 |
):
|
| 644 |
"""Load AI model"""
|
| 645 |
try:
|
| 646 |
+
# Log the received model name for debugging
|
| 647 |
logging.info(f"Received model load request for: {model_name}")
|
| 648 |
+
|
| 649 |
+
# Get safe filename for storage
|
| 650 |
+
safe_name = sanitize_filename(model_name)
|
| 651 |
+
|
| 652 |
+
# Store model information
|
| 653 |
model_info = {
|
| 654 |
'model_name': model_name,
|
| 655 |
'model_data': request.model_data,
|
|
|
|
| 659 |
'session_id': session['session_id']
|
| 660 |
}
|
| 661 |
|
|
|
|
| 662 |
server.model_cache[model_name] = model_info
|
|
|
|
| 663 |
|
| 664 |
+
# Store in persistent storage
|
| 665 |
model_file = server.models_path / f"{safe_name}.json"
|
| 666 |
+
logging.info(f"Storing model info at: {model_file}")
|
| 667 |
+
|
| 668 |
with open(model_file, 'w') as f:
|
| 669 |
json.dump(model_info, f)
|
| 670 |
|
| 671 |
server.ops_counter += 1
|
|
|
|
| 672 |
return {
|
| 673 |
"status": "success",
|
| 674 |
"message": f"Model {model_name} loaded successfully",
|
|
|
|
| 679 |
}
|
| 680 |
|
| 681 |
except Exception as e:
|
|
|
|
| 682 |
raise HTTPException(
|
| 683 |
status_code=500,
|
| 684 |
detail=f"Model load operation failed: {str(e)}"
|
|
|
|
| 692 |
):
|
| 693 |
"""Run model inference"""
|
| 694 |
try:
|
| 695 |
+
logging.info(f"Running inference - Raw model name: {model_name}")
|
| 696 |
+
safe_name = sanitize_model_name(model_name)
|
| 697 |
+
logging.info(f"Running inference - Safe model name: {safe_name}")
|
| 698 |
+
|
| 699 |
+
# Check if model is loaded (try both original and safe names)
|
| 700 |
+
if model_name not in server.model_cache:
|
| 701 |
+
# Try loading from file system using safe name
|
| 702 |
+
model_file = server.models_path / f"{safe_name}.json"
|
| 703 |
+
if not model_file.exists():
|
| 704 |
+
logging.error(f"Model {model_name} not found in cache or filesystem")
|
| 705 |
+
raise HTTPException(status_code=404, detail=f"Model {model_name} not loaded")
|
| 706 |
+
|
| 707 |
+
logging.info(f"Loading model info from file: {model_file}")
|
| 708 |
+
with open(model_file) as f:
|
| 709 |
+
model_info = json.load(f)
|
| 710 |
+
server.model_cache[model_name] = model_info
|
| 711 |
|
| 712 |
+
# Simulate inference processing
|
| 713 |
+
# In a real implementation, this would invoke the actual model
|
| 714 |
result = {
|
| 715 |
"status": "success",
|
| 716 |
"output": request.input_data, # Echo input for now
|
|
|
|
| 718 |
"inference_time": 0.1,
|
| 719 |
"tokens_processed": len(request.input_data)
|
| 720 |
},
|
| 721 |
+
"model_info": server.model_cache[model_name]
|
| 722 |
}
|
| 723 |
|
| 724 |
server.ops_counter += 1
|
| 725 |
+
logging.info(f"Inference completed successfully for model: {model_name}")
|
| 726 |
return result
|
| 727 |
|
| 728 |
except HTTPException:
|
| 729 |
raise
|
| 730 |
except Exception as e:
|
| 731 |
+
logging.error(f"Inference operation failed for {model_name}: {str(e)}")
|
| 732 |
raise HTTPException(
|
| 733 |
status_code=500,
|
| 734 |
detail=f"Inference operation failed: {str(e)}"
|
|
|
|
| 741 |
):
|
| 742 |
"""Get model status"""
|
| 743 |
try:
|
| 744 |
+
logging.info(f"Checking model status for: {model_name}")
|
| 745 |
+
|
| 746 |
+
# Check cache first
|
| 747 |
+
if model_name in server.model_cache:
|
| 748 |
+
logging.info(f"Model {model_name} found in cache")
|
| 749 |
return {
|
| 750 |
"status": "loaded",
|
| 751 |
+
"model_info": server.model_cache[model_name]
|
| 752 |
}
|
| 753 |
+
|
| 754 |
+
# Check file system using safe name
|
| 755 |
+
safe_name = sanitize_filename(model_name)
|
| 756 |
+
model_file = server.models_path / f"{safe_name}.json"
|
| 757 |
+
if model_file.exists():
|
| 758 |
+
logging.info(f"Model file found: {model_file}")
|
| 759 |
+
with open(model_file) as f:
|
| 760 |
+
model_info = json.load(f)
|
| 761 |
+
# Update cache
|
| 762 |
+
server.model_cache[model_name] = model_info
|
| 763 |
return {
|
| 764 |
+
"status": "loaded",
|
| 765 |
+
"model_info": model_info
|
| 766 |
}
|
| 767 |
+
|
| 768 |
+
logging.info(f"Model {model_name} not found in cache or filesystem")
|
| 769 |
+
return {
|
| 770 |
+
"status": "not_loaded",
|
| 771 |
+
"message": f"Model {model_name} is not loaded"
|
| 772 |
+
}
|
| 773 |
|
| 774 |
except Exception as e:
|
| 775 |
+
logging.error(f"Model status check failed for {model_name}: {str(e)}")
|
| 776 |
raise HTTPException(
|
| 777 |
status_code=500,
|
| 778 |
detail=f"Model status check failed: {str(e)}"
|
|
|
|
| 837 |
except HTTPException:
|
| 838 |
raise
|
| 839 |
except Exception as e:
|
|
|
|
| 840 |
raise HTTPException(
|
| 841 |
status_code=500,
|
| 842 |
detail=f"Chip transfer failed: {str(e)}"
|
|
|
|
| 869 |
}
|
| 870 |
|
| 871 |
except Exception as e:
|
|
|
|
| 872 |
raise HTTPException(
|
| 873 |
status_code=500,
|
| 874 |
detail=f"Barrier creation failed: {str(e)}"
|
|
|
|
| 906 |
except HTTPException:
|
| 907 |
raise
|
| 908 |
except Exception as e:
|
|
|
|
| 909 |
raise HTTPException(
|
| 910 |
status_code=500,
|
| 911 |
detail=f"Barrier wait failed: {str(e)}"
|