File size: 5,427 Bytes
5a43cd4
41ace18
21577a3
 
5a43cd4
e0cac51
21577a3
 
 
 
 
 
41ace18
 
 
 
 
 
 
21577a3
 
 
 
e0cac51
21577a3
 
 
 
e0cac51
 
 
 
 
21577a3
1b2c6f8
5a43cd4
 
21577a3
e0cac51
21577a3
e0cac51
21577a3
 
 
 
 
 
5a43cd4
21577a3
 
41ace18
5a4e6c4
41ace18
 
 
 
 
 
 
 
 
 
78ec9da
41ace18
 
 
5a4e6c4
41ace18
 
 
 
 
 
21577a3
 
 
41ace18
21577a3
5a43cd4
1b2c6f8
 
 
e0cac51
41ace18
 
5a43cd4
21577a3
41ace18
21577a3
78ec9da
21577a3
e0cac51
 
21577a3
e0cac51
21577a3
5a43cd4
21577a3
1b2c6f8
21577a3
 
 
 
1b2c6f8
5a43cd4
21577a3
 
 
 
1b2c6f8
5a43cd4
1b2c6f8
e0cac51
1b2c6f8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gradio as gr
import time
import json
import uuid
import uvicorn
import traceback
from fastapi import FastAPI, Request, HTTPException, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from starlette.responses import StreamingResponse
from pydantic import BaseModel
from typing import List, Optional

from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.common.keys import Keys
from selenium.webdriver.chrome.service import Service as ChromeService
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC

# --- 1. FastAPI 应用和认证设置 ---
app = FastAPI(
    title="SAI-ChatBot OpenAI-Compatible API",
    description="使用 Selenium 自动化在后台与 SAI-ChatBot 交互,并以 OpenAI API 格式返回结果。",
    version="1.2.0-final"
)
auth_scheme = HTTPBearer()

def api_key_auth(credentials: HTTPAuthorizationCredentials = Depends(auth_scheme)):
    if not credentials:
        raise HTTPException(status_code=401, detail="Not authenticated")
    
    # 【已修复】正确的属性名是 .credentials
    return credentials.credentials

# --- 2. OpenAI 格式的数据模型 ---
class ChatMessage(BaseModel): role: str; content: str
class ChatCompletionRequest(BaseModel): model: str; messages: List[ChatMessage]; stream: Optional[bool] = False

# --- 3. Selenium 自动化核心函数 ---
def get_sai_response(prompt_text: str):
    # (此函数内部保持不变,因为之前的调试日志显示它没有被执行到)
    options = webdriver.ChromeOptions()
    options.add_argument("--headless")
    options.add_argument("--no-sandbox")
    options.add_argument("--disable-dev-shm-usage")
    options.add_argument("--disable-gpu")
    options.binary_location = "/usr/bin/chromium"
    
    service = ChromeService(executable_path='/usr/bin/chromedriver')
    driver = None
    try:
        driver = webdriver.Chrome(service=service, options=options)
        driver.get("https://sai.coludai.cn/")
        
        wait = WebDriverWait(driver, 20)
        textarea_selector = 'textarea[placeholder="随时与未来对话,探索无限可能...."]'
        textarea = wait.until(EC.presence_of_element_located((By.CSS_SELECTOR, textarea_selector)))
        
        textarea.send_keys(prompt_text)
        textarea.send_keys(Keys.RETURN)
        
        last_assistant_selector = "(.//div[@class='message-item' and @type='assistant'])[last()]"
        wait.until(EC.presence_of_element_located((By.XPATH, last_assistant_selector)))
        last_response_element = driver.find_element(By.XPATH, last_assistant_selector)

        previous_text = ""
        max_wait_time = 120
        start_time = time.time()
        
        while time.time() - start_time < max_wait_time:
            try:
                markdown_body = last_response_element.find_element(By.CSS_SELECTOR, '.markdown-body')
                current_text = markdown_body.text
                if current_text != previous_text:
                    new_text_chunk = current_text[len(previous_text):]
                    yield new_text_chunk
                    previous_text = current_text
                
                time.sleep(1)
                final_text_check = markdown_body.text
                if final_text_check == previous_text and final_text_check != "":
                    break
            except Exception:
                time.sleep(0.5)
    except Exception as e:
        error_message = f"自动化过程中发生严重错误: {e}\n\n详细信息请查看 Hugging Face Space 的日志。"
        yield error_message
    finally:
        if driver:
            driver.quit()

# --- 4. API 端点定义 ---
# 【已修改】将 api_key_auth 的返回值赋给一个未使用的变量 _
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest, _: str = Depends(api_key_auth)):
    last_user_message = next((msg.content for msg in reversed(request.messages) if msg.role == 'user'), None)
    if not last_user_message: raise HTTPException(status_code=400, detail="No user message found")

    response_id, created_timestamp = f"chatcmpl-{uuid.uuid4()}", int(time.time())

    if request.stream:
        async def stream_generator():
            for chunk in get_sai_response(last_user_message):
                if not chunk: continue
                response_chunk = {"id": response_id, "object": "chat.completion.chunk", "created": created_timestamp, "model": "sai-chatbot-l6", "choices": [{"index": 0, "delta": {"content": chunk}, "finish_reason": None}]}
                yield f"data: {json.dumps(response_chunk)}\n\n"
            yield f"data: [DONE]\n\n"
        return StreamingResponse(stream_generator(), media_type="text/event-stream")
    else:
        full_content = "".join([chunk for chunk in get_sai_response(last_user_message)])
        return {"id": response_id, "object": "chat.completion", "created": created_timestamp, "model": "sai-chatbot-l6", "choices": [{"index": 0, "message": {"role": "assistant", "content": full_content}, "finish_reason": "stop"}], "usage": {"prompt_tokens": len(last_user_message), "completion_tokens": len(full_content), "total_tokens": len(last_user_message) + len(full_content)}}

# --- 5. 启动服务器 ---
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)