Spaces:
Paused
Paused
| import json | |
| import os | |
| import time | |
| import uuid | |
| import threading | |
| import requests | |
| import ast | |
| import secrets | |
| import base64 | |
| from typing import Any, Dict, List, Optional, TypedDict, Union, Generator | |
| from fastapi import FastAPI, HTTPException, Depends | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from pydantic import BaseModel, Field | |
| # Weam Account Management | |
| class WeamAccount(TypedDict): | |
| jwt: str | |
| is_valid: bool | |
| last_used: float | |
| error_count: int | |
| # Pydantic Models | |
| class ChatMessage(BaseModel): | |
| role: str | |
| content: Union[str, List[Dict[str, Any]]] | |
| class ChatCompletionRequest(BaseModel): | |
| model: str | |
| messages: List[ChatMessage] | |
| stream: bool = True | |
| temperature: Optional[float] = None | |
| max_tokens: Optional[int] = None | |
| top_p: Optional[float] = None | |
| class ModelInfo(BaseModel): | |
| id: str | |
| object: str = "model" | |
| created: int = Field(default_factory=lambda: int(time.time())) | |
| owned_by: str = "weam" | |
| class ModelList(BaseModel): | |
| object: str = "list" | |
| data: List[ModelInfo] | |
| class ChatCompletionChoice(BaseModel): | |
| message: ChatMessage | |
| index: int = 0 | |
| finish_reason: str = "stop" | |
| class ChatCompletionResponse(BaseModel): | |
| id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex}") | |
| object: str = "chat.completion" | |
| created: int = Field(default_factory=lambda: int(time.time())) | |
| model: str | |
| choices: List[ChatCompletionChoice] | |
| usage: Dict[str, int] = Field( | |
| default_factory=lambda: { | |
| "prompt_tokens": 0, | |
| "completion_tokens": 0, | |
| "total_tokens": 0, | |
| } | |
| ) | |
| class StreamChoice(BaseModel): | |
| delta: Dict[str, Any] = Field(default_factory=dict) | |
| index: int = 0 | |
| finish_reason: Optional[str] = None | |
| class StreamResponse(BaseModel): | |
| id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex}") | |
| object: str = "chat.completion.chunk" | |
| created: int = Field(default_factory=lambda: int(time.time())) | |
| model: str | |
| choices: List[StreamChoice] | |
| # Global variables | |
| VALID_CLIENT_KEYS: set = set() | |
| WEAM_ACCOUNTS: List[WeamAccount] = [] | |
| account_lock = threading.Lock() | |
| MAX_ERROR_COUNT = 3 | |
| ERROR_COOLDOWN = 300 # 5 minutes cooldown for accounts with errors | |
| DEBUG_MODE = os.environ.get("DEBUG_MODE", "false").lower() == "true" | |
| REQUEST_TIMEOUT = 60.0 | |
| # Weam models list | |
| WEAM_MODELS = [ | |
| "claude-3-5-sonnet-latest", | |
| "claude-3-7-sonnet-latest", | |
| "claude-sonnet-4-20250514", | |
| "claude-opus-4-20250514", | |
| ] | |
| # FastAPI App | |
| app = FastAPI(title="Weam OpenAI API Adapter") | |
| security = HTTPBearer(auto_error=False) | |
| def log_debug(message: str): | |
| """Debug日志函数""" | |
| if DEBUG_MODE: | |
| print(f"[DEBUG] {message}") | |
| def load_client_api_keys(): | |
| """Load client API keys from client_api_keys.json""" | |
| global VALID_CLIENT_KEYS | |
| try: | |
| with open("client_api_keys.json", "r", encoding="utf-8") as f: | |
| keys = json.load(f) | |
| VALID_CLIENT_KEYS = set(keys) if isinstance(keys, list) else set() | |
| print(f"Successfully loaded {len(VALID_CLIENT_KEYS)} client API keys.") | |
| except FileNotFoundError: | |
| print("Error: client_api_keys.json not found. Client authentication will fail.") | |
| VALID_CLIENT_KEYS = set() | |
| except Exception as e: | |
| print(f"Error loading client_api_keys.json: {e}") | |
| VALID_CLIENT_KEYS = set() | |
| def load_weam_accounts(): | |
| """Load Weam accounts from weam.json""" | |
| global WEAM_ACCOUNTS | |
| WEAM_ACCOUNTS = [] | |
| try: | |
| with open("weam.json", "r", encoding="utf-8") as f: | |
| accounts = json.load(f) | |
| if not isinstance(accounts, list): | |
| print("Warning: weam.json should contain a list of account objects.") | |
| return | |
| for acc in accounts: | |
| jwt = acc.get("jwt") | |
| if jwt: | |
| WEAM_ACCOUNTS.append({ | |
| "jwt": jwt, | |
| "is_valid": True, | |
| "last_used": 0, | |
| "error_count": 0, | |
| }) | |
| print(f"Successfully loaded {len(WEAM_ACCOUNTS)} Weam accounts.") | |
| except FileNotFoundError: | |
| print("Error: weam.json not found. API calls will fail.") | |
| except Exception as e: | |
| print(f"Error loading weam.json: {e}") | |
| def get_best_weam_account() -> Optional[WeamAccount]: | |
| """Get the best available Weam account using a smart selection algorithm.""" | |
| with account_lock: | |
| now = time.time() | |
| valid_accounts = [ | |
| acc for acc in WEAM_ACCOUNTS | |
| if acc["is_valid"] and ( | |
| acc["error_count"] < MAX_ERROR_COUNT or | |
| now - acc["last_used"] > ERROR_COOLDOWN | |
| ) | |
| ] | |
| if not valid_accounts: | |
| return None | |
| # Reset error count for accounts that have been in cooldown | |
| for acc in valid_accounts: | |
| if acc["error_count"] >= MAX_ERROR_COUNT and now - acc["last_used"] > ERROR_COOLDOWN: | |
| acc["error_count"] = 0 | |
| # Sort by last used (oldest first) and error count (lowest first) | |
| valid_accounts.sort(key=lambda x: (x["last_used"], x["error_count"])) | |
| account = valid_accounts[0] | |
| account["last_used"] = now | |
| return account | |
| def upload_image(jwt: str, image_bytes: bytes, filename: str) -> str: | |
| """Upload image to weam.ai and return the URI.""" | |
| url = "https://api.weam.ai/api/upload/file" | |
| payload = {"brainId": "699b908a19999177cf7e496a", "vectorApiCall": "true"} | |
| files = [("files", (filename, image_bytes, "image/png"))] | |
| headers = { | |
| "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/137.0.0.0 Safari/537.36 Edg/137.0.0.0", | |
| "Accept": "application/json, text/plain, */*", | |
| "authorization": f"jwt {jwt}", | |
| "origin": "https://app.weam.ai", | |
| "referer": "https://app.weam.ai/", | |
| } | |
| response = requests.post(url, data=payload, files=files, headers=headers, timeout=REQUEST_TIMEOUT) | |
| response.raise_for_status() | |
| return response.json()["data"][0]["uri"] | |
| def process_messages(messages: List[ChatMessage], jwt: str) -> tuple[str, List[str]]: | |
| """Extract text and process images from messages.""" | |
| query_parts = [] | |
| image_urls = [] | |
| for msg in messages: | |
| if isinstance(msg.content, str): | |
| query_parts.append(msg.content) | |
| elif isinstance(msg.content, list): | |
| for item in msg.content: | |
| if item.get("type") == "text": | |
| query_parts.append(item.get("text", "")) | |
| elif item.get("type") == "image_url": | |
| image_url = item.get("image_url", {}).get("url", "") | |
| if image_url.startswith("data:image/"): | |
| # Handle base64 encoded images | |
| try: | |
| header, encoded = image_url.split(",", 1) | |
| image_bytes = base64.b64decode(encoded) | |
| filename = f"upload-{uuid.uuid4().hex}.png" | |
| uri = upload_image(jwt, image_bytes, filename) | |
| image_urls.append(f"https://cdn.weam.ai{uri}") | |
| except Exception as e: | |
| log_debug(f"Failed to process image: {e}") | |
| raise HTTPException(status_code=400, detail=f"Invalid image format: {e}") | |
| return "\n\n".join(query_parts), image_urls | |
| async def authenticate_client( | |
| auth: Optional[HTTPAuthorizationCredentials] = Depends(security), | |
| ): | |
| """Authenticate client based on API key in Authorization header""" | |
| if not VALID_CLIENT_KEYS: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Service unavailable: Client API keys not configured on server.", | |
| ) | |
| if not auth or not auth.credentials: | |
| raise HTTPException( | |
| status_code=401, | |
| detail="API key required in Authorization header.", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| if auth.credentials not in VALID_CLIENT_KEYS: | |
| raise HTTPException(status_code=403, detail="Invalid client API key.") | |
| async def startup(): | |
| """应用启动时初始化配置""" | |
| print("Starting Weam OpenAI API Adapter server...") | |
| load_client_api_keys() | |
| load_weam_accounts() | |
| print("Server initialization completed.") | |
| async def list_v1_models(_: None = Depends(authenticate_client)): | |
| """List available models - authenticated""" | |
| return ModelList(data=[ModelInfo(id=model) for model in WEAM_MODELS]) | |
| async def list_models_no_auth(): | |
| """List available models without authentication - for client compatibility""" | |
| return ModelList(data=[ModelInfo(id=model) for model in WEAM_MODELS]) | |
| async def chat_completions( | |
| request: ChatCompletionRequest, _: None = Depends(authenticate_client) | |
| ): | |
| """Creates a chat completion using the Weam API.""" | |
| if request.model not in WEAM_MODELS: | |
| raise HTTPException(status_code=404, detail=f"Model '{request.model}' not found.") | |
| if not request.messages: | |
| raise HTTPException(status_code=400, detail="No messages provided in the request.") | |
| log_debug(f"Processing request for model: {request.model}") | |
| request_id = f"chatcmpl-{uuid.uuid4().hex}" | |
| # Try all accounts until one works | |
| for attempt in range(len(WEAM_ACCOUNTS)): | |
| account = get_best_weam_account() | |
| if not account: | |
| raise HTTPException(status_code=503, detail="No valid Weam accounts available.") | |
| jwt = account["jwt"] | |
| log_debug(f"Using account with JWT ending in ...{jwt[-6:]}") | |
| try: | |
| query, image_urls = process_messages(request.messages, jwt) | |
| log_debug(f"Query length: {len(query)}, Images: {len(image_urls)}") | |
| if request.stream: | |
| return StreamingResponse( | |
| weam_stream_generator(request.model, query, image_urls, jwt, request_id), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no", | |
| }, | |
| ) | |
| else: | |
| return build_non_stream_response(request.model, query, image_urls, jwt, request_id) | |
| except Exception as e: | |
| error_detail = str(e) | |
| log_debug(f"Weam API error: {error_detail}") | |
| with account_lock: | |
| account["error_count"] += 1 | |
| if "401" in error_detail or "unauthorized" in error_detail.lower(): | |
| account["is_valid"] = False | |
| log_debug(f"Account marked as invalid due to auth error") | |
| # All attempts failed | |
| raise HTTPException(status_code=503, detail="All attempts to contact Weam API failed.") | |
| def weam_stream_generator(model: str, query: str, image_urls: List[str], jwt: str, request_id: str) -> Generator[str, None, None]: | |
| """Generate streaming response from Weam API.""" | |
| url = "https://pyapi.weam.ai/api/tool/stream-tool-chat-with-openai" | |
| created_time = int(time.time()) | |
| payload = { | |
| "thread_id": secrets.token_hex(12), | |
| "query": query, | |
| "prompt_id": None, | |
| "llm_apikey": "684685b2f24ae32c999cbc93", | |
| "chat_session_id": secrets.token_hex(12), | |
| "image_url": image_urls, | |
| "company_id": "6846857af24ae32c7b1cbc32", | |
| "delay_chunk": 0.02, | |
| "code": "ANTHROPIC", | |
| "model_name": model, | |
| "msgCredit": 0, | |
| } | |
| headers = { | |
| "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/137.0.0.0 Safari/537.36 Edg/137.0.0.0", | |
| "Content-Type": "application/json", | |
| "authorization": f"jwt {jwt}", | |
| "origin": "https://app.weam.ai", | |
| "referer": "https://app.weam.ai/", | |
| } | |
| # Send initial role message | |
| yield f"data: {StreamResponse(id=request_id, created=created_time, model=model, choices=[StreamChoice(delta={'role': 'assistant'})]).json()}\n\n" | |
| try: | |
| with requests.post(url, data=json.dumps(payload), headers=headers, stream=True, timeout=REQUEST_TIMEOUT) as response: | |
| response.raise_for_status() | |
| for line in response.iter_lines(): | |
| if line: | |
| text = line.decode("utf-8") | |
| if text.startswith("data: "): | |
| byte_str = text[6:] # 去掉 "data: " 前缀 | |
| try: | |
| byte_obj = ast.literal_eval(byte_str) | |
| if isinstance(byte_obj, bytes): | |
| decoded = byte_obj.decode("utf-8") | |
| yield f"data: {StreamResponse(id=request_id, created=created_time, model=model, choices=[StreamChoice(delta={'content': decoded})]).json()}\n\n" | |
| except (SyntaxError, ValueError) as e: | |
| log_debug(f"Parse error: {e}") | |
| except requests.exceptions.ChunkedEncodingError: | |
| # 这个错误通常发生在响应结束时,可以安全地忽略 | |
| log_debug("ChunkedEncodingError caught - stream likely completed") | |
| except Exception as e: | |
| log_debug(f"Stream error: {e}") | |
| yield f"data: {json.dumps({'error': {'message': str(e)}})}\n\n" | |
| # Always send completion message | |
| yield f"data: {StreamResponse(id=request_id, created=created_time, model=model, choices=[StreamChoice(delta={}, finish_reason='stop')]).json()}\n\n" | |
| yield "data: [DONE]\n\n" | |
| def build_non_stream_response(model: str, query: str, image_urls: List[str], jwt: str, request_id: str) -> ChatCompletionResponse: | |
| """Build non-streaming response by accumulating stream chunks.""" | |
| full_content = "" | |
| for chunk in weam_stream_generator(model, query, image_urls, jwt, request_id): | |
| if not chunk.startswith("data: ") or chunk.strip() == "data: [DONE]": | |
| continue | |
| try: | |
| data = json.loads(chunk[6:]) # 去掉 "data: " 前缀 | |
| if "choices" in data and data["choices"]: | |
| delta = data["choices"][0].get("delta", {}) | |
| if "content" in delta and delta["content"]: | |
| full_content += delta["content"] | |
| except json.JSONDecodeError: | |
| continue | |
| return ChatCompletionResponse( | |
| id=request_id, | |
| model=model, | |
| choices=[ChatCompletionChoice(message=ChatMessage(role="assistant", content=full_content))] | |
| ) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| if os.environ.get("DEBUG_MODE", "").lower() == "true": | |
| DEBUG_MODE = True | |
| print("Debug mode enabled via environment variable") | |
| if not os.path.exists("weam.json"): | |
| print("Warning: weam.json not found. Creating a dummy file.") | |
| dummy_data = [{"jwt": "your_jwt_here"}] | |
| with open("weam.json", "w", encoding="utf-8") as f: | |
| json.dump(dummy_data, f, indent=4) | |
| print("Created dummy weam.json. Please replace with valid Weam JWTs.") | |
| if not os.path.exists("client_api_keys.json"): | |
| print("Warning: client_api_keys.json not found. Creating a dummy file.") | |
| dummy_key = f"sk-dummy-{uuid.uuid4().hex}" | |
| with open("client_api_keys.json", "w", encoding="utf-8") as f: | |
| json.dump([dummy_key], f, indent=2) | |
| print(f"Created dummy client_api_keys.json with key: {dummy_key}") | |
| load_client_api_keys() | |
| load_weam_accounts() | |
| print("\n--- Weam OpenAI API Adapter ---") | |
| print(f"Debug Mode: {DEBUG_MODE}") | |
| print(f"Client API Keys: {len(VALID_CLIENT_KEYS)}") | |
| print(f"Weam Accounts: {len(WEAM_ACCOUNTS)}") | |
| print(f"Available Models: {', '.join(WEAM_MODELS[:5])}{'...' if len(WEAM_MODELS) > 5 else ''}") | |
| print("------------------------------------") | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |