File size: 5,363 Bytes
6480add
 
 
 
 
 
 
 
 
 
 
5b62571
 
6480add
 
 
 
 
 
 
 
 
5b62571
 
 
 
 
6480add
 
 
 
 
 
5b62571
 
 
 
 
 
 
 
6480add
 
 
 
 
 
 
5b62571
6480add
 
 
 
5b62571
 
 
 
 
 
 
 
6480add
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b62571
6480add
 
 
 
 
 
 
 
 
5b62571
6480add
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b62571
 
 
 
 
 
 
 
 
6480add
 
 
 
5b62571
 
6480add
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b62571
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""
Z.AI Proxy - OpenAI-compatible API for Z.AI
"""
import asyncio
import logging
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, Depends, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials

from config import settings
from models import ChatCompletionRequest, ModelsResponse, ModelInfo
from proxy_handler import ProxyHandler  # 確保 proxy_handler.py 在同級目錄
from cookie_manager import cookie_manager

# Configure logging
logging.basicConfig(
    level=getattr(logging, settings.LOG_LEVEL),
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# --- MODIFICATION START ---
# 1. 創建一個全局變量來持有 handler 實例
proxy_handler: ProxyHandler | None = None
# --- MODIFICATION END ---

# Security
security = HTTPBearer(auto_error=False)

@asynccontextmanager
async def lifespan(app: FastAPI):
    """Application lifespan manager"""
    global proxy_handler
    logger.info("Application startup: Initializing ProxyHandler and starting background tasks...")

    # --- MODIFICATION START ---
    # 2. 在應用啟動時初始化 ProxyHandler
    proxy_handler = ProxyHandler()
    # --- MODIFICATION END ---

    # Start background tasks
    health_check_task = asyncio.create_task(cookie_manager.periodic_health_check())
    
    try:
        yield
    finally:
        # Cleanup
        logger.info("Application shutdown: Cleaning up resources...")
        health_check_task.cancel()
        try:
            await health_check_task
        except asyncio.CancelledError:
            logger.info("Cookie health check task cancelled.")
        
        # --- MODIFICATION START ---
        # 3. 在應用關閉時,安全地關閉 ProxyHandler 的 client
        if proxy_handler:
            await proxy_handler.aclose()
            logger.info("ProxyHandler client closed.")
        # --- MODIFICATION END ---

# Create FastAPI app
app = FastAPI(
    title="Z.AI Proxy",
    description="OpenAI-compatible API proxy for Z.AI",
    version="1.0.0",
    lifespan=lifespan
)

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

async def verify_auth(credentials: HTTPAuthorizationCredentials = Depends(security)):
    """Verify authentication with fixed API key"""
    if not credentials:
        raise HTTPException(status_code=401, detail="Authorization header required")

    if credentials.credentials != settings.API_KEY:
        raise HTTPException(status_code=401, detail="Invalid API key")

    return credentials.credentials

@app.get("/v1/models", response_model=ModelsResponse)
async def list_models():
    """List available models"""
    models = [
        ModelInfo(
            id=settings.MODEL_ID,  # 建議使用 settings.MODEL_NAME 以保持一致性
            object="model",
            owned_by="z-ai"
        )
    ]
    return ModelsResponse(data=models)

@app.post("/v1/chat/completions")
async def chat_completions(
    request: ChatCompletionRequest,
    _auth_token: str = Depends(verify_auth)  # 變數名加底線表示未使用
):
    """Create chat completion"""
    try:
        if not settings or not settings.COOKIES:
            raise HTTPException(
                status_code=503,
                detail="Service unavailable: No Z.AI cookies configured. Please set Z_AI_COOKIES environment variable."
            )

        if request.model != settings.MODEL_NAME:
            raise HTTPException(
                status_code=400,
                detail=f"Model '{request.model}' not supported. Use '{settings.MODEL_NAME}'"
            )

        # --- MODIFICATION START ---
        # 4. 直接使用全局的 proxy_handler 實例,並移除 async with 塊
        if not proxy_handler:
            # 這種情況理論上不應發生,因為 lifespan 會先執行
            logger.error("Proxy handler is not initialized.")
            raise HTTPException(status_code=503, detail="Service is not ready.")

        return await proxy_handler.handle_chat_completion(request)
        # --- MODIFICATION END ---

    except HTTPException:
        raise
    except Exception as e:
        # 使用 logger.exception 可以記錄完整的 traceback
        logger.exception(f"Unexpected error in chat_completions endpoint: {e}")
        raise HTTPException(status_code=500, detail="Internal server error")

@app.get("/health")
async def health_check():
    """Health check endpoint"""
    return {"status": "healthy", "model": settings.MODEL_NAME}

@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
    """Custom HTTP exception handler"""
    from fastapi.responses import JSONResponse
    return JSONResponse(
        status_code=exc.status_code,
        content={
            "error": {
                "message": exc.detail,
                "type": "invalid_request_error",
                "code": exc.status_code
            }
        }
    )

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(
        "main:app",
        host=settings.HOST,
        port=settings.PORT,
        reload=False,
        log_level=settings.LOG_LEVEL.lower()
    )