bumi / app.py
bunnybun07's picture
Rename main.py to app.py
12a9bcc verified
import json
import os
import time
import uuid
import threading
import requests
import ast
import secrets
import base64
from typing import Any, Dict, List, Optional, TypedDict, Union, Generator
from fastapi import FastAPI, HTTPException, Depends
from fastapi.responses import StreamingResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel, Field
# Weam Account Management
class WeamAccount(TypedDict):
jwt: str
is_valid: bool
last_used: float
error_count: int
# Pydantic Models
class ChatMessage(BaseModel):
role: str
content: Union[str, List[Dict[str, Any]]]
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
stream: bool = True
temperature: Optional[float] = None
max_tokens: Optional[int] = None
top_p: Optional[float] = None
class ModelInfo(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "weam"
class ModelList(BaseModel):
object: str = "list"
data: List[ModelInfo]
class ChatCompletionChoice(BaseModel):
message: ChatMessage
index: int = 0
finish_reason: str = "stop"
class ChatCompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex}")
object: str = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionChoice]
usage: Dict[str, int] = Field(
default_factory=lambda: {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
}
)
class StreamChoice(BaseModel):
delta: Dict[str, Any] = Field(default_factory=dict)
index: int = 0
finish_reason: Optional[str] = None
class StreamResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex}")
object: str = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[StreamChoice]
# Global variables
VALID_CLIENT_KEYS: set = set()
WEAM_ACCOUNTS: List[WeamAccount] = []
account_lock = threading.Lock()
MAX_ERROR_COUNT = 3
ERROR_COOLDOWN = 300 # 5 minutes cooldown for accounts with errors
DEBUG_MODE = os.environ.get("DEBUG_MODE", "false").lower() == "true"
REQUEST_TIMEOUT = 60.0
# Weam models list
WEAM_MODELS = [
"claude-3-5-sonnet-latest",
"claude-3-7-sonnet-latest",
"claude-sonnet-4-20250514",
"claude-opus-4-20250514",
]
# FastAPI App
app = FastAPI(title="Weam OpenAI API Adapter")
security = HTTPBearer(auto_error=False)
def log_debug(message: str):
"""Debug日志函数"""
if DEBUG_MODE:
print(f"[DEBUG] {message}")
def load_client_api_keys():
"""Load client API keys from client_api_keys.json"""
global VALID_CLIENT_KEYS
try:
with open("client_api_keys.json", "r", encoding="utf-8") as f:
keys = json.load(f)
VALID_CLIENT_KEYS = set(keys) if isinstance(keys, list) else set()
print(f"Successfully loaded {len(VALID_CLIENT_KEYS)} client API keys.")
except FileNotFoundError:
print("Error: client_api_keys.json not found. Client authentication will fail.")
VALID_CLIENT_KEYS = set()
except Exception as e:
print(f"Error loading client_api_keys.json: {e}")
VALID_CLIENT_KEYS = set()
def load_weam_accounts():
"""Load Weam accounts from weam.json"""
global WEAM_ACCOUNTS
WEAM_ACCOUNTS = []
try:
with open("weam.json", "r", encoding="utf-8") as f:
accounts = json.load(f)
if not isinstance(accounts, list):
print("Warning: weam.json should contain a list of account objects.")
return
for acc in accounts:
jwt = acc.get("jwt")
if jwt:
WEAM_ACCOUNTS.append({
"jwt": jwt,
"is_valid": True,
"last_used": 0,
"error_count": 0,
})
print(f"Successfully loaded {len(WEAM_ACCOUNTS)} Weam accounts.")
except FileNotFoundError:
print("Error: weam.json not found. API calls will fail.")
except Exception as e:
print(f"Error loading weam.json: {e}")
def get_best_weam_account() -> Optional[WeamAccount]:
"""Get the best available Weam account using a smart selection algorithm."""
with account_lock:
now = time.time()
valid_accounts = [
acc for acc in WEAM_ACCOUNTS
if acc["is_valid"] and (
acc["error_count"] < MAX_ERROR_COUNT or
now - acc["last_used"] > ERROR_COOLDOWN
)
]
if not valid_accounts:
return None
# Reset error count for accounts that have been in cooldown
for acc in valid_accounts:
if acc["error_count"] >= MAX_ERROR_COUNT and now - acc["last_used"] > ERROR_COOLDOWN:
acc["error_count"] = 0
# Sort by last used (oldest first) and error count (lowest first)
valid_accounts.sort(key=lambda x: (x["last_used"], x["error_count"]))
account = valid_accounts[0]
account["last_used"] = now
return account
def upload_image(jwt: str, image_bytes: bytes, filename: str) -> str:
"""Upload image to weam.ai and return the URI."""
url = "https://api.weam.ai/api/upload/file"
payload = {"brainId": "699b908a19999177cf7e496a", "vectorApiCall": "true"}
files = [("files", (filename, image_bytes, "image/png"))]
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/137.0.0.0 Safari/537.36 Edg/137.0.0.0",
"Accept": "application/json, text/plain, */*",
"authorization": f"jwt {jwt}",
"origin": "https://app.weam.ai",
"referer": "https://app.weam.ai/",
}
response = requests.post(url, data=payload, files=files, headers=headers, timeout=REQUEST_TIMEOUT)
response.raise_for_status()
return response.json()["data"][0]["uri"]
def process_messages(messages: List[ChatMessage], jwt: str) -> tuple[str, List[str]]:
"""Extract text and process images from messages."""
query_parts = []
image_urls = []
for msg in messages:
if isinstance(msg.content, str):
query_parts.append(msg.content)
elif isinstance(msg.content, list):
for item in msg.content:
if item.get("type") == "text":
query_parts.append(item.get("text", ""))
elif item.get("type") == "image_url":
image_url = item.get("image_url", {}).get("url", "")
if image_url.startswith("data:image/"):
# Handle base64 encoded images
try:
header, encoded = image_url.split(",", 1)
image_bytes = base64.b64decode(encoded)
filename = f"upload-{uuid.uuid4().hex}.png"
uri = upload_image(jwt, image_bytes, filename)
image_urls.append(f"https://cdn.weam.ai{uri}")
except Exception as e:
log_debug(f"Failed to process image: {e}")
raise HTTPException(status_code=400, detail=f"Invalid image format: {e}")
return "\n\n".join(query_parts), image_urls
async def authenticate_client(
auth: Optional[HTTPAuthorizationCredentials] = Depends(security),
):
"""Authenticate client based on API key in Authorization header"""
if not VALID_CLIENT_KEYS:
raise HTTPException(
status_code=503,
detail="Service unavailable: Client API keys not configured on server.",
)
if not auth or not auth.credentials:
raise HTTPException(
status_code=401,
detail="API key required in Authorization header.",
headers={"WWW-Authenticate": "Bearer"},
)
if auth.credentials not in VALID_CLIENT_KEYS:
raise HTTPException(status_code=403, detail="Invalid client API key.")
@app.on_event("startup")
async def startup():
"""应用启动时初始化配置"""
print("Starting Weam OpenAI API Adapter server...")
load_client_api_keys()
load_weam_accounts()
print("Server initialization completed.")
@app.get("/v1/models", response_model=ModelList)
async def list_v1_models(_: None = Depends(authenticate_client)):
"""List available models - authenticated"""
return ModelList(data=[ModelInfo(id=model) for model in WEAM_MODELS])
@app.get("/models", response_model=ModelList)
async def list_models_no_auth():
"""List available models without authentication - for client compatibility"""
return ModelList(data=[ModelInfo(id=model) for model in WEAM_MODELS])
@app.post("/v1/chat/completions")
async def chat_completions(
request: ChatCompletionRequest, _: None = Depends(authenticate_client)
):
"""Creates a chat completion using the Weam API."""
if request.model not in WEAM_MODELS:
raise HTTPException(status_code=404, detail=f"Model '{request.model}' not found.")
if not request.messages:
raise HTTPException(status_code=400, detail="No messages provided in the request.")
log_debug(f"Processing request for model: {request.model}")
request_id = f"chatcmpl-{uuid.uuid4().hex}"
# Try all accounts until one works
for attempt in range(len(WEAM_ACCOUNTS)):
account = get_best_weam_account()
if not account:
raise HTTPException(status_code=503, detail="No valid Weam accounts available.")
jwt = account["jwt"]
log_debug(f"Using account with JWT ending in ...{jwt[-6:]}")
try:
query, image_urls = process_messages(request.messages, jwt)
log_debug(f"Query length: {len(query)}, Images: {len(image_urls)}")
if request.stream:
return StreamingResponse(
weam_stream_generator(request.model, query, image_urls, jwt, request_id),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
else:
return build_non_stream_response(request.model, query, image_urls, jwt, request_id)
except Exception as e:
error_detail = str(e)
log_debug(f"Weam API error: {error_detail}")
with account_lock:
account["error_count"] += 1
if "401" in error_detail or "unauthorized" in error_detail.lower():
account["is_valid"] = False
log_debug(f"Account marked as invalid due to auth error")
# All attempts failed
raise HTTPException(status_code=503, detail="All attempts to contact Weam API failed.")
def weam_stream_generator(model: str, query: str, image_urls: List[str], jwt: str, request_id: str) -> Generator[str, None, None]:
"""Generate streaming response from Weam API."""
url = "https://pyapi.weam.ai/api/tool/stream-tool-chat-with-openai"
created_time = int(time.time())
payload = {
"thread_id": secrets.token_hex(12),
"query": query,
"prompt_id": None,
"llm_apikey": "684685b2f24ae32c999cbc93",
"chat_session_id": secrets.token_hex(12),
"image_url": image_urls,
"company_id": "6846857af24ae32c7b1cbc32",
"delay_chunk": 0.02,
"code": "ANTHROPIC",
"model_name": model,
"msgCredit": 0,
}
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/137.0.0.0 Safari/537.36 Edg/137.0.0.0",
"Content-Type": "application/json",
"authorization": f"jwt {jwt}",
"origin": "https://app.weam.ai",
"referer": "https://app.weam.ai/",
}
# Send initial role message
yield f"data: {StreamResponse(id=request_id, created=created_time, model=model, choices=[StreamChoice(delta={'role': 'assistant'})]).json()}\n\n"
try:
with requests.post(url, data=json.dumps(payload), headers=headers, stream=True, timeout=REQUEST_TIMEOUT) as response:
response.raise_for_status()
for line in response.iter_lines():
if line:
text = line.decode("utf-8")
if text.startswith("data: "):
byte_str = text[6:] # 去掉 "data: " 前缀
try:
byte_obj = ast.literal_eval(byte_str)
if isinstance(byte_obj, bytes):
decoded = byte_obj.decode("utf-8")
yield f"data: {StreamResponse(id=request_id, created=created_time, model=model, choices=[StreamChoice(delta={'content': decoded})]).json()}\n\n"
except (SyntaxError, ValueError) as e:
log_debug(f"Parse error: {e}")
except requests.exceptions.ChunkedEncodingError:
# 这个错误通常发生在响应结束时,可以安全地忽略
log_debug("ChunkedEncodingError caught - stream likely completed")
except Exception as e:
log_debug(f"Stream error: {e}")
yield f"data: {json.dumps({'error': {'message': str(e)}})}\n\n"
# Always send completion message
yield f"data: {StreamResponse(id=request_id, created=created_time, model=model, choices=[StreamChoice(delta={}, finish_reason='stop')]).json()}\n\n"
yield "data: [DONE]\n\n"
def build_non_stream_response(model: str, query: str, image_urls: List[str], jwt: str, request_id: str) -> ChatCompletionResponse:
"""Build non-streaming response by accumulating stream chunks."""
full_content = ""
for chunk in weam_stream_generator(model, query, image_urls, jwt, request_id):
if not chunk.startswith("data: ") or chunk.strip() == "data: [DONE]":
continue
try:
data = json.loads(chunk[6:]) # 去掉 "data: " 前缀
if "choices" in data and data["choices"]:
delta = data["choices"][0].get("delta", {})
if "content" in delta and delta["content"]:
full_content += delta["content"]
except json.JSONDecodeError:
continue
return ChatCompletionResponse(
id=request_id,
model=model,
choices=[ChatCompletionChoice(message=ChatMessage(role="assistant", content=full_content))]
)
if __name__ == "__main__":
import uvicorn
if os.environ.get("DEBUG_MODE", "").lower() == "true":
DEBUG_MODE = True
print("Debug mode enabled via environment variable")
if not os.path.exists("weam.json"):
print("Warning: weam.json not found. Creating a dummy file.")
dummy_data = [{"jwt": "your_jwt_here"}]
with open("weam.json", "w", encoding="utf-8") as f:
json.dump(dummy_data, f, indent=4)
print("Created dummy weam.json. Please replace with valid Weam JWTs.")
if not os.path.exists("client_api_keys.json"):
print("Warning: client_api_keys.json not found. Creating a dummy file.")
dummy_key = f"sk-dummy-{uuid.uuid4().hex}"
with open("client_api_keys.json", "w", encoding="utf-8") as f:
json.dump([dummy_key], f, indent=2)
print(f"Created dummy client_api_keys.json with key: {dummy_key}")
load_client_api_keys()
load_weam_accounts()
print("\n--- Weam OpenAI API Adapter ---")
print(f"Debug Mode: {DEBUG_MODE}")
print(f"Client API Keys: {len(VALID_CLIENT_KEYS)}")
print(f"Weam Accounts: {len(WEAM_ACCOUNTS)}")
print(f"Available Models: {', '.join(WEAM_MODELS[:5])}{'...' if len(WEAM_MODELS) > 5 else ''}")
print("------------------------------------")
uvicorn.run(app, host="0.0.0.0", port=8000)