Spaces:
Paused
Paused
| import base64 as _b64, json as _j, time as _t, uuid as _u, logging as _l, traceback as _tb, os as _o | |
| from fastapi import FastAPI as _FA, HTTPException as _HE | |
| from fastapi.responses import StreamingResponse as _SR, JSONResponse as _JR | |
| from pydantic import BaseModel as _BM, Field as _F | |
| from typing import List as _L, Optional as _O, Dict as _D, Any as _A, Union as _U | |
| import replicate as _r | |
| from contextlib import asynccontextmanager as _acm | |
| # Obfuscated configuration | |
| _l.basicConfig(level=_l.INFO) | |
| _lg = _l.getLogger(__name__) | |
| _TOKEN = _b64.b64decode(b'cjhfUFpXcWI4WjZxTGQ3NTl3Y2llaG5GR25WYU95MUdEZDFiQkFxUg==').decode('utf-8') | |
| # Supported models configuration | |
| _MODELS = { | |
| # Anthropic Claude Models | |
| "claude-4-sonnet": "anthropic/claude-4-sonnet", | |
| "claude-3.7-sonnet": "anthropic/claude-3.7-sonnet", | |
| "claude-3.5-sonnet": "anthropic/claude-3.5-sonnet", | |
| "claude-3.5-haiku": "anthropic/claude-3.5-haiku", | |
| # OpenAI GPT Models | |
| "gpt-4.1": "openai/gpt-4.1", | |
| "gpt-4.1-mini": "openai/gpt-4.1-mini", | |
| "gpt-4.1-nano": "openai/gpt-4.1-nano", | |
| "gpt-5": "openai/gpt-5", | |
| "gpt-5-mini": "openai/gpt-5-mini", | |
| "gpt-5-nano": "openai/gpt-5-nano", | |
| # Alternative naming (with provider prefix) | |
| "anthropic/claude-4-sonnet": "anthropic/claude-4-sonnet", | |
| "anthropic/claude-3.7-sonnet": "anthropic/claude-3.7-sonnet", | |
| "anthropic/claude-3.5-sonnet": "anthropic/claude-3.5-sonnet", | |
| "anthropic/claude-3.5-haiku": "anthropic/claude-3.5-haiku", | |
| "openai/gpt-4.1": "openai/gpt-4.1", | |
| "openai/gpt-4.1-mini": "openai/gpt-4.1-mini", | |
| "openai/gpt-4.1-nano": "openai/gpt-4.1-nano", | |
| "openai/gpt-5": "openai/gpt-5", | |
| "openai/gpt-5-mini": "openai/gpt-5-mini", | |
| "openai/gpt-5-nano": "openai/gpt-5-nano" | |
| } | |
| # Model metadata for OpenAI compatibility | |
| _MODEL_INFO = { | |
| "claude-4-sonnet": {"owned_by": "anthropic", "context_length": 200000}, | |
| "claude-3.7-sonnet": {"owned_by": "anthropic", "context_length": 200000}, | |
| "claude-3.5-sonnet": {"owned_by": "anthropic", "context_length": 200000}, | |
| "claude-3.5-haiku": {"owned_by": "anthropic", "context_length": 200000}, | |
| "gpt-4.1": {"owned_by": "openai", "context_length": 128000}, | |
| "gpt-4.1-mini": {"owned_by": "openai", "context_length": 128000}, | |
| "gpt-4.1-nano": {"owned_by": "openai", "context_length": 128000}, | |
| "gpt-5": {"owned_by": "openai", "context_length": 400000}, | |
| "gpt-5-mini": {"owned_by": "openai", "context_length": 400000}, | |
| "gpt-5-nano": {"owned_by": "openai", "context_length": 400000} | |
| } | |
| # OpenAI Compatible Models | |
| class _CM(_BM): | |
| role: str = _F(..., description="Message role") | |
| content: _O[_U[str, _L[_D[str, _A]]]] = _F(None, description="Message content") | |
| name: _O[str] = _F(None, description="Message name") | |
| function_call: _O[_D[str, _A]] = _F(None, description="Function call") | |
| tool_calls: _O[_L[_D[str, _A]]] = _F(None, description="Tool calls") | |
| tool_call_id: _O[str] = _F(None, description="Tool call ID") | |
| class _FC(_BM): | |
| name: str = _F(..., description="Function name") | |
| arguments: str = _F(..., description="Function arguments") | |
| class _TC(_BM): | |
| id: str = _F(..., description="Tool call ID") | |
| type: str = _F(default="function", description="Tool call type") | |
| function: _FC = _F(..., description="Function call") | |
| class _FD(_BM): | |
| name: str = _F(..., description="Function name") | |
| description: _O[str] = _F(None, description="Function description") | |
| parameters: _D[str, _A] = _F(..., description="Function parameters") | |
| class _TD(_BM): | |
| type: str = _F(default="function", description="Tool type") | |
| function: _FD = _F(..., description="Function definition") | |
| class _CCR(_BM): | |
| model: str = _F(..., description="Model name") | |
| messages: _L[_CM] = _F(..., description="Messages") | |
| max_tokens: _O[int] = _F(default=4096, description="Max tokens") | |
| temperature: _O[float] = _F(default=0.7, description="Temperature") | |
| top_p: _O[float] = _F(default=1.0, description="Top p") | |
| n: _O[int] = _F(default=1, description="Number of completions") | |
| stream: _O[bool] = _F(default=False, description="Stream response") | |
| stop: _O[_U[str, _L[str]]] = _F(None, description="Stop sequences") | |
| presence_penalty: _O[float] = _F(default=0.0, description="Presence penalty") | |
| frequency_penalty: _O[float] = _F(default=0.0, description="Frequency penalty") | |
| logit_bias: _O[_D[str, float]] = _F(None, description="Logit bias") | |
| user: _O[str] = _F(None, description="User ID") | |
| tools: _O[_L[_TD]] = _F(None, description="Available tools") | |
| tool_choice: _O[_U[str, _D[str, _A]]] = _F(None, description="Tool choice") | |
| functions: _O[_L[_FD]] = _F(None, description="Available functions") | |
| function_call: _O[_U[str, _D[str, _A]]] = _F(None, description="Function call") | |
| class _CCC(_BM): | |
| index: int = _F(default=0, description="Choice index") | |
| message: _CM = _F(..., description="Message") | |
| finish_reason: _O[str] = _F(None, description="Finish reason") | |
| class _CCSC(_BM): | |
| index: int = _F(default=0, description="Choice index") | |
| delta: _D[str, _A] = _F(..., description="Delta") | |
| finish_reason: _O[str] = _F(None, description="Finish reason") | |
| class _CCRes(_BM): | |
| id: str = _F(..., description="Completion ID") | |
| object: str = _F(default="chat.completion", description="Object type") | |
| created: int = _F(..., description="Created timestamp") | |
| model: str = _F(..., description="Model name") | |
| choices: _L[_CCC] = _F(..., description="Choices") | |
| usage: _D[str, int] = _F(..., description="Usage stats") | |
| system_fingerprint: _O[str] = _F(None, description="System fingerprint") | |
| class _CCSR(_BM): | |
| id: str = _F(..., description="Completion ID") | |
| object: str = _F(default="chat.completion.chunk", description="Object type") | |
| created: int = _F(..., description="Created timestamp") | |
| model: str = _F(..., description="Model name") | |
| choices: _L[_CCSC] = _F(..., description="Choices") | |
| system_fingerprint: _O[str] = _F(None, description="System fingerprint") | |
| class _OM(_BM): | |
| id: str = _F(..., description="Model ID") | |
| object: str = _F(default="model", description="Object type") | |
| created: int = _F(..., description="Created timestamp") | |
| owned_by: str = _F(..., description="Owner") | |
| # Replicate Client | |
| class _RC: | |
| def __init__(self, _tk=_TOKEN): | |
| _o.environ['REPLICATE_API_TOKEN'] = _tk | |
| self._client = _r | |
| self._models = _MODELS | |
| self._model_info = _MODEL_INFO | |
| def _get_replicate_model(self, _model_name): | |
| """Get the Replicate model ID from OpenAI model name""" | |
| return self._models.get(_model_name, _model_name) | |
| def _validate_model(self, _model_name): | |
| """Validate if model is supported""" | |
| return _model_name in self._models or _model_name in self._models.values() | |
| def _format_messages(self, _msgs): | |
| _prompt = "" | |
| _system = "" | |
| for _msg in _msgs: | |
| _role = _msg.get('role', '') | |
| _content = _msg.get('content', '') | |
| if _role == 'system': | |
| _system = _content | |
| elif _role == 'user': | |
| _prompt += f"Human: {_content}\n\n" | |
| elif _role == 'assistant': | |
| _prompt += f"Assistant: {_content}\n\n" | |
| _prompt += "Assistant: " | |
| return _prompt, _system | |
| def _sanitize_params(self, **_kwargs): | |
| """Sanitize parameters and set proper defaults""" | |
| _params = {} | |
| # Handle max_tokens | |
| _max_tokens = _kwargs.get('max_tokens') | |
| if _max_tokens is not None and _max_tokens > 0: | |
| # Replicate Anthropic models often require >= 1024; clamp to avoid 422s | |
| try: | |
| _mt = int(_max_tokens) | |
| except Exception: | |
| _mt = 4096 | |
| _params['max_tokens'] = max(1024, _mt) | |
| else: | |
| _params['max_tokens'] = 4096 | |
| # Handle temperature | |
| _temperature = _kwargs.get('temperature') | |
| if _temperature is not None: | |
| _params['temperature'] = max(0.0, min(2.0, float(_temperature))) | |
| else: | |
| _params['temperature'] = 0.7 | |
| # Handle top_p | |
| _top_p = _kwargs.get('top_p') | |
| if _top_p is not None: | |
| _params['top_p'] = max(0.0, min(1.0, float(_top_p))) | |
| else: | |
| _params['top_p'] = 1.0 | |
| # Handle presence_penalty | |
| _presence_penalty = _kwargs.get('presence_penalty') | |
| if _presence_penalty is not None: | |
| _params['presence_penalty'] = max(-2.0, min(2.0, float(_presence_penalty))) | |
| else: | |
| _params['presence_penalty'] = 0.0 | |
| # Handle frequency_penalty | |
| _frequency_penalty = _kwargs.get('frequency_penalty') | |
| if _frequency_penalty is not None: | |
| _params['frequency_penalty'] = max(-2.0, min(2.0, float(_frequency_penalty))) | |
| else: | |
| _params['frequency_penalty'] = 0.0 | |
| return _params | |
| def _create_prediction(self, _model_name, _prompt, _system="", **_kwargs): | |
| """Create a prediction using Replicate API""" | |
| _replicate_model = self._get_replicate_model(_model_name) | |
| _params = self._sanitize_params(**_kwargs) | |
| _input = { | |
| "prompt": _prompt, | |
| "system_prompt": _system, | |
| "max_tokens": _params['max_tokens'], | |
| "temperature": _params['temperature'], | |
| "top_p": _params['top_p'] | |
| } | |
| try: | |
| _prediction = self._client.predictions.create( | |
| model=_replicate_model, | |
| input=_input | |
| ) | |
| return _prediction | |
| except Exception as _e: | |
| _lg.error(f"Prediction creation error for {_replicate_model}: {_e}") | |
| return None | |
| def _handle_tools(self, _tools, _tool_choice): | |
| if not _tools: | |
| return "" | |
| _tool_prompt = "\n\nYou have access to the following tools:\n" | |
| for _tool in _tools: | |
| _func = _tool.get('function', {}) | |
| _name = _func.get('name', '') | |
| _desc = _func.get('description', '') | |
| _params = _func.get('parameters', {}) | |
| _tool_prompt += f"- {_name}: {_desc}\n" | |
| _tool_prompt += f" Parameters: {_j.dumps(_params)}\n" | |
| _tool_prompt += "\nTo use a tool, respond with JSON in this format:\n" | |
| _tool_prompt += '{"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "tool_name", "arguments": "{\\"param\\": \\"value\\"}"}}]}\n' | |
| return _tool_prompt | |
| def _stream_chat(self, _model_name, _prompt, _system="", **_kwargs): | |
| """Stream chat using Replicate's streaming API, yielding only text chunks.""" | |
| _replicate_model = self._get_replicate_model(_model_name) | |
| _params = self._sanitize_params(**_kwargs) | |
| _input = { | |
| "prompt": _prompt, | |
| "system_prompt": _system, | |
| "max_tokens": _params['max_tokens'], | |
| "temperature": _params['temperature'], | |
| "top_p": _params['top_p'] | |
| } | |
| # pass through stop sequences if provided | |
| if 'stop' in _kwargs and _kwargs['stop'] is not None: | |
| _input["stop"] = _kwargs['stop'] | |
| try: | |
| for _event in self._client.stream(_replicate_model, input=_input): | |
| if not _event: | |
| continue | |
| # Fast path: plain string/bytes token | |
| if isinstance(_event, (str, bytes)): | |
| yield (_event.decode('utf-8', errors='ignore') if isinstance(_event, bytes) else _event) | |
| continue | |
| # Normalize event interfaces (object, dict, or custom) | |
| _etype, _edata = None, None | |
| if isinstance(_event, dict): | |
| _etype = _event.get('type') or _event.get('event') | |
| _edata = _event.get('data') or _event.get('output') or _event.get('text') | |
| else: | |
| _etype = getattr(_event, 'type', None) or getattr(_event, 'event', None) | |
| _edata = getattr(_event, 'data', None) | |
| # Extract text payloads | |
| if _etype == "output" or _edata is not None: | |
| if isinstance(_edata, (list, tuple)): | |
| for _piece in _edata: | |
| if isinstance(_piece, (str, bytes)): | |
| yield (_piece.decode('utf-8', errors='ignore') if isinstance(_piece, bytes) else _piece) | |
| elif isinstance(_edata, (str, bytes)): | |
| yield (_edata.decode('utf-8', errors='ignore') if isinstance(_edata, bytes) else _edata) | |
| elif isinstance(_edata, dict): | |
| # Common nested keys | |
| for _k in ("text", "output", "delta"): | |
| if _k in _edata and isinstance(_edata[_k], (str, bytes)): | |
| _v = _edata[_k] | |
| yield (_v.decode('utf-8', errors='ignore') if isinstance(_v, bytes) else _v) | |
| break | |
| elif _etype in {"completed", "done", "end"}: | |
| break | |
| else: | |
| # Fallback to string form (restore old working behavior) | |
| try: | |
| _s = str(_event) | |
| if _s: | |
| yield _s | |
| except Exception: | |
| pass | |
| elif _etype in {"error", "logs", "warning"}: | |
| try: | |
| _lg.warning(f"Replicate stream {_etype}: {_edata}") | |
| except Exception: | |
| pass | |
| elif _etype in {"completed", "done", "end"}: | |
| break | |
| else: | |
| # Unknown/eventless object; fallback to string form | |
| try: | |
| _s = str(_event) | |
| if _s: | |
| yield _s | |
| except Exception: | |
| pass | |
| except Exception as _e: | |
| _lg.error(f"Streaming error for {_replicate_model}: {_e}") | |
| # Surface a minimal safe error token | |
| yield "" | |
| def _stream_from_prediction(self, _prediction): | |
| """Stream from a prediction using the stream URL""" | |
| try: | |
| import requests | |
| _stream_url = _prediction.urls.get('stream') | |
| if not _stream_url: | |
| _lg.error("No stream URL available") | |
| return | |
| _response = requests.get( | |
| _stream_url, | |
| headers={ | |
| "Accept": "text/event-stream", | |
| "Cache-Control": "no-store" | |
| }, | |
| stream=True | |
| ) | |
| for _line in _response.iter_lines(): | |
| if _line: | |
| _line = _line.decode('utf-8') | |
| if _line.startswith('data: '): | |
| _data = _line[6:] | |
| if _data != '[DONE]': | |
| yield _data | |
| else: | |
| break | |
| except Exception as _e: | |
| _lg.error(f"Stream from prediction error: {_e}") | |
| yield f"Error: {_e}" | |
| def _complete_chat(self, _model_name, _prompt, _system="", **_kwargs): | |
| """Complete chat using Replicate's run method and coalesce into a single string.""" | |
| _replicate_model = self._get_replicate_model(_model_name) | |
| _params = self._sanitize_params(**_kwargs) | |
| _input = { | |
| "prompt": _prompt, | |
| "system_prompt": _system, | |
| "max_tokens": _params['max_tokens'], | |
| "temperature": _params['temperature'], | |
| "top_p": _params['top_p'] | |
| } | |
| if 'stop' in _kwargs and _kwargs['stop'] is not None: | |
| _input["stop"] = _kwargs['stop'] | |
| try: | |
| _result = self._client.run(_replicate_model, input=_input) | |
| # If it's a list of strings or chunks, join | |
| if isinstance(_result, list): | |
| _joined = "".join([x.decode("utf-8", errors="ignore") if isinstance(x, bytes) else str(x) for x in _result]) | |
| return _joined | |
| # Some models return generators/iterables; accumulate | |
| try: | |
| from collections.abc import Iterator, Iterable | |
| if isinstance(_result, Iterator) or ( | |
| isinstance(_result, Iterable) and not isinstance(_result, (str, bytes)) | |
| ): | |
| _buf = [] | |
| for _piece in _result: | |
| if isinstance(_piece, (str, bytes)): | |
| _buf.append(_piece.decode("utf-8", errors="ignore") if isinstance(_piece, bytes) else _piece) | |
| else: | |
| _buf.append(str(_piece)) | |
| _text = "".join(_buf) | |
| if _text: | |
| return _text | |
| except Exception: | |
| pass | |
| # FileOutput or scalar: cast to string; if empty, safe fallback | |
| _text = str(_result) if _result is not None else "" | |
| return _text | |
| except Exception as _e: | |
| _lg.error(f"Completion error for {_replicate_model}: {_e}") | |
| # Return empty to avoid leaking internals into user-visible content | |
| return "" | |
| # Global variables | |
| _client = None | |
| _startup_time = _t.time() | |
| _request_count = 0 | |
| _error_count = 0 | |
| async def _lifespan(_app: _FA): | |
| global _client | |
| try: | |
| _lg.info("Initializing Replicate client...") | |
| _client = _RC() | |
| _lg.info("Replicate client initialized successfully") | |
| except Exception as _e: | |
| _lg.error(f"Failed to initialize client: {_e}") | |
| _client = None | |
| yield | |
| _lg.info("Shutting down Replicate client...") | |
| # FastAPI App | |
| _app = _FA( | |
| title="Replicate Claude-4-Sonnet OpenAI API", | |
| version="1.0.0", | |
| description="OpenAI-compatible API for Claude-4-Sonnet via Replicate", | |
| lifespan=_lifespan | |
| ) | |
| # CORS | |
| try: | |
| from fastapi.middleware.cors import CORSMiddleware as _CORS | |
| _app.add_middleware( | |
| _CORS, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| except ImportError: | |
| pass | |
| # Error handlers | |
| async def _http_exception_handler(_request, _exc: _HE): | |
| _lg.error(f"HTTP error: {_exc.status_code} - {_exc.detail}") | |
| return _JR( | |
| status_code=_exc.status_code, | |
| content={ | |
| "error": { | |
| "message": _exc.detail, | |
| "type": "api_error", | |
| "code": _exc.status_code | |
| } | |
| } | |
| ) | |
| async def _global_exception_handler(_request, _exc): | |
| _lg.error(f"Unexpected error: {_exc}\n{_tb.format_exc()}") | |
| return _JR( | |
| status_code=500, | |
| content={ | |
| "error": { | |
| "message": "Internal server error", | |
| "type": "server_error", | |
| "code": 500 | |
| } | |
| } | |
| ) | |
| async def _root(): | |
| _model_count = len([m for m in _MODELS.keys() if not m.startswith(('anthropic/', 'openai/'))]) | |
| return { | |
| "message": "Replicate Multi-Model OpenAI API", | |
| "version": "1.0.0", | |
| "status": "running", | |
| "supported_models": _model_count, | |
| "providers": ["anthropic", "openai"] | |
| } | |
| async def _health_check(): | |
| global _client, _startup_time, _request_count, _error_count | |
| _uptime = _t.time() - _startup_time | |
| _status = "healthy" | |
| _client_status = "unknown" | |
| if _client is None: | |
| _client_status = "not_initialized" | |
| _status = "degraded" | |
| else: | |
| _client_status = "ready" | |
| return { | |
| "status": _status, | |
| "timestamp": int(_t.time()), | |
| "uptime_seconds": int(_uptime), | |
| "client_status": _client_status, | |
| "stats": { | |
| "total_requests": _request_count, | |
| "total_errors": _error_count, | |
| "error_rate": _error_count / max(_request_count, 1) | |
| } | |
| } | |
| async def _list_models(): | |
| """List all supported models""" | |
| _models_list = [] | |
| _created_time = int(_t.time()) | |
| # Get unique model names (remove duplicates from alternative naming) | |
| _unique_models = set() | |
| for _model_name in _MODELS.keys(): | |
| if not _model_name.startswith(('anthropic/', 'openai/')): | |
| _unique_models.add(_model_name) | |
| # Create model objects | |
| for _model_name in sorted(_unique_models): | |
| _info = _MODEL_INFO.get(_model_name, {"owned_by": "unknown", "context_length": 4096}) | |
| _models_list.append(_OM( | |
| id=_model_name, | |
| created=_created_time, | |
| owned_by=_info["owned_by"] | |
| )) | |
| return { | |
| "object": "list", | |
| "data": _models_list | |
| } | |
| async def _list_models_alt(): | |
| return await _list_models() | |
| async def _generate_stream_response(_request: _CCR, _prompt: str, _system: str, _request_id: str = None): | |
| _completion_id = f"chatcmpl-{_u.uuid4().hex}" | |
| _created_time = int(_t.time()) | |
| _request_id = _request_id or f"req-{_u.uuid4().hex[:8]}" | |
| _lg.info(f"[{_request_id}] Starting stream generation") | |
| try: | |
| # Send initial chunk with role | |
| _initial_chunk = { | |
| "id": _completion_id, | |
| "object": "chat.completion.chunk", | |
| "created": _created_time, | |
| "model": _request.model, | |
| "choices": [{ | |
| "index": 0, | |
| "delta": {"role": "assistant"}, | |
| "finish_reason": None | |
| }] | |
| } | |
| yield f"data: {_j.dumps(_initial_chunk)}\n\n" | |
| # Stream content chunks using Replicate's streaming | |
| _chunk_count = 0 | |
| _total_content = "" | |
| try: | |
| # Extract only relevant parameters for Replicate API | |
| _api_params = { | |
| 'max_tokens': _request.max_tokens, | |
| 'temperature': _request.temperature, | |
| 'top_p': _request.top_p, | |
| 'presence_penalty': _request.presence_penalty, | |
| 'frequency_penalty': _request.frequency_penalty, | |
| 'stop': _request.stop | |
| } | |
| # Use Replicate's direct streaming method with model parameter | |
| for _chunk in _client._stream_chat(_request.model, _prompt, _system, **_api_params): | |
| if _chunk and isinstance(_chunk, str): | |
| _chunk_count += 1 | |
| _total_content += _chunk | |
| _stream_response = _CCSR( | |
| id=_completion_id, | |
| created=_created_time, | |
| model=_request.model, | |
| choices=[_CCSC( | |
| delta={"content": _chunk}, | |
| finish_reason=None | |
| )] | |
| ) | |
| try: | |
| _chunk_json = _j.dumps(_stream_response.model_dump()) | |
| yield f"data: {_chunk_json}\n\n" | |
| except Exception as _json_error: | |
| _lg.error(f"[{_request_id}] JSON serialization error: {_json_error}") | |
| continue | |
| except Exception as _stream_error: | |
| _lg.error(f"[{_request_id}] Streaming error after {_chunk_count} chunks: {_stream_error}") | |
| if _chunk_count == 0: | |
| _error_content = "I apologize, but I encountered an error while generating the response. Please try again." | |
| _error_response = _CCSR( | |
| id=_completion_id, | |
| created=_created_time, | |
| model=_request.model, | |
| choices=[_CCSC( | |
| delta={"content": _error_content}, | |
| finish_reason=None | |
| )] | |
| ) | |
| yield f"data: {_j.dumps(_error_response.model_dump())}\n\n" | |
| _lg.info(f"[{_request_id}] Stream completed: {_chunk_count} chunks, {len(_total_content)} characters") | |
| except Exception as _e: | |
| _lg.error(f"[{_request_id}] Critical streaming error: {_e}") | |
| _error_chunk = { | |
| "id": _completion_id, | |
| "object": "chat.completion.chunk", | |
| "created": _created_time, | |
| "model": _request.model, | |
| "choices": [{ | |
| "index": 0, | |
| "delta": {"content": "Error occurred while streaming response."}, | |
| "finish_reason": "stop" | |
| }] | |
| } | |
| yield f"data: {_j.dumps(_error_chunk)}\n\n" | |
| finally: | |
| try: | |
| _final_chunk = { | |
| "id": _completion_id, | |
| "object": "chat.completion.chunk", | |
| "created": _created_time, | |
| "model": _request.model, | |
| "choices": [{ | |
| "index": 0, | |
| "delta": {}, | |
| "finish_reason": "stop" | |
| }] | |
| } | |
| yield f"data: {_j.dumps(_final_chunk)}\n\n" | |
| yield "data: [DONE]\n\n" | |
| _lg.info(f"[{_request_id}] Stream finalized") | |
| except Exception as _final_error: | |
| _lg.error(f"[{_request_id}] Error sending final chunk: {_final_error}") | |
| yield "data: [DONE]\n\n" | |
| async def _create_chat_completion(_request: _CCR): | |
| global _request_count, _error_count, _client | |
| _request_count += 1 | |
| _request_id = f"req-{_u.uuid4().hex[:8]}" | |
| _lg.info(f"[{_request_id}] Chat completion request: model={_request.model}, stream={_request.stream}") | |
| if _client is None: | |
| _error_count += 1 | |
| _lg.error(f"[{_request_id}] Client not initialized") | |
| raise _HE(status_code=503, detail="Service temporarily unavailable") | |
| try: | |
| # Validate model | |
| if not _client._validate_model(_request.model): | |
| _supported_models = list(_MODELS.keys()) | |
| raise _HE(status_code=400, detail=f"Model '{_request.model}' not supported. Supported models: {_supported_models}") | |
| # Format messages | |
| _prompt, _system = _client._format_messages([_msg.model_dump() for _msg in _request.messages]) | |
| # Handle tools/functions | |
| if _request.tools or _request.functions: | |
| _tools = _request.tools or [_TD(function=_func) for _func in (_request.functions or [])] | |
| _tool_prompt = _client._handle_tools([_tool.model_dump() for _tool in _tools], _request.tool_choice) | |
| _prompt += _tool_prompt | |
| _lg.info(f"[{_request_id}] Formatted prompt length: {len(_prompt)}") | |
| # Extract only relevant parameters for Replicate API | |
| _api_params = { | |
| 'max_tokens': _request.max_tokens, | |
| 'temperature': _request.temperature, | |
| 'top_p': _request.top_p, | |
| 'presence_penalty': _request.presence_penalty, | |
| 'frequency_penalty': _request.frequency_penalty | |
| } | |
| _lg.info(f"[{_request_id}] API parameters: {_api_params}") | |
| # Stream or complete | |
| if _request.stream: | |
| _lg.info(f"[{_request_id}] Starting streaming response") | |
| return _SR( | |
| _generate_stream_response(_request, _prompt, _system, _request_id), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive" | |
| } | |
| ) | |
| else: | |
| # Non-streaming completion | |
| _lg.info(f"[{_request_id}] Starting non-streaming completion") | |
| _content = _client._complete_chat(_request.model, _prompt, _system, **_api_params) | |
| _completion_id = f"chatcmpl-{_u.uuid4().hex}" | |
| _created_time = int(_t.time()) | |
| # Check for tool calls in response | |
| _tool_calls = None | |
| _finish_reason = "stop" | |
| try: | |
| if _content.strip().startswith('{"tool_calls"'): | |
| _tool_data = _j.loads(_content.strip()) | |
| if "tool_calls" in _tool_data: | |
| _tool_calls = [_TC(**_tc) for _tc in _tool_data["tool_calls"]] | |
| _finish_reason = "tool_calls" | |
| _content = None | |
| except: | |
| pass | |
| _response = _CCRes( | |
| id=_completion_id, | |
| created=_created_time, | |
| model=_request.model, | |
| choices=[_CCC( | |
| message=_CM( | |
| role="assistant", | |
| content=_content, | |
| tool_calls=[_tc.model_dump() for _tc in _tool_calls] if _tool_calls else None | |
| ), | |
| finish_reason=_finish_reason | |
| )], | |
| usage={ | |
| "prompt_tokens": len(_prompt.split()), | |
| "completion_tokens": len(_content.split()) if _content else 0, | |
| "total_tokens": len(_prompt.split()) + (len(_content.split()) if _content else 0) | |
| } | |
| ) | |
| _lg.info(f"[{_request_id}] Non-streaming completion finished") | |
| return _response | |
| except _HE: | |
| _error_count += 1 | |
| raise | |
| except Exception as _e: | |
| _error_count += 1 | |
| _lg.error(f"[{_request_id}] Unexpected error: {_e}\n{_tb.format_exc()}") | |
| raise _HE(status_code=500, detail="Internal server error occurred") | |
| async def _create_chat_completion_alt(_request: _CCR): | |
| return await _create_chat_completion(_request) | |
| if __name__ == "__main__": | |
| try: | |
| import uvicorn as _uv | |
| _port = int(_o.getenv("PORT", 7860)) # Hugging Face default port | |
| _host = _o.getenv("HOST", "0.0.0.0") | |
| _lg.info(f"Starting Replicate Multi-Model server on {_host}:{_port}") | |
| _lg.info(f"Supported models: {list(_MODELS.keys())[:7]}") # Show first 7 models | |
| _uv.run( | |
| _app, | |
| host=_host, | |
| port=_port, | |
| reload=False, | |
| log_level="info", | |
| access_log=True | |
| ) | |
| except ImportError: | |
| _lg.error("uvicorn not installed. Install with: pip install uvicorn") | |
| except Exception as _e: | |
| _lg.error(f"Failed to start server: {_e}") |