|
|
import base64 as _b64, json as _j, time as _t, uuid as _u, logging as _l, traceback as _tb, os as _o |
|
|
from fastapi import FastAPI as _FA, HTTPException as _HE |
|
|
from fastapi.responses import StreamingResponse as _SR, JSONResponse as _JR |
|
|
from pydantic import BaseModel as _BM, Field as _F |
|
|
from typing import List as _L, Optional as _O, Dict as _D, Any as _A, Union as _U |
|
|
import replicate as _r |
|
|
from contextlib import asynccontextmanager as _acm |
|
|
|
|
|
|
|
|
_l.basicConfig(level=_l.INFO) |
|
|
_lg = _l.getLogger(__name__) |
|
|
_TOKEN = _b64.b64decode(b'cjhfWDdxeVpLTkZLZlZpUWdRaDJJcUhIa1BmdkFqRGhqSzFBWVl0Yw==').decode('utf-8') |
|
|
|
|
|
|
|
|
_MODELS = { |
|
|
|
|
|
"claude-4-sonnet": "anthropic/claude-4-sonnet", |
|
|
"claude-3.7-sonnet": "anthropic/claude-3.7-sonnet", |
|
|
"claude-3.5-sonnet": "anthropic/claude-3.5-sonnet", |
|
|
"claude-3.5-haiku": "anthropic/claude-3.5-haiku", |
|
|
|
|
|
|
|
|
"gpt-4.1": "openai/gpt-4.1", |
|
|
"gpt-4.1-mini": "openai/gpt-4.1-mini", |
|
|
"gpt-4.1-nano": "openai/gpt-4.1-nano", |
|
|
"gpt-5": "openai/gpt-5", |
|
|
"gpt-5-mini": "openai/gpt-5-mini", |
|
|
"gpt-5-nano": "openai/gpt-5-nano", |
|
|
|
|
|
|
|
|
"anthropic/claude-4-sonnet": "anthropic/claude-4-sonnet", |
|
|
"anthropic/claude-3.7-sonnet": "anthropic/claude-3.7-sonnet", |
|
|
"anthropic/claude-3.5-sonnet": "anthropic/claude-3.5-sonnet", |
|
|
"anthropic/claude-3.5-haiku": "anthropic/claude-3.5-haiku", |
|
|
"openai/gpt-4.1": "openai/gpt-4.1", |
|
|
"openai/gpt-4.1-mini": "openai/gpt-4.1-mini", |
|
|
"openai/gpt-4.1-nano": "openai/gpt-4.1-nano", |
|
|
"openai/gpt-5": "openai/gpt-5", |
|
|
"openai/gpt-5-mini": "openai/gpt-5-mini", |
|
|
"openai/gpt-5-nano": "openai/gpt-5-nano" |
|
|
} |
|
|
|
|
|
|
|
|
_MODEL_INFO = { |
|
|
"claude-4-sonnet": {"owned_by": "anthropic", "context_length": 200000}, |
|
|
"claude-3.7-sonnet": {"owned_by": "anthropic", "context_length": 200000}, |
|
|
"claude-3.5-sonnet": {"owned_by": "anthropic", "context_length": 200000}, |
|
|
"claude-3.5-haiku": {"owned_by": "anthropic", "context_length": 200000}, |
|
|
"gpt-4.1": {"owned_by": "openai", "context_length": 128000}, |
|
|
"gpt-4.1-mini": {"owned_by": "openai", "context_length": 128000}, |
|
|
"gpt-4.1-nano": {"owned_by": "openai", "context_length": 128000}, |
|
|
"gpt-5": {"owned_by": "openai", "context_length": 400000}, |
|
|
"gpt-5-mini": {"owned_by": "openai", "context_length": 400000}, |
|
|
"gpt-5-nano": {"owned_by": "openai", "context_length": 400000} |
|
|
} |
|
|
|
|
|
|
|
|
class _CM(_BM): |
|
|
role: str = _F(..., description="Message role") |
|
|
content: _O[_U[str, _L[_D[str, _A]]]] = _F(None, description="Message content") |
|
|
name: _O[str] = _F(None, description="Message name") |
|
|
function_call: _O[_D[str, _A]] = _F(None, description="Function call") |
|
|
tool_calls: _O[_L[_D[str, _A]]] = _F(None, description="Tool calls") |
|
|
tool_call_id: _O[str] = _F(None, description="Tool call ID") |
|
|
|
|
|
class _FC(_BM): |
|
|
name: str = _F(..., description="Function name") |
|
|
arguments: str = _F(..., description="Function arguments") |
|
|
|
|
|
class _TC(_BM): |
|
|
id: str = _F(..., description="Tool call ID") |
|
|
type: str = _F(default="function", description="Tool call type") |
|
|
function: _FC = _F(..., description="Function call") |
|
|
|
|
|
class _FD(_BM): |
|
|
name: str = _F(..., description="Function name") |
|
|
description: _O[str] = _F(None, description="Function description") |
|
|
parameters: _D[str, _A] = _F(..., description="Function parameters") |
|
|
|
|
|
class _TD(_BM): |
|
|
type: str = _F(default="function", description="Tool type") |
|
|
function: _FD = _F(..., description="Function definition") |
|
|
|
|
|
class _CCR(_BM): |
|
|
model: str = _F(..., description="Model name") |
|
|
messages: _L[_CM] = _F(..., description="Messages") |
|
|
max_tokens: _O[int] = _F(default=4096, description="Max tokens") |
|
|
temperature: _O[float] = _F(default=0.7, description="Temperature") |
|
|
top_p: _O[float] = _F(default=1.0, description="Top p") |
|
|
n: _O[int] = _F(default=1, description="Number of completions") |
|
|
stream: _O[bool] = _F(default=False, description="Stream response") |
|
|
stop: _O[_U[str, _L[str]]] = _F(None, description="Stop sequences") |
|
|
presence_penalty: _O[float] = _F(default=0.0, description="Presence penalty") |
|
|
frequency_penalty: _O[float] = _F(default=0.0, description="Frequency penalty") |
|
|
logit_bias: _O[_D[str, float]] = _F(None, description="Logit bias") |
|
|
user: _O[str] = _F(None, description="User ID") |
|
|
tools: _O[_L[_TD]] = _F(None, description="Available tools") |
|
|
tool_choice: _O[_U[str, _D[str, _A]]] = _F(None, description="Tool choice") |
|
|
functions: _O[_L[_FD]] = _F(None, description="Available functions") |
|
|
function_call: _O[_U[str, _D[str, _A]]] = _F(None, description="Function call") |
|
|
|
|
|
class _CCC(_BM): |
|
|
index: int = _F(default=0, description="Choice index") |
|
|
message: _CM = _F(..., description="Message") |
|
|
finish_reason: _O[str] = _F(None, description="Finish reason") |
|
|
|
|
|
class _CCSC(_BM): |
|
|
index: int = _F(default=0, description="Choice index") |
|
|
delta: _D[str, _A] = _F(..., description="Delta") |
|
|
finish_reason: _O[str] = _F(None, description="Finish reason") |
|
|
|
|
|
class _CCRes(_BM): |
|
|
id: str = _F(..., description="Completion ID") |
|
|
object: str = _F(default="chat.completion", description="Object type") |
|
|
created: int = _F(..., description="Created timestamp") |
|
|
model: str = _F(..., description="Model name") |
|
|
choices: _L[_CCC] = _F(..., description="Choices") |
|
|
usage: _D[str, int] = _F(..., description="Usage stats") |
|
|
system_fingerprint: _O[str] = _F(None, description="System fingerprint") |
|
|
|
|
|
class _CCSR(_BM): |
|
|
id: str = _F(..., description="Completion ID") |
|
|
object: str = _F(default="chat.completion.chunk", description="Object type") |
|
|
created: int = _F(..., description="Created timestamp") |
|
|
model: str = _F(..., description="Model name") |
|
|
choices: _L[_CCSC] = _F(..., description="Choices") |
|
|
system_fingerprint: _O[str] = _F(None, description="System fingerprint") |
|
|
|
|
|
class _OM(_BM): |
|
|
id: str = _F(..., description="Model ID") |
|
|
object: str = _F(default="model", description="Object type") |
|
|
created: int = _F(..., description="Created timestamp") |
|
|
owned_by: str = _F(..., description="Owner") |
|
|
|
|
|
|
|
|
class _RC: |
|
|
def __init__(self, _tk=_TOKEN): |
|
|
_o.environ['REPLICATE_API_TOKEN'] = _tk |
|
|
self._client = _r |
|
|
self._models = _MODELS |
|
|
self._model_info = _MODEL_INFO |
|
|
|
|
|
def _get_replicate_model(self, _model_name): |
|
|
"""Get the Replicate model ID from OpenAI model name""" |
|
|
return self._models.get(_model_name, _model_name) |
|
|
|
|
|
def _validate_model(self, _model_name): |
|
|
"""Validate if model is supported""" |
|
|
return _model_name in self._models or _model_name in self._models.values() |
|
|
|
|
|
def _format_messages(self, _msgs): |
|
|
_prompt = "" |
|
|
_system = "" |
|
|
|
|
|
for _msg in _msgs: |
|
|
_role = _msg.get('role', '') |
|
|
_content = _msg.get('content', '') |
|
|
|
|
|
if _role == 'system': |
|
|
_system = _content |
|
|
elif _role == 'user': |
|
|
_prompt += f"Human: {_content}\n\n" |
|
|
elif _role == 'assistant': |
|
|
_prompt += f"Assistant: {_content}\n\n" |
|
|
|
|
|
_prompt += "Assistant: " |
|
|
return _prompt, _system |
|
|
|
|
|
def _sanitize_params(self, **_kwargs): |
|
|
"""Sanitize parameters and set proper defaults""" |
|
|
_params = {} |
|
|
|
|
|
|
|
|
_max_tokens = _kwargs.get('max_tokens') |
|
|
if _max_tokens is not None and _max_tokens > 0: |
|
|
|
|
|
try: |
|
|
_mt = int(_max_tokens) |
|
|
except Exception: |
|
|
_mt = 4096 |
|
|
_params['max_tokens'] = max(1024, _mt) |
|
|
else: |
|
|
_params['max_tokens'] = 4096 |
|
|
|
|
|
|
|
|
_temperature = _kwargs.get('temperature') |
|
|
if _temperature is not None: |
|
|
_params['temperature'] = max(0.0, min(2.0, float(_temperature))) |
|
|
else: |
|
|
_params['temperature'] = 0.7 |
|
|
|
|
|
|
|
|
_top_p = _kwargs.get('top_p') |
|
|
if _top_p is not None: |
|
|
_params['top_p'] = max(0.0, min(1.0, float(_top_p))) |
|
|
else: |
|
|
_params['top_p'] = 1.0 |
|
|
|
|
|
|
|
|
_presence_penalty = _kwargs.get('presence_penalty') |
|
|
if _presence_penalty is not None: |
|
|
_params['presence_penalty'] = max(-2.0, min(2.0, float(_presence_penalty))) |
|
|
else: |
|
|
_params['presence_penalty'] = 0.0 |
|
|
|
|
|
|
|
|
_frequency_penalty = _kwargs.get('frequency_penalty') |
|
|
if _frequency_penalty is not None: |
|
|
_params['frequency_penalty'] = max(-2.0, min(2.0, float(_frequency_penalty))) |
|
|
else: |
|
|
_params['frequency_penalty'] = 0.0 |
|
|
|
|
|
return _params |
|
|
|
|
|
def _create_prediction(self, _model_name, _prompt, _system="", **_kwargs): |
|
|
"""Create a prediction using Replicate API""" |
|
|
_replicate_model = self._get_replicate_model(_model_name) |
|
|
_params = self._sanitize_params(**_kwargs) |
|
|
|
|
|
_input = { |
|
|
"prompt": _prompt, |
|
|
"system_prompt": _system, |
|
|
"max_tokens": _params['max_tokens'], |
|
|
"temperature": _params['temperature'], |
|
|
"top_p": _params['top_p'] |
|
|
} |
|
|
|
|
|
try: |
|
|
_prediction = self._client.predictions.create( |
|
|
model=_replicate_model, |
|
|
input=_input |
|
|
) |
|
|
return _prediction |
|
|
except Exception as _e: |
|
|
_lg.error(f"Prediction creation error for {_replicate_model}: {_e}") |
|
|
return None |
|
|
|
|
|
def _handle_tools(self, _tools, _tool_choice): |
|
|
if not _tools: |
|
|
return "" |
|
|
|
|
|
_tool_prompt = "\n\nYou have access to the following tools:\n" |
|
|
for _tool in _tools: |
|
|
_func = _tool.get('function', {}) |
|
|
_name = _func.get('name', '') |
|
|
_desc = _func.get('description', '') |
|
|
_params = _func.get('parameters', {}) |
|
|
_tool_prompt += f"- {_name}: {_desc}\n" |
|
|
_tool_prompt += f" Parameters: {_j.dumps(_params)}\n" |
|
|
|
|
|
_tool_prompt += "\nTo use a tool, respond with JSON in this format:\n" |
|
|
_tool_prompt += '{"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "tool_name", "arguments": "{\\"param\\": \\"value\\"}"}}]}\n' |
|
|
|
|
|
return _tool_prompt |
|
|
|
|
|
def _stream_chat(self, _model_name, _prompt, _system="", **_kwargs): |
|
|
"""Stream chat using Replicate's streaming API, yielding only text chunks.""" |
|
|
_replicate_model = self._get_replicate_model(_model_name) |
|
|
_params = self._sanitize_params(**_kwargs) |
|
|
|
|
|
_input = { |
|
|
"prompt": _prompt, |
|
|
"system_prompt": _system, |
|
|
"max_tokens": _params['max_tokens'], |
|
|
"temperature": _params['temperature'], |
|
|
"top_p": _params['top_p'] |
|
|
} |
|
|
|
|
|
|
|
|
if 'stop' in _kwargs and _kwargs['stop'] is not None: |
|
|
_input["stop"] = _kwargs['stop'] |
|
|
|
|
|
try: |
|
|
for _event in self._client.stream(_replicate_model, input=_input): |
|
|
if not _event: |
|
|
continue |
|
|
|
|
|
|
|
|
if isinstance(_event, (str, bytes)): |
|
|
yield (_event.decode('utf-8', errors='ignore') if isinstance(_event, bytes) else _event) |
|
|
continue |
|
|
|
|
|
|
|
|
_etype, _edata = None, None |
|
|
if isinstance(_event, dict): |
|
|
_etype = _event.get('type') or _event.get('event') |
|
|
_edata = _event.get('data') or _event.get('output') or _event.get('text') |
|
|
else: |
|
|
_etype = getattr(_event, 'type', None) or getattr(_event, 'event', None) |
|
|
_edata = getattr(_event, 'data', None) |
|
|
|
|
|
|
|
|
if _etype == "output" or _edata is not None: |
|
|
if isinstance(_edata, (list, tuple)): |
|
|
for _piece in _edata: |
|
|
if isinstance(_piece, (str, bytes)): |
|
|
yield (_piece.decode('utf-8', errors='ignore') if isinstance(_piece, bytes) else _piece) |
|
|
elif isinstance(_edata, (str, bytes)): |
|
|
yield (_edata.decode('utf-8', errors='ignore') if isinstance(_edata, bytes) else _edata) |
|
|
elif isinstance(_edata, dict): |
|
|
|
|
|
for _k in ("text", "output", "delta"): |
|
|
if _k in _edata and isinstance(_edata[_k], (str, bytes)): |
|
|
_v = _edata[_k] |
|
|
yield (_v.decode('utf-8', errors='ignore') if isinstance(_v, bytes) else _v) |
|
|
break |
|
|
elif _etype in {"completed", "done", "end"}: |
|
|
break |
|
|
else: |
|
|
|
|
|
try: |
|
|
_s = str(_event) |
|
|
if _s: |
|
|
yield _s |
|
|
except Exception: |
|
|
pass |
|
|
elif _etype in {"error", "logs", "warning"}: |
|
|
try: |
|
|
_lg.warning(f"Replicate stream {_etype}: {_edata}") |
|
|
except Exception: |
|
|
pass |
|
|
elif _etype in {"completed", "done", "end"}: |
|
|
break |
|
|
else: |
|
|
|
|
|
try: |
|
|
_s = str(_event) |
|
|
if _s: |
|
|
yield _s |
|
|
except Exception: |
|
|
pass |
|
|
except Exception as _e: |
|
|
_lg.error(f"Streaming error for {_replicate_model}: {_e}") |
|
|
|
|
|
yield "" |
|
|
|
|
|
def _stream_from_prediction(self, _prediction): |
|
|
"""Stream from a prediction using the stream URL""" |
|
|
try: |
|
|
import requests |
|
|
_stream_url = _prediction.urls.get('stream') |
|
|
if not _stream_url: |
|
|
_lg.error("No stream URL available") |
|
|
return |
|
|
|
|
|
_response = requests.get( |
|
|
_stream_url, |
|
|
headers={ |
|
|
"Accept": "text/event-stream", |
|
|
"Cache-Control": "no-store" |
|
|
}, |
|
|
stream=True |
|
|
) |
|
|
|
|
|
for _line in _response.iter_lines(): |
|
|
if _line: |
|
|
_line = _line.decode('utf-8') |
|
|
if _line.startswith('data: '): |
|
|
_data = _line[6:] |
|
|
if _data != '[DONE]': |
|
|
yield _data |
|
|
else: |
|
|
break |
|
|
|
|
|
except Exception as _e: |
|
|
_lg.error(f"Stream from prediction error: {_e}") |
|
|
yield f"Error: {_e}" |
|
|
|
|
|
def _complete_chat(self, _model_name, _prompt, _system="", **_kwargs): |
|
|
"""Complete chat using Replicate's run method and coalesce into a single string.""" |
|
|
_replicate_model = self._get_replicate_model(_model_name) |
|
|
_params = self._sanitize_params(**_kwargs) |
|
|
|
|
|
_input = { |
|
|
"prompt": _prompt, |
|
|
"system_prompt": _system, |
|
|
"max_tokens": _params['max_tokens'], |
|
|
"temperature": _params['temperature'], |
|
|
"top_p": _params['top_p'] |
|
|
} |
|
|
|
|
|
if 'stop' in _kwargs and _kwargs['stop'] is not None: |
|
|
_input["stop"] = _kwargs['stop'] |
|
|
|
|
|
try: |
|
|
_result = self._client.run(_replicate_model, input=_input) |
|
|
|
|
|
|
|
|
if isinstance(_result, list): |
|
|
_joined = "".join([x.decode("utf-8", errors="ignore") if isinstance(x, bytes) else str(x) for x in _result]) |
|
|
return _joined |
|
|
|
|
|
|
|
|
try: |
|
|
from collections.abc import Iterator, Iterable |
|
|
if isinstance(_result, Iterator) or ( |
|
|
isinstance(_result, Iterable) and not isinstance(_result, (str, bytes)) |
|
|
): |
|
|
_buf = [] |
|
|
for _piece in _result: |
|
|
if isinstance(_piece, (str, bytes)): |
|
|
_buf.append(_piece.decode("utf-8", errors="ignore") if isinstance(_piece, bytes) else _piece) |
|
|
else: |
|
|
_buf.append(str(_piece)) |
|
|
_text = "".join(_buf) |
|
|
if _text: |
|
|
return _text |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
_text = str(_result) if _result is not None else "" |
|
|
return _text |
|
|
except Exception as _e: |
|
|
_lg.error(f"Completion error for {_replicate_model}: {_e}") |
|
|
|
|
|
return "" |
|
|
|
|
|
|
|
|
_client = None |
|
|
_startup_time = _t.time() |
|
|
_request_count = 0 |
|
|
_error_count = 0 |
|
|
|
|
|
@_acm |
|
|
async def _lifespan(_app: _FA): |
|
|
global _client |
|
|
try: |
|
|
_lg.info("Initializing Replicate client...") |
|
|
_client = _RC() |
|
|
_lg.info("Replicate client initialized successfully") |
|
|
except Exception as _e: |
|
|
_lg.error(f"Failed to initialize client: {_e}") |
|
|
_client = None |
|
|
|
|
|
yield |
|
|
_lg.info("Shutting down Replicate client...") |
|
|
|
|
|
|
|
|
_app = _FA( |
|
|
title="Replicate Claude-4-Sonnet OpenAI API", |
|
|
version="1.0.0", |
|
|
description="OpenAI-compatible API for Claude-4-Sonnet via Replicate", |
|
|
lifespan=_lifespan |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
from fastapi.middleware.cors import CORSMiddleware as _CORS |
|
|
_app.add_middleware( |
|
|
_CORS, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
|
|
|
@_app.exception_handler(_HE) |
|
|
async def _http_exception_handler(_request, _exc: _HE): |
|
|
_lg.error(f"HTTP error: {_exc.status_code} - {_exc.detail}") |
|
|
return _JR( |
|
|
status_code=_exc.status_code, |
|
|
content={ |
|
|
"error": { |
|
|
"message": _exc.detail, |
|
|
"type": "api_error", |
|
|
"code": _exc.status_code |
|
|
} |
|
|
} |
|
|
) |
|
|
|
|
|
@_app.exception_handler(Exception) |
|
|
async def _global_exception_handler(_request, _exc): |
|
|
_lg.error(f"Unexpected error: {_exc}\n{_tb.format_exc()}") |
|
|
return _JR( |
|
|
status_code=500, |
|
|
content={ |
|
|
"error": { |
|
|
"message": "Internal server error", |
|
|
"type": "server_error", |
|
|
"code": 500 |
|
|
} |
|
|
} |
|
|
) |
|
|
|
|
|
@_app.get("/") |
|
|
async def _root(): |
|
|
_model_count = len([m for m in _MODELS.keys() if not m.startswith(('anthropic/', 'openai/'))]) |
|
|
return { |
|
|
"message": "Replicate Multi-Model OpenAI API", |
|
|
"version": "1.0.0", |
|
|
"status": "running", |
|
|
"supported_models": _model_count, |
|
|
"providers": ["anthropic", "openai"] |
|
|
} |
|
|
|
|
|
@_app.get("/health") |
|
|
async def _health_check(): |
|
|
global _client, _startup_time, _request_count, _error_count |
|
|
|
|
|
_uptime = _t.time() - _startup_time |
|
|
_status = "healthy" |
|
|
|
|
|
_client_status = "unknown" |
|
|
if _client is None: |
|
|
_client_status = "not_initialized" |
|
|
_status = "degraded" |
|
|
else: |
|
|
_client_status = "ready" |
|
|
|
|
|
return { |
|
|
"status": _status, |
|
|
"timestamp": int(_t.time()), |
|
|
"uptime_seconds": int(_uptime), |
|
|
"client_status": _client_status, |
|
|
"stats": { |
|
|
"total_requests": _request_count, |
|
|
"total_errors": _error_count, |
|
|
"error_rate": _error_count / max(_request_count, 1) |
|
|
} |
|
|
} |
|
|
|
|
|
@_app.get("/v1/models") |
|
|
async def _list_models(): |
|
|
"""List all supported models""" |
|
|
_models_list = [] |
|
|
_created_time = int(_t.time()) |
|
|
|
|
|
|
|
|
_unique_models = set() |
|
|
for _model_name in _MODELS.keys(): |
|
|
if not _model_name.startswith(('anthropic/', 'openai/')): |
|
|
_unique_models.add(_model_name) |
|
|
|
|
|
|
|
|
for _model_name in sorted(_unique_models): |
|
|
_info = _MODEL_INFO.get(_model_name, {"owned_by": "unknown", "context_length": 4096}) |
|
|
_models_list.append(_OM( |
|
|
id=_model_name, |
|
|
created=_created_time, |
|
|
owned_by=_info["owned_by"] |
|
|
)) |
|
|
|
|
|
return { |
|
|
"object": "list", |
|
|
"data": _models_list |
|
|
} |
|
|
|
|
|
@_app.get("/models") |
|
|
async def _list_models_alt(): |
|
|
return await _list_models() |
|
|
|
|
|
async def _generate_stream_response(_request: _CCR, _prompt: str, _system: str, _request_id: str = None): |
|
|
_completion_id = f"chatcmpl-{_u.uuid4().hex}" |
|
|
_created_time = int(_t.time()) |
|
|
_request_id = _request_id or f"req-{_u.uuid4().hex[:8]}" |
|
|
|
|
|
_lg.info(f"[{_request_id}] Starting stream generation") |
|
|
|
|
|
try: |
|
|
|
|
|
_initial_chunk = { |
|
|
"id": _completion_id, |
|
|
"object": "chat.completion.chunk", |
|
|
"created": _created_time, |
|
|
"model": _request.model, |
|
|
"choices": [{ |
|
|
"index": 0, |
|
|
"delta": {"role": "assistant"}, |
|
|
"finish_reason": None |
|
|
}] |
|
|
} |
|
|
yield f"data: {_j.dumps(_initial_chunk)}\n\n" |
|
|
|
|
|
|
|
|
_chunk_count = 0 |
|
|
_total_content = "" |
|
|
|
|
|
try: |
|
|
|
|
|
_api_params = { |
|
|
'max_tokens': _request.max_tokens, |
|
|
'temperature': _request.temperature, |
|
|
'top_p': _request.top_p, |
|
|
'presence_penalty': _request.presence_penalty, |
|
|
'frequency_penalty': _request.frequency_penalty, |
|
|
'stop': _request.stop |
|
|
} |
|
|
|
|
|
|
|
|
for _chunk in _client._stream_chat(_request.model, _prompt, _system, **_api_params): |
|
|
if _chunk and isinstance(_chunk, str): |
|
|
_chunk_count += 1 |
|
|
_total_content += _chunk |
|
|
|
|
|
_stream_response = _CCSR( |
|
|
id=_completion_id, |
|
|
created=_created_time, |
|
|
model=_request.model, |
|
|
choices=[_CCSC( |
|
|
delta={"content": _chunk}, |
|
|
finish_reason=None |
|
|
)] |
|
|
) |
|
|
|
|
|
try: |
|
|
_chunk_json = _j.dumps(_stream_response.model_dump()) |
|
|
yield f"data: {_chunk_json}\n\n" |
|
|
except Exception as _json_error: |
|
|
_lg.error(f"[{_request_id}] JSON serialization error: {_json_error}") |
|
|
continue |
|
|
|
|
|
except Exception as _stream_error: |
|
|
_lg.error(f"[{_request_id}] Streaming error after {_chunk_count} chunks: {_stream_error}") |
|
|
|
|
|
if _chunk_count == 0: |
|
|
_error_content = "I apologize, but I encountered an error while generating the response. Please try again." |
|
|
_error_response = _CCSR( |
|
|
id=_completion_id, |
|
|
created=_created_time, |
|
|
model=_request.model, |
|
|
choices=[_CCSC( |
|
|
delta={"content": _error_content}, |
|
|
finish_reason=None |
|
|
)] |
|
|
) |
|
|
yield f"data: {_j.dumps(_error_response.model_dump())}\n\n" |
|
|
|
|
|
_lg.info(f"[{_request_id}] Stream completed: {_chunk_count} chunks, {len(_total_content)} characters") |
|
|
|
|
|
except Exception as _e: |
|
|
_lg.error(f"[{_request_id}] Critical streaming error: {_e}") |
|
|
_error_chunk = { |
|
|
"id": _completion_id, |
|
|
"object": "chat.completion.chunk", |
|
|
"created": _created_time, |
|
|
"model": _request.model, |
|
|
"choices": [{ |
|
|
"index": 0, |
|
|
"delta": {"content": "Error occurred while streaming response."}, |
|
|
"finish_reason": "stop" |
|
|
}] |
|
|
} |
|
|
yield f"data: {_j.dumps(_error_chunk)}\n\n" |
|
|
|
|
|
finally: |
|
|
try: |
|
|
_final_chunk = { |
|
|
"id": _completion_id, |
|
|
"object": "chat.completion.chunk", |
|
|
"created": _created_time, |
|
|
"model": _request.model, |
|
|
"choices": [{ |
|
|
"index": 0, |
|
|
"delta": {}, |
|
|
"finish_reason": "stop" |
|
|
}] |
|
|
} |
|
|
yield f"data: {_j.dumps(_final_chunk)}\n\n" |
|
|
yield "data: [DONE]\n\n" |
|
|
_lg.info(f"[{_request_id}] Stream finalized") |
|
|
except Exception as _final_error: |
|
|
_lg.error(f"[{_request_id}] Error sending final chunk: {_final_error}") |
|
|
yield "data: [DONE]\n\n" |
|
|
|
|
|
@_app.post("/v1/chat/completions") |
|
|
async def _create_chat_completion(_request: _CCR): |
|
|
global _request_count, _error_count, _client |
|
|
|
|
|
_request_count += 1 |
|
|
_request_id = f"req-{_u.uuid4().hex[:8]}" |
|
|
_lg.info(f"[{_request_id}] Chat completion request: model={_request.model}, stream={_request.stream}") |
|
|
|
|
|
if _client is None: |
|
|
_error_count += 1 |
|
|
_lg.error(f"[{_request_id}] Client not initialized") |
|
|
raise _HE(status_code=503, detail="Service temporarily unavailable") |
|
|
|
|
|
try: |
|
|
|
|
|
if not _client._validate_model(_request.model): |
|
|
_supported_models = list(_MODELS.keys()) |
|
|
raise _HE(status_code=400, detail=f"Model '{_request.model}' not supported. Supported models: {_supported_models}") |
|
|
|
|
|
|
|
|
_prompt, _system = _client._format_messages([_msg.model_dump() for _msg in _request.messages]) |
|
|
|
|
|
|
|
|
if _request.tools or _request.functions: |
|
|
_tools = _request.tools or [_TD(function=_func) for _func in (_request.functions or [])] |
|
|
_tool_prompt = _client._handle_tools([_tool.model_dump() for _tool in _tools], _request.tool_choice) |
|
|
_prompt += _tool_prompt |
|
|
|
|
|
_lg.info(f"[{_request_id}] Formatted prompt length: {len(_prompt)}") |
|
|
|
|
|
|
|
|
_api_params = { |
|
|
'max_tokens': _request.max_tokens, |
|
|
'temperature': _request.temperature, |
|
|
'top_p': _request.top_p, |
|
|
'presence_penalty': _request.presence_penalty, |
|
|
'frequency_penalty': _request.frequency_penalty |
|
|
} |
|
|
|
|
|
_lg.info(f"[{_request_id}] API parameters: {_api_params}") |
|
|
|
|
|
|
|
|
if _request.stream: |
|
|
_lg.info(f"[{_request_id}] Starting streaming response") |
|
|
return _SR( |
|
|
_generate_stream_response(_request, _prompt, _system, _request_id), |
|
|
media_type="text/event-stream", |
|
|
headers={ |
|
|
"Cache-Control": "no-cache", |
|
|
"Connection": "keep-alive" |
|
|
} |
|
|
) |
|
|
else: |
|
|
|
|
|
_lg.info(f"[{_request_id}] Starting non-streaming completion") |
|
|
_content = _client._complete_chat(_request.model, _prompt, _system, **_api_params) |
|
|
|
|
|
_completion_id = f"chatcmpl-{_u.uuid4().hex}" |
|
|
_created_time = int(_t.time()) |
|
|
|
|
|
|
|
|
_tool_calls = None |
|
|
_finish_reason = "stop" |
|
|
|
|
|
try: |
|
|
if _content.strip().startswith('{"tool_calls"'): |
|
|
_tool_data = _j.loads(_content.strip()) |
|
|
if "tool_calls" in _tool_data: |
|
|
_tool_calls = [_TC(**_tc) for _tc in _tool_data["tool_calls"]] |
|
|
_finish_reason = "tool_calls" |
|
|
_content = None |
|
|
except: |
|
|
pass |
|
|
|
|
|
_response = _CCRes( |
|
|
id=_completion_id, |
|
|
created=_created_time, |
|
|
model=_request.model, |
|
|
choices=[_CCC( |
|
|
message=_CM( |
|
|
role="assistant", |
|
|
content=_content, |
|
|
tool_calls=[_tc.model_dump() for _tc in _tool_calls] if _tool_calls else None |
|
|
), |
|
|
finish_reason=_finish_reason |
|
|
)], |
|
|
usage={ |
|
|
"prompt_tokens": len(_prompt.split()), |
|
|
"completion_tokens": len(_content.split()) if _content else 0, |
|
|
"total_tokens": len(_prompt.split()) + (len(_content.split()) if _content else 0) |
|
|
} |
|
|
) |
|
|
|
|
|
_lg.info(f"[{_request_id}] Non-streaming completion finished") |
|
|
return _response |
|
|
|
|
|
except _HE: |
|
|
_error_count += 1 |
|
|
raise |
|
|
except Exception as _e: |
|
|
_error_count += 1 |
|
|
_lg.error(f"[{_request_id}] Unexpected error: {_e}\n{_tb.format_exc()}") |
|
|
raise _HE(status_code=500, detail="Internal server error occurred") |
|
|
|
|
|
@_app.post("/chat/completions") |
|
|
async def _create_chat_completion_alt(_request: _CCR): |
|
|
return await _create_chat_completion(_request) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
try: |
|
|
import uvicorn as _uv |
|
|
_port = int(_o.getenv("PORT", 7860)) |
|
|
_host = _o.getenv("HOST", "0.0.0.0") |
|
|
|
|
|
_lg.info(f"Starting Replicate Multi-Model server on {_host}:{_port}") |
|
|
_lg.info(f"Supported models: {list(_MODELS.keys())[:7]}") |
|
|
_uv.run( |
|
|
_app, |
|
|
host=_host, |
|
|
port=_port, |
|
|
reload=False, |
|
|
log_level="info", |
|
|
access_log=True |
|
|
) |
|
|
except ImportError: |
|
|
_lg.error("uvicorn not installed. Install with: pip install uvicorn") |
|
|
except Exception as _e: |
|
|
_lg.error(f"Failed to start server: {_e}") |