Spaces:
Paused
Paused
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.responses import StreamingResponse, JSONResponse | |
| from pydantic import BaseModel, Field, ValidationError | |
| from typing import List, Optional, Dict, Any, Union | |
| import json | |
| import time | |
| import uuid | |
| import logging | |
| import traceback | |
| from curl_cffi import CurlError | |
| from curl_cffi.requests import Session | |
| import asyncio | |
| from threading import Lock | |
| import os | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class GithubChat: | |
| def __init__(self, cookie_path="cookies.json", model="gpt-4o"): | |
| self.api_url = "https://api.individual.githubcopilot.com" | |
| self.session = Session() | |
| self.session.headers.update({ | |
| "Content-Type": "application/json", | |
| "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36", | |
| "Accept": "*/*", | |
| "Accept-Encoding": "gzip, deflate, br", | |
| "Accept-Language": "en-US,en;q=0.5", | |
| "Origin": "https://github.com", | |
| "Referer": "https://github.com/copilot", | |
| "GitHub-Verified-Fetch": "true", | |
| "X-Requested-With": "XMLHttpRequest", | |
| "Connection": "keep-alive", | |
| "Sec-Fetch-Dest": "empty", | |
| "Sec-Fetch-Mode": "cors", | |
| "Sec-Fetch-Site": "same-origin", | |
| }) | |
| # Load cookies with better error handling | |
| self.cookie_path = cookie_path | |
| self._load_cookies() | |
| self.model = model | |
| self._access_token = None | |
| self._conversation_id = None | |
| self._token_lock = Lock() | |
| self._conversation_lock = Lock() | |
| self.max_retries = 3 | |
| self.retry_delay = 1.0 | |
| def _load_cookies(self): | |
| """Load cookies with robust error handling""" | |
| try: | |
| if not os.path.exists(self.cookie_path): | |
| logger.warning(f"Cookie file {self.cookie_path} not found") | |
| return | |
| with open(self.cookie_path, 'r', encoding='utf-8') as f: | |
| cookies_data = json.load(f) | |
| if not isinstance(cookies_data, list): | |
| logger.error("Invalid cookie format: expected list") | |
| return | |
| cookies = {} | |
| current_time = time.time() | |
| for cookie in cookies_data: | |
| if not isinstance(cookie, dict): | |
| continue | |
| name = cookie.get('name') | |
| value = cookie.get('value') | |
| if not name or not value: | |
| continue | |
| # Check expiration | |
| expiry = cookie.get('expirationDate') | |
| if expiry and expiry <= current_time: | |
| logger.debug(f"Cookie {name} expired") | |
| continue | |
| cookies[name] = value | |
| self.session.cookies.update(cookies) | |
| logger.info(f"Loaded {len(cookies)} valid cookies") | |
| except json.JSONDecodeError as e: | |
| logger.error(f"Invalid JSON in cookie file: {e}") | |
| except Exception as e: | |
| logger.error(f"Error loading cookies: {e}") | |
| def _retry_request(self, func, *args, **kwargs): | |
| """Retry mechanism for API requests""" | |
| last_exception = None | |
| for attempt in range(self.max_retries): | |
| try: | |
| return func(*args, **kwargs) | |
| except (CurlError, ConnectionError, TimeoutError) as e: | |
| last_exception = e | |
| if attempt < self.max_retries - 1: | |
| wait_time = self.retry_delay * (2 ** attempt) | |
| logger.warning(f"Request failed (attempt {attempt + 1}), retrying in {wait_time}s: {e}") | |
| time.sleep(wait_time) | |
| else: | |
| logger.error(f"Request failed after {self.max_retries} attempts: {e}") | |
| except Exception as e: | |
| logger.error(f"Unexpected error in request: {e}") | |
| last_exception = e | |
| break | |
| raise last_exception or Exception("Request failed after retries") | |
| def get_access_token(self): | |
| """Get access token with thread safety and retry logic""" | |
| with self._token_lock: | |
| if self._access_token: | |
| return self._access_token | |
| def _get_token(): | |
| response = self.session.post( | |
| "https://github.com/github-copilot/chat/token", | |
| headers=self.session.headers, | |
| timeout=30 | |
| ) | |
| if response.status_code == 200: | |
| data = response.json() | |
| token = data.get("token") | |
| if token: | |
| self._access_token = token | |
| logger.info("Successfully obtained access token") | |
| return token | |
| else: | |
| raise Exception("No token in response") | |
| else: | |
| raise Exception(f"Token request failed: {response.status_code} - {response.text}") | |
| try: | |
| return self._retry_request(_get_token) | |
| except Exception as e: | |
| logger.error(f"Failed to get access token: {e}") | |
| # Reset token on failure | |
| self._access_token = None | |
| return None | |
| def create_conversation(self): | |
| """Create conversation with thread safety and retry logic""" | |
| with self._conversation_lock: | |
| if self._conversation_id: | |
| return self._conversation_id | |
| access_token = self.get_access_token() | |
| if not access_token: | |
| logger.error("Cannot create conversation: no access token") | |
| return None | |
| def _create_conv(): | |
| headers = self.session.headers.copy() | |
| headers["Authorization"] = f"GitHub-Bearer {access_token}" | |
| response = self.session.post( | |
| f"{self.api_url}/github/chat/threads", | |
| headers=headers, | |
| impersonate="chrome120", | |
| timeout=30 | |
| ) | |
| if response.status_code in [200, 201]: | |
| data = response.json() | |
| thread_id = data.get("thread_id") | |
| if thread_id: | |
| self._conversation_id = thread_id | |
| logger.info(f"Created conversation: {thread_id}") | |
| return thread_id | |
| else: | |
| raise Exception("No thread_id in response") | |
| else: | |
| raise Exception(f"Conversation creation failed: {response.status_code} - {response.text}") | |
| try: | |
| return self._retry_request(_create_conv) | |
| except Exception as e: | |
| logger.error(f"Failed to create conversation: {e}") | |
| # Reset conversation on failure | |
| self._conversation_id = None | |
| return None | |
| def chat(self, prompt, stream=False): | |
| """Chat with robust error handling and validation""" | |
| if not prompt or not prompt.strip(): | |
| logger.error("Empty prompt provided") | |
| return None | |
| conversation_id = self.create_conversation() | |
| if not conversation_id: | |
| logger.error("Failed to create conversation") | |
| return None | |
| access_token = self.get_access_token() | |
| if not access_token: | |
| logger.error("Failed to get access token") | |
| return None | |
| def _send_message(): | |
| headers = self.session.headers.copy() | |
| headers["Authorization"] = f"GitHub-Bearer {access_token}" | |
| data = { | |
| "content": prompt, | |
| "intent": "conversation", | |
| "references": [], | |
| "context": [], | |
| "currentURL": f"https://github.com/copilot/c/{conversation_id}", | |
| "streaming": True, # GitHub Copilot API always uses streaming | |
| "confirmations": [], | |
| "customInstructions": [], | |
| "model": self.model, | |
| "mode": "immersive" | |
| } | |
| response = self.session.post( | |
| f"{self.api_url}/github/chat/threads/{conversation_id}/messages", | |
| json=data, | |
| headers=headers, | |
| impersonate="chrome120", | |
| stream=True, | |
| timeout=60 # Longer timeout for chat | |
| ) | |
| if response.status_code not in [200, 201]: | |
| raise Exception(f"Chat request failed: {response.status_code} - {response.text}") | |
| return response | |
| try: | |
| response = self._retry_request(_send_message) | |
| if stream: | |
| # Return generator for streaming | |
| return self._stream_response(response) | |
| else: | |
| # Collect all chunks for non-streaming response | |
| response_text = "" | |
| try: | |
| for chunk in self._stream_response(response): | |
| if chunk: | |
| response_text += chunk | |
| except Exception as e: | |
| logger.error(f"Error collecting streaming response: {e}") | |
| return None | |
| return response_text | |
| except Exception as e: | |
| logger.error(f"Chat request failed: {e}") | |
| # Reset conversation on failure to force new one next time | |
| self._conversation_id = None | |
| return None | |
| def _stream_response(self, response): | |
| """Helper method to parse streaming response with robust error handling""" | |
| try: | |
| for line in response.iter_lines(): | |
| if not line: | |
| continue | |
| try: | |
| # Handle different line formats | |
| if line.startswith(b'data: '): | |
| data_str = line[6:].decode('utf-8') | |
| elif line.startswith(b'data:'): | |
| data_str = line[5:].decode('utf-8') | |
| else: | |
| continue | |
| # Skip empty data or [DONE] | |
| if not data_str.strip() or data_str.strip() == '[DONE]': | |
| continue | |
| try: | |
| data = json.loads(data_str) | |
| if isinstance(data, dict) and data.get("type") == "content": | |
| body = data.get("body", "") | |
| if body and isinstance(body, str): # Only yield non-empty string content | |
| yield body | |
| except json.JSONDecodeError as e: | |
| logger.debug(f"JSON decode error for line: {data_str[:100]}... Error: {e}") | |
| continue | |
| except UnicodeDecodeError as e: | |
| logger.debug(f"Unicode decode error: {e}") | |
| continue | |
| except Exception as e: | |
| logger.debug(f"Unexpected error processing line: {e}") | |
| continue | |
| except Exception as e: | |
| logger.error(f"Error in stream response: {e}") | |
| raise | |
| # OpenAI Compatible Models | |
| class OpenAIModel(BaseModel): | |
| id: str | |
| object: str = "model" | |
| created: int = int(time.time()) | |
| owned_by: str = "github-copilot" | |
| class FunctionCall(BaseModel): | |
| name: str | |
| arguments: str | |
| class ToolCall(BaseModel): | |
| id: str | |
| type: str = "function" | |
| function: FunctionCall | |
| class ChatMessage(BaseModel): | |
| role: str | |
| content: Optional[Union[str, List[Dict[str, Any]]]] = None | |
| name: Optional[str] = None | |
| function_call: Optional[FunctionCall] = None | |
| tool_calls: Optional[List[ToolCall]] = None | |
| tool_call_id: Optional[str] = None | |
| class Function(BaseModel): | |
| name: str | |
| description: Optional[str] = None | |
| parameters: Optional[Dict[str, Any]] = None | |
| class Tool(BaseModel): | |
| type: str = "function" | |
| function: Function | |
| class ChatCompletionRequest(BaseModel): | |
| model: str | |
| messages: List[ChatMessage] | |
| max_tokens: Optional[int] = None | |
| temperature: Optional[float] = 0.7 | |
| top_p: Optional[float] = 1.0 | |
| stream: Optional[bool] = True # Default to True for streaming | |
| stop: Optional[Union[str, List[str]]] = None | |
| tools: Optional[List[Tool]] = None | |
| tool_choice: Optional[Union[str, Dict[str, Any]]] = None | |
| functions: Optional[List[Function]] = None | |
| function_call: Optional[Union[str, Dict[str, Any]]] = None | |
| presence_penalty: Optional[float] = 0.0 | |
| frequency_penalty: Optional[float] = 0.0 | |
| logit_bias: Optional[Dict[str, float]] = None | |
| user: Optional[str] = None | |
| seed: Optional[int] = None | |
| logprobs: Optional[bool] = None | |
| top_logprobs: Optional[int] = None | |
| n: Optional[int] = 1 | |
| class ChatCompletionChoice(BaseModel): | |
| index: int = 0 | |
| message: ChatMessage | |
| finish_reason: Optional[str] = None | |
| logprobs: Optional[Dict[str, Any]] = None | |
| class ChatCompletionResponse(BaseModel): | |
| id: str | |
| object: str = "chat.completion" | |
| created: int | |
| model: str | |
| choices: List[ChatCompletionChoice] | |
| usage: Optional[Dict[str, int]] = None | |
| system_fingerprint: Optional[str] = None | |
| class ChatCompletionStreamChoice(BaseModel): | |
| index: int = 0 | |
| delta: Dict[str, Any] | |
| finish_reason: Optional[str] = None | |
| logprobs: Optional[Dict[str, Any]] = None | |
| class ChatCompletionStreamResponse(BaseModel): | |
| id: str | |
| object: str = "chat.completion.chunk" | |
| created: int | |
| model: str | |
| choices: List[ChatCompletionStreamChoice] | |
| system_fingerprint: Optional[str] = None | |
| def format_prompt(messages: List[Dict[str, Any]], add_special_tokens: bool = False, | |
| do_continue: bool = False, include_system: bool = True) -> str: | |
| """ | |
| Format a series of messages into a single string, optionally adding special tokens. | |
| Args: | |
| messages: A list of message dictionaries, each containing 'role' and 'content'. | |
| add_special_tokens: Whether to add special formatting tokens. | |
| do_continue: If True, don't add the final "Assistant:" prompt. | |
| include_system: Whether to include system messages in the formatted output. | |
| Returns: | |
| A formatted string containing all messages. | |
| """ | |
| # Helper function to convert content to string | |
| def to_string(value) -> str: | |
| if isinstance(value, str): | |
| return value | |
| elif isinstance(value, dict): | |
| if "text" in value: | |
| return value.get("text", "") | |
| return "" | |
| elif isinstance(value, list): | |
| # Handle array content (like images + text) | |
| text_parts = [] | |
| for item in value: | |
| if isinstance(item, dict): | |
| if item.get("type") == "text": | |
| text_parts.append(item.get("text", "")) | |
| elif item.get("type") == "image_url": | |
| text_parts.append("[Image]") | |
| elif isinstance(item, str): | |
| text_parts.append(item) | |
| return "".join(text_parts) | |
| return str(value) if value is not None else "" | |
| # If there's only one message and no special tokens needed, just return its content | |
| if not add_special_tokens and len(messages) <= 1 and messages: | |
| return to_string(messages[0].get("content", "")) | |
| # Filter and process messages | |
| processed_messages = [] | |
| for message in messages: | |
| role = message.get("role", "") | |
| content = message.get("content") | |
| if include_system or role != "system": | |
| content_str = to_string(content) | |
| if content_str.strip(): | |
| processed_messages.append((role, content_str)) | |
| # Format each message as "Role: Content" | |
| formatted = "\n".join([ | |
| f'{role.capitalize()}: {content}' | |
| for role, content in processed_messages | |
| ]) | |
| # Add final prompt for assistant if needed | |
| if do_continue: | |
| return formatted | |
| return f"{formatted}\nAssistant:" if formatted else "Assistant:" | |
| app = FastAPI( | |
| title="GitHub Copilot OpenAI Compatible API", | |
| version="1.0.0", | |
| description="OpenAI-compatible API for GitHub Copilot with full streaming and tool support" | |
| ) | |
| # Global variables for health monitoring | |
| chat_client = None | |
| startup_time = time.time() | |
| request_count = 0 | |
| error_count = 0 | |
| async def startup_event(): | |
| """Initialize the chat client on startup""" | |
| global chat_client | |
| try: | |
| logger.info("Initializing GitHub Copilot chat client...") | |
| chat_client = GithubChat() | |
| logger.info("Chat client initialized successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize chat client: {e}") | |
| # Don't fail startup, but log the error | |
| chat_client = None | |
| # Add CORS middleware for web clients | |
| try: | |
| from fastapi.middleware.cors import CORSMiddleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| except ImportError: | |
| pass # CORS middleware is optional | |
| # Add comprehensive error handling | |
| async def validation_exception_handler(request, exc: ValidationError): | |
| logger.error(f"Validation error: {exc}") | |
| return JSONResponse( | |
| status_code=422, | |
| content={ | |
| "error": { | |
| "message": "Validation error", | |
| "type": "invalid_request_error", | |
| "details": exc.errors() | |
| } | |
| } | |
| ) | |
| async def http_exception_handler(request, exc: HTTPException): | |
| logger.error(f"HTTP error: {exc.status_code} - {exc.detail}") | |
| return JSONResponse( | |
| status_code=exc.status_code, | |
| content={ | |
| "error": { | |
| "message": exc.detail, | |
| "type": "api_error", | |
| "code": exc.status_code | |
| } | |
| } | |
| ) | |
| async def global_exception_handler(request, exc): | |
| logger.error(f"Unexpected error: {exc}\n{traceback.format_exc()}") | |
| return JSONResponse( | |
| status_code=500, | |
| content={ | |
| "error": { | |
| "message": "Internal server error", | |
| "type": "server_error", | |
| "code": 500 | |
| } | |
| } | |
| ) | |
| async def root(): | |
| return {"message": "GitHub Copilot OpenAI Compatible API", "version": "1.0.0", "status": "running"} | |
| async def health_check(): | |
| """Enhanced health check with system status""" | |
| global chat_client, startup_time, request_count, error_count | |
| uptime = time.time() - startup_time | |
| status = "healthy" | |
| # Check chat client status | |
| client_status = "unknown" | |
| if chat_client is None: | |
| client_status = "not_initialized" | |
| status = "degraded" | |
| else: | |
| try: | |
| # Quick test of token retrieval | |
| token = chat_client.get_access_token() | |
| client_status = "ready" if token else "auth_failed" | |
| if not token: | |
| status = "degraded" | |
| except Exception as e: | |
| client_status = f"error: {str(e)[:50]}" | |
| status = "degraded" | |
| return { | |
| "status": status, | |
| "timestamp": int(time.time()), | |
| "uptime_seconds": int(uptime), | |
| "client_status": client_status, | |
| "stats": { | |
| "total_requests": request_count, | |
| "total_errors": error_count, | |
| "error_rate": error_count / max(request_count, 1) | |
| } | |
| } | |
| async def list_models(): | |
| models = [ | |
| OpenAIModel(id="gpt-4o"), | |
| OpenAIModel(id="o3-mini"), | |
| OpenAIModel(id="o1"), | |
| OpenAIModel(id="claude-3.5-sonnet"), | |
| OpenAIModel(id="claude-3.7-sonnet"), | |
| OpenAIModel(id="claude-3.7-sonnet-thought"), | |
| OpenAIModel(id="claude-sonnet-4"), | |
| OpenAIModel(id="gemini-2.0-flash-001"), | |
| OpenAIModel(id="gemini-2.5-pro"), | |
| OpenAIModel(id="gpt-4.1"), | |
| OpenAIModel(id="o4-mini"), | |
| ] | |
| return {"object": "list", "data": models} | |
| async def list_models_alt(): | |
| """Alternative endpoint for models""" | |
| return await list_models() | |
| async def validate_chat_request(request: ChatCompletionRequest): | |
| """Validate chat completion request format without processing""" | |
| try: | |
| # Validate messages | |
| if not request.messages: | |
| return {"valid": False, "error": "Messages cannot be empty"} | |
| validation_results = [] | |
| for i, msg in enumerate(request.messages): | |
| msg_validation = { | |
| "index": i, | |
| "role": repr(msg.role), | |
| "role_type": type(msg.role).__name__, | |
| "content_type": type(msg.content).__name__ if msg.content is not None else "None", | |
| "valid": True, | |
| "errors": [] | |
| } | |
| # Check role | |
| role = getattr(msg, 'role', None) | |
| if not role: | |
| msg_validation["valid"] = False | |
| msg_validation["errors"].append("Missing role") | |
| else: | |
| role_str = str(role).lower().strip() | |
| valid_roles = ["system", "user", "assistant", "function", "tool"] | |
| if role_str not in valid_roles: | |
| msg_validation["valid"] = False | |
| msg_validation["errors"].append(f"Invalid role '{role_str}'. Valid: {valid_roles}") | |
| validation_results.append(msg_validation) | |
| all_valid = all(result["valid"] for result in validation_results) | |
| return { | |
| "valid": all_valid, | |
| "model": request.model, | |
| "message_count": len(request.messages), | |
| "messages": validation_results | |
| } | |
| except Exception as e: | |
| return {"valid": False, "error": f"Validation error: {str(e)}"} | |
| async def create_chat_completion(request: ChatCompletionRequest): | |
| """Enhanced chat completions endpoint with robust error handling""" | |
| global request_count, error_count, chat_client | |
| request_count += 1 | |
| request_id = f"req-{uuid.uuid4().hex[:8]}" | |
| logger.info(f"[{request_id}] Chat completion request: model={request.model}, messages={len(request.messages)}, stream={request.stream}") | |
| # Debug log the first message for troubleshooting | |
| if request.messages: | |
| first_msg = request.messages[0] | |
| logger.debug(f"[{request_id}] First message - role: {repr(first_msg.role)}, content type: {type(first_msg.content)}") | |
| # Check if chat client is available | |
| if chat_client is None: | |
| error_count += 1 | |
| logger.error(f"[{request_id}] Chat client not initialized") | |
| raise HTTPException(status_code=503, detail="Service temporarily unavailable - chat client not initialized") | |
| try: | |
| # Comprehensive validation | |
| if not request.messages: | |
| raise HTTPException(status_code=400, detail="Messages array cannot be empty") | |
| if len(request.messages) > 100: # Reasonable limit | |
| raise HTTPException(status_code=400, detail="Too many messages (max 100)") | |
| # Validate model | |
| if not request.model or not isinstance(request.model, str): | |
| raise HTTPException(status_code=400, detail="Model must be a non-empty string") | |
| # Extract and validate prompt | |
| message_dicts = [] | |
| total_content_length = 0 | |
| for i, msg in enumerate(request.messages): | |
| # More flexible role validation | |
| role = getattr(msg, 'role', None) | |
| if not role: | |
| raise HTTPException(status_code=400, detail=f"Missing role in message {i}") | |
| # Convert to string and validate | |
| role_str = str(role).lower().strip() | |
| valid_roles = ["system", "user", "assistant", "function", "tool"] | |
| if role_str not in valid_roles: | |
| raise HTTPException(status_code=400, detail=f"Invalid role '{role_str}' in message {i}. Valid roles: {valid_roles}") | |
| msg_dict = {"role": role_str} | |
| if msg.content is not None: | |
| # Handle different content types | |
| if isinstance(msg.content, str): | |
| content_length = len(msg.content) | |
| elif isinstance(msg.content, list): | |
| content_length = sum(len(str(item)) for item in msg.content) | |
| else: | |
| content_length = len(str(msg.content)) | |
| total_content_length += content_length | |
| msg_dict["content"] = msg.content | |
| message_dicts.append(msg_dict) | |
| # Check total content length (reasonable limit) | |
| if total_content_length > 100000: # 100KB limit | |
| raise HTTPException(status_code=400, detail="Total message content too large") | |
| prompt = format_prompt(message_dicts) | |
| if not prompt.strip(): | |
| raise HTTPException(status_code=400, detail="No valid message content found") | |
| logger.info(f"[{request_id}] Formatted prompt length: {len(prompt)}") | |
| # Determine streaming mode | |
| should_stream = request.stream if request.stream is not None else True | |
| if should_stream: | |
| logger.info(f"[{request_id}] Starting streaming response") | |
| return StreamingResponse( | |
| generate_stream_response(request, prompt, request_id), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no", | |
| "Access-Control-Allow-Origin": "*", | |
| "Access-Control-Allow-Headers": "*" | |
| } | |
| ) | |
| else: | |
| logger.info(f"[{request_id}] Starting non-streaming response") | |
| # Get non-streaming response | |
| result = chat_client.chat(prompt, stream=False) | |
| if result is None: | |
| logger.error(f"[{request_id}] Chat client returned None") | |
| raise HTTPException(status_code=503, detail="GitHub Copilot service unavailable") | |
| # Ensure result is a string | |
| response_text = result if isinstance(result, str) else str(result) | |
| if not response_text.strip(): | |
| logger.warning(f"[{request_id}] Empty response from chat client") | |
| response_text = "I apologize, but I couldn't generate a response. Please try again." | |
| logger.info(f"[{request_id}] Non-streaming response length: {len(response_text)}") | |
| response = ChatCompletionResponse( | |
| id=f"chatcmpl-{uuid.uuid4().hex}", | |
| created=int(time.time()), | |
| model=request.model, | |
| choices=[ChatCompletionChoice( | |
| message=ChatMessage(role="assistant", content=response_text), | |
| finish_reason="stop" | |
| )], | |
| usage={ | |
| "prompt_tokens": len(prompt.split()), | |
| "completion_tokens": len(response_text.split()), | |
| "total_tokens": len(prompt.split()) + len(response_text.split()) | |
| } | |
| ) | |
| return response | |
| except HTTPException: | |
| error_count += 1 | |
| raise | |
| except ValidationError as e: | |
| error_count += 1 | |
| logger.error(f"[{request_id}] Validation error: {e}") | |
| raise HTTPException(status_code=422, detail=f"Request validation failed: {str(e)}") | |
| except Exception as e: | |
| error_count += 1 | |
| logger.error(f"[{request_id}] Unexpected error: {e}\n{traceback.format_exc()}") | |
| raise HTTPException(status_code=500, detail="Internal server error occurred") | |
| async def generate_stream_response(request: ChatCompletionRequest, prompt: str, request_id: str = None): | |
| """Enhanced streaming response with comprehensive error handling""" | |
| completion_id = f"chatcmpl-{uuid.uuid4().hex}" | |
| created_time = int(time.time()) | |
| request_id = request_id or f"req-{uuid.uuid4().hex[:8]}" | |
| logger.info(f"[{request_id}] Starting stream generation") | |
| try: | |
| # Send initial chunk with role | |
| initial_chunk = { | |
| "id": completion_id, | |
| "object": "chat.completion.chunk", | |
| "created": created_time, | |
| "model": request.model, | |
| "choices": [{ | |
| "index": 0, | |
| "delta": {"role": "assistant"}, | |
| "finish_reason": None | |
| }] | |
| } | |
| yield f"data: {json.dumps(initial_chunk)}\n\n" | |
| # Stream content chunks with enhanced error handling | |
| chunk_count = 0 | |
| total_content = "" | |
| try: | |
| chat_stream = chat_client.chat(prompt, stream=True) | |
| if chat_stream is None: | |
| raise Exception("Chat client returned None for streaming") | |
| for chunk in chat_stream: | |
| if chunk and isinstance(chunk, str): # Only process non-empty string chunks | |
| chunk_count += 1 | |
| total_content += chunk | |
| stream_response = ChatCompletionStreamResponse( | |
| id=completion_id, | |
| created=created_time, | |
| model=request.model, | |
| choices=[ChatCompletionStreamChoice( | |
| delta={"content": chunk}, | |
| finish_reason=None | |
| )] | |
| ) | |
| try: | |
| chunk_json = json.dumps(stream_response.model_dump()) | |
| yield f"data: {chunk_json}\n\n" | |
| except Exception as json_error: | |
| logger.error(f"[{request_id}] JSON serialization error: {json_error}") | |
| continue | |
| except Exception as stream_error: | |
| logger.error(f"[{request_id}] Streaming error after {chunk_count} chunks: {stream_error}") | |
| # If we got some content, continue gracefully | |
| if chunk_count > 0: | |
| logger.info(f"[{request_id}] Partial stream completed with {chunk_count} chunks") | |
| else: | |
| # Send error content if no chunks were received | |
| error_content = "I apologize, but I encountered an error while generating the response. Please try again." | |
| error_response = ChatCompletionStreamResponse( | |
| id=completion_id, | |
| created=created_time, | |
| model=request.model, | |
| choices=[ChatCompletionStreamChoice( | |
| delta={"content": error_content}, | |
| finish_reason=None | |
| )] | |
| ) | |
| yield f"data: {json.dumps(error_response.model_dump())}\n\n" | |
| logger.info(f"[{request_id}] Stream completed: {chunk_count} chunks, {len(total_content)} characters") | |
| except Exception as e: | |
| logger.error(f"[{request_id}] Critical streaming error: {e}") | |
| # Send error chunk | |
| error_chunk = { | |
| "id": completion_id, | |
| "object": "chat.completion.chunk", | |
| "created": created_time, | |
| "model": request.model, | |
| "choices": [{ | |
| "index": 0, | |
| "delta": {"content": "Error occurred while streaming response."}, | |
| "finish_reason": "stop" | |
| }] | |
| } | |
| yield f"data: {json.dumps(error_chunk)}\n\n" | |
| finally: | |
| # Always send final chunk | |
| try: | |
| final_chunk = { | |
| "id": completion_id, | |
| "object": "chat.completion.chunk", | |
| "created": created_time, | |
| "model": request.model, | |
| "choices": [{ | |
| "index": 0, | |
| "delta": {}, | |
| "finish_reason": "stop" | |
| }] | |
| } | |
| yield f"data: {json.dumps(final_chunk)}\n\n" | |
| yield "data: [DONE]\n\n" | |
| logger.info(f"[{request_id}] Stream finalized") | |
| except Exception as final_error: | |
| logger.error(f"[{request_id}] Error sending final chunk: {final_error}") | |
| yield "data: [DONE]\n\n" | |
| if __name__ == "__main__": | |
| try: | |
| import uvicorn | |
| port = int(os.getenv("PORT", 8000)) | |
| host = os.getenv("HOST", "0.0.0.0") | |
| logger.info(f"Starting server on {host}:{port}") | |
| uvicorn.run( | |
| "server:app", | |
| host=host, | |
| port=port, | |
| reload=False, # Disable reload in production | |
| log_level="info", | |
| access_log=True | |
| ) | |
| except ImportError: | |
| logger.error("uvicorn not installed. Install with: pip install uvicorn") | |
| except Exception as e: | |
| logger.error(f"Failed to start server: {e}") |