File size: 6,299 Bytes
d8581cf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | #!/usr/bin/env python3
"""
Launch script for the BCE Prediction Web Server
"""
import argparse
import uvicorn
import sys
import time
import torch
from pathlib import Path
def preload_model(device_id=-1, model_path=None, verbose=True):
"""
预加载ReCEP模型以减少首次请求延迟
"""
try:
if verbose:
print("🔄 Preloading ReCEP model...")
# Add project root to path
script_dir = Path(__file__).parent
project_root = script_dir.parents[2]
sys.path.insert(0, str(project_root))
from bce.model.ReCEP import ReCEP
from bce.utils.constants import BASE_DIR
# 设置设备
if device_id >= 0 and torch.cuda.is_available():
device = torch.device(f"cuda:{device_id}")
if verbose:
print(f"🎯 Using GPU: {torch.cuda.get_device_name(device_id)}")
else:
device = torch.device("cpu")
if verbose:
print("🎯 Using CPU")
# 使用默认模型路径如果未指定
if model_path is None:
model_path = f"{BASE_DIR}/models/ReCEP/20250626_110438/best_mcc_model.bin"
start_time = time.time()
# 加载模型
model, threshold = ReCEP.load(model_path, device=device, strict=False, verbose=False)
model.eval()
# 预热GPU(如果使用GPU)
if device.type == 'cuda':
# 创建一个小的测试张量来预热GPU
dummy_tensor = torch.randn(10, 512).to(device)
with torch.no_grad():
_ = dummy_tensor.sum()
del dummy_tensor
torch.cuda.synchronize()
load_time = time.time() - start_time
if verbose:
print(f"✅ Model preloaded successfully in {load_time:.2f}s")
print(f"📏 Model threshold: {threshold:.4f}")
return True
except Exception as e:
if verbose:
print(f"⚠️ Model preload failed: {str(e)}")
print(" Server will load model on first request")
return False
def main():
parser = argparse.ArgumentParser(description="Launch BCE Prediction Web Server")
parser.add_argument("--host", default="0.0.0.0", help="Host to bind to (default: 0.0.0.0)")
parser.add_argument("--port", type=int, default=8000, help="Port to bind to (default: 8000)")
parser.add_argument("--reload", action="store_true", help="Enable auto-reload for development")
parser.add_argument("--workers", type=int, default=1, help="Number of worker processes")
parser.add_argument("--log-level", default="info", choices=["debug", "info", "warning", "error", "critical"],
help="Log level (default: info)")
parser.add_argument("--preload", action="store_true", default=True, help="Preload model on startup (default: True)")
parser.add_argument("--no-preload", action="store_true", help="Skip model preloading")
parser.add_argument("--device-id", type=int, default=-1, help="Device ID for model preloading (default: -1 for CPU)")
parser.add_argument("--model-path", type=str, default=None, help="Custom model path for preloading")
args = parser.parse_args()
# Ensure we're in the correct directory
script_dir = Path(__file__).parent
sys.path.insert(0, str(script_dir))
# Change to the website directory to ensure relative imports work
import os
original_cwd = os.getcwd()
os.chdir(script_dir)
print(f"📁 Working directory: {script_dir}")
print(f"🐍 Python path includes: {script_dir}")
print(f"🚀 Starting BCE Prediction Server...")
print(f"📍 Server will be available at: http://{args.host}:{args.port}")
print(f"📖 API documentation at: http://{args.host}:{args.port}/docs")
print(f"🔄 Auto-reload: {'enabled' if args.reload else 'disabled'}")
# 模型预加载
should_preload = args.preload and not args.no_preload and not args.reload
if should_preload:
success = preload_model(
device_id=args.device_id,
model_path=args.model_path,
verbose=True
)
if success:
print("🎉 Server ready for fast predictions!")
else:
print("⚠️ Server starting without preload")
else:
if args.reload:
print("ℹ️ Skipping preload in reload mode")
else:
print("ℹ️ Skipping model preload (use --preload to enable)")
print("-" * 50)
# Run the server - Use direct import to avoid module resolution issues
print("🔧 Starting server with direct import method...")
# Ensure the website directory is first in Python path to avoid conflicts
if str(script_dir) not in sys.path:
sys.path.insert(0, str(script_dir))
# Remove any conflicting paths that might have main.py
project_root = script_dir.parents[2]
if str(project_root) in sys.path:
sys.path.remove(str(project_root))
print(f"🔍 Python path: {sys.path[:3]}...") # Show first 3 paths
try:
import main
print(f"📄 Imported main from: {main.__file__}")
app = main.app
uvicorn.run(
app,
host=args.host,
port=args.port,
reload=args.reload,
workers=args.workers if not args.reload else 1,
log_level=args.log_level,
access_log=True
)
except Exception as e:
print(f"❌ Direct import failed: {str(e)}")
print("💡 Trying string-based import...")
# Fallback: try string-based import
try:
uvicorn.run(
"main:app",
host=args.host,
port=args.port,
reload=args.reload,
workers=args.workers if not args.reload else 1,
log_level=args.log_level,
access_log=True
)
except Exception as e2:
print(f"❌ String-based import also failed: {str(e2)}")
print("🔧 Please check your environment and dependencies")
if __name__ == "__main__":
main() |