File size: 6,299 Bytes
d8581cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
#!/usr/bin/env python3
"""
Launch script for the BCE Prediction Web Server
"""

import argparse
import uvicorn
import sys
import time
import torch
from pathlib import Path

def preload_model(device_id=-1, model_path=None, verbose=True):
    """
    预加载ReCEP模型以减少首次请求延迟
    """
    try:
        if verbose:
            print("🔄 Preloading ReCEP model...")
        
        # Add project root to path
        script_dir = Path(__file__).parent
        project_root = script_dir.parents[2]
        sys.path.insert(0, str(project_root))
        
        from bce.model.ReCEP import ReCEP
        from bce.utils.constants import BASE_DIR
        
        # 设置设备
        if device_id >= 0 and torch.cuda.is_available():
            device = torch.device(f"cuda:{device_id}")
            if verbose:
                print(f"🎯 Using GPU: {torch.cuda.get_device_name(device_id)}")
        else:
            device = torch.device("cpu")
            if verbose:
                print("🎯 Using CPU")
        
        # 使用默认模型路径如果未指定
        if model_path is None:
            model_path = f"{BASE_DIR}/models/ReCEP/20250626_110438/best_mcc_model.bin"
        
        start_time = time.time()
        
        # 加载模型
        model, threshold = ReCEP.load(model_path, device=device, strict=False, verbose=False)
        model.eval()
        
        # 预热GPU(如果使用GPU)
        if device.type == 'cuda':
            # 创建一个小的测试张量来预热GPU
            dummy_tensor = torch.randn(10, 512).to(device)
            with torch.no_grad():
                _ = dummy_tensor.sum()
            del dummy_tensor
            torch.cuda.synchronize()
        
        load_time = time.time() - start_time
        
        if verbose:
            print(f"✅ Model preloaded successfully in {load_time:.2f}s")
            print(f"📏 Model threshold: {threshold:.4f}")
        
        return True
        
    except Exception as e:
        if verbose:
            print(f"⚠️  Model preload failed: {str(e)}")
            print("   Server will load model on first request")
        return False

def main():
    parser = argparse.ArgumentParser(description="Launch BCE Prediction Web Server")
    parser.add_argument("--host", default="0.0.0.0", help="Host to bind to (default: 0.0.0.0)")
    parser.add_argument("--port", type=int, default=8000, help="Port to bind to (default: 8000)")
    parser.add_argument("--reload", action="store_true", help="Enable auto-reload for development")
    parser.add_argument("--workers", type=int, default=1, help="Number of worker processes")
    parser.add_argument("--log-level", default="info", choices=["debug", "info", "warning", "error", "critical"],
                       help="Log level (default: info)")
    parser.add_argument("--preload", action="store_true", default=True, help="Preload model on startup (default: True)")
    parser.add_argument("--no-preload", action="store_true", help="Skip model preloading")
    parser.add_argument("--device-id", type=int, default=-1, help="Device ID for model preloading (default: -1 for CPU)")
    parser.add_argument("--model-path", type=str, default=None, help="Custom model path for preloading")
    
    args = parser.parse_args()
    
    # Ensure we're in the correct directory
    script_dir = Path(__file__).parent
    sys.path.insert(0, str(script_dir))
    
    # Change to the website directory to ensure relative imports work
    import os
    original_cwd = os.getcwd()
    os.chdir(script_dir)
    
    print(f"📁 Working directory: {script_dir}")
    print(f"🐍 Python path includes: {script_dir}")
    
    print(f"🚀 Starting BCE Prediction Server...")
    print(f"📍 Server will be available at: http://{args.host}:{args.port}")
    print(f"📖 API documentation at: http://{args.host}:{args.port}/docs")
    print(f"🔄 Auto-reload: {'enabled' if args.reload else 'disabled'}")
    
    # 模型预加载
    should_preload = args.preload and not args.no_preload and not args.reload
    if should_preload:
        success = preload_model(
            device_id=args.device_id,
            model_path=args.model_path,
            verbose=True
        )
        if success:
            print("🎉 Server ready for fast predictions!")
        else:
            print("⚠️  Server starting without preload")
    else:
        if args.reload:
            print("ℹ️  Skipping preload in reload mode")
        else:
            print("ℹ️  Skipping model preload (use --preload to enable)")
    
    print("-" * 50)
    
    # Run the server - Use direct import to avoid module resolution issues
    print("🔧 Starting server with direct import method...")
    
    # Ensure the website directory is first in Python path to avoid conflicts
    if str(script_dir) not in sys.path:
        sys.path.insert(0, str(script_dir))
    
    # Remove any conflicting paths that might have main.py
    project_root = script_dir.parents[2]
    if str(project_root) in sys.path:
        sys.path.remove(str(project_root))
    
    print(f"🔍 Python path: {sys.path[:3]}...")  # Show first 3 paths
    
    try:
        import main
        print(f"📄 Imported main from: {main.__file__}")
        app = main.app
        uvicorn.run(
            app,
            host=args.host,
            port=args.port,
            reload=args.reload,
            workers=args.workers if not args.reload else 1,
            log_level=args.log_level,
            access_log=True
        )
    except Exception as e:
        print(f"❌ Direct import failed: {str(e)}")
        print("💡 Trying string-based import...")
        
        # Fallback: try string-based import
        try:
            uvicorn.run(
                "main:app",
                host=args.host,
                port=args.port,
                reload=args.reload,
                workers=args.workers if not args.reload else 1,
                log_level=args.log_level,
                access_log=True
            )
        except Exception as e2:
            print(f"❌ String-based import also failed: {str(e2)}")
            print("🔧 Please check your environment and dependencies")

if __name__ == "__main__":
    main()