Update app.py
Browse files
app.py
CHANGED
|
@@ -173,7 +173,12 @@ class _RC:
|
|
| 173 |
# Handle max_tokens
|
| 174 |
_max_tokens = _kwargs.get('max_tokens')
|
| 175 |
if _max_tokens is not None and _max_tokens > 0:
|
| 176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
else:
|
| 178 |
_params['max_tokens'] = 4096
|
| 179 |
|
|
@@ -249,7 +254,7 @@ class _RC:
|
|
| 249 |
return _tool_prompt
|
| 250 |
|
| 251 |
def _stream_chat(self, _model_name, _prompt, _system="", **_kwargs):
|
| 252 |
-
"""Stream chat using Replicate's streaming API"""
|
| 253 |
_replicate_model = self._get_replicate_model(_model_name)
|
| 254 |
_params = self._sanitize_params(**_kwargs)
|
| 255 |
|
|
@@ -261,14 +266,73 @@ class _RC:
|
|
| 261 |
"top_p": _params['top_p']
|
| 262 |
}
|
| 263 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
try:
|
| 265 |
-
# Use Replicate's streaming method
|
| 266 |
for _event in self._client.stream(_replicate_model, input=_input):
|
| 267 |
-
if _event:
|
| 268 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
except Exception as _e:
|
| 270 |
_lg.error(f"Streaming error for {_replicate_model}: {_e}")
|
| 271 |
-
|
|
|
|
| 272 |
|
| 273 |
def _stream_from_prediction(self, _prediction):
|
| 274 |
"""Stream from a prediction using the stream URL"""
|
|
@@ -303,7 +367,7 @@ class _RC:
|
|
| 303 |
yield f"Error: {_e}"
|
| 304 |
|
| 305 |
def _complete_chat(self, _model_name, _prompt, _system="", **_kwargs):
|
| 306 |
-
"""Complete chat using Replicate's run method"""
|
| 307 |
_replicate_model = self._get_replicate_model(_model_name)
|
| 308 |
_params = self._sanitize_params(**_kwargs)
|
| 309 |
|
|
@@ -315,12 +379,42 @@ class _RC:
|
|
| 315 |
"top_p": _params['top_p']
|
| 316 |
}
|
| 317 |
|
|
|
|
|
|
|
|
|
|
| 318 |
try:
|
| 319 |
_result = self._client.run(_replicate_model, input=_input)
|
| 320 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
except Exception as _e:
|
| 322 |
_lg.error(f"Completion error for {_replicate_model}: {_e}")
|
| 323 |
-
|
|
|
|
| 324 |
|
| 325 |
# Global variables
|
| 326 |
_client = None
|
|
@@ -492,7 +586,8 @@ async def _generate_stream_response(_request: _CCR, _prompt: str, _system: str,
|
|
| 492 |
'temperature': _request.temperature,
|
| 493 |
'top_p': _request.top_p,
|
| 494 |
'presence_penalty': _request.presence_penalty,
|
| 495 |
-
'frequency_penalty': _request.frequency_penalty
|
|
|
|
| 496 |
}
|
| 497 |
|
| 498 |
# Use Replicate's direct streaming method with model parameter
|
|
@@ -617,11 +712,10 @@ async def _create_chat_completion(_request: _CCR):
|
|
| 617 |
_lg.info(f"[{_request_id}] Starting streaming response")
|
| 618 |
return _SR(
|
| 619 |
_generate_stream_response(_request, _prompt, _system, _request_id),
|
| 620 |
-
media_type="text/
|
| 621 |
headers={
|
| 622 |
"Cache-Control": "no-cache",
|
| 623 |
-
"Connection": "keep-alive"
|
| 624 |
-
"Content-Type": "text/event-stream"
|
| 625 |
}
|
| 626 |
)
|
| 627 |
else:
|
|
|
|
| 173 |
# Handle max_tokens
|
| 174 |
_max_tokens = _kwargs.get('max_tokens')
|
| 175 |
if _max_tokens is not None and _max_tokens > 0:
|
| 176 |
+
# Replicate Anthropic models often require >= 1024; clamp to avoid 422s
|
| 177 |
+
try:
|
| 178 |
+
_mt = int(_max_tokens)
|
| 179 |
+
except Exception:
|
| 180 |
+
_mt = 4096
|
| 181 |
+
_params['max_tokens'] = max(1024, _mt)
|
| 182 |
else:
|
| 183 |
_params['max_tokens'] = 4096
|
| 184 |
|
|
|
|
| 254 |
return _tool_prompt
|
| 255 |
|
| 256 |
def _stream_chat(self, _model_name, _prompt, _system="", **_kwargs):
|
| 257 |
+
"""Stream chat using Replicate's streaming API, yielding only text chunks."""
|
| 258 |
_replicate_model = self._get_replicate_model(_model_name)
|
| 259 |
_params = self._sanitize_params(**_kwargs)
|
| 260 |
|
|
|
|
| 266 |
"top_p": _params['top_p']
|
| 267 |
}
|
| 268 |
|
| 269 |
+
# pass through stop sequences if provided
|
| 270 |
+
if 'stop' in _kwargs and _kwargs['stop'] is not None:
|
| 271 |
+
_input["stop"] = _kwargs['stop']
|
| 272 |
+
|
| 273 |
try:
|
|
|
|
| 274 |
for _event in self._client.stream(_replicate_model, input=_input):
|
| 275 |
+
if not _event:
|
| 276 |
+
continue
|
| 277 |
+
|
| 278 |
+
# Fast path: plain string/bytes token
|
| 279 |
+
if isinstance(_event, (str, bytes)):
|
| 280 |
+
yield (_event.decode('utf-8', errors='ignore') if isinstance(_event, bytes) else _event)
|
| 281 |
+
continue
|
| 282 |
+
|
| 283 |
+
# Normalize event interfaces (object, dict, or custom)
|
| 284 |
+
_etype, _edata = None, None
|
| 285 |
+
if isinstance(_event, dict):
|
| 286 |
+
_etype = _event.get('type') or _event.get('event')
|
| 287 |
+
_edata = _event.get('data') or _event.get('output') or _event.get('text')
|
| 288 |
+
else:
|
| 289 |
+
_etype = getattr(_event, 'type', None) or getattr(_event, 'event', None)
|
| 290 |
+
_edata = getattr(_event, 'data', None)
|
| 291 |
+
|
| 292 |
+
# Extract text payloads
|
| 293 |
+
if _etype == "output" or _edata is not None:
|
| 294 |
+
if isinstance(_edata, (list, tuple)):
|
| 295 |
+
for _piece in _edata:
|
| 296 |
+
if isinstance(_piece, (str, bytes)):
|
| 297 |
+
yield (_piece.decode('utf-8', errors='ignore') if isinstance(_piece, bytes) else _piece)
|
| 298 |
+
elif isinstance(_edata, (str, bytes)):
|
| 299 |
+
yield (_edata.decode('utf-8', errors='ignore') if isinstance(_edata, bytes) else _edata)
|
| 300 |
+
elif isinstance(_edata, dict):
|
| 301 |
+
# Common nested keys
|
| 302 |
+
for _k in ("text", "output", "delta"):
|
| 303 |
+
if _k in _edata and isinstance(_edata[_k], (str, bytes)):
|
| 304 |
+
_v = _edata[_k]
|
| 305 |
+
yield (_v.decode('utf-8', errors='ignore') if isinstance(_v, bytes) else _v)
|
| 306 |
+
break
|
| 307 |
+
elif _etype in {"completed", "done", "end"}:
|
| 308 |
+
break
|
| 309 |
+
else:
|
| 310 |
+
# Fallback to string form (restore old working behavior)
|
| 311 |
+
try:
|
| 312 |
+
_s = str(_event)
|
| 313 |
+
if _s:
|
| 314 |
+
yield _s
|
| 315 |
+
except Exception:
|
| 316 |
+
pass
|
| 317 |
+
elif _etype in {"error", "logs", "warning"}:
|
| 318 |
+
try:
|
| 319 |
+
_lg.warning(f"Replicate stream {_etype}: {_edata}")
|
| 320 |
+
except Exception:
|
| 321 |
+
pass
|
| 322 |
+
elif _etype in {"completed", "done", "end"}:
|
| 323 |
+
break
|
| 324 |
+
else:
|
| 325 |
+
# Unknown/eventless object; fallback to string form
|
| 326 |
+
try:
|
| 327 |
+
_s = str(_event)
|
| 328 |
+
if _s:
|
| 329 |
+
yield _s
|
| 330 |
+
except Exception:
|
| 331 |
+
pass
|
| 332 |
except Exception as _e:
|
| 333 |
_lg.error(f"Streaming error for {_replicate_model}: {_e}")
|
| 334 |
+
# Surface a minimal safe error token
|
| 335 |
+
yield ""
|
| 336 |
|
| 337 |
def _stream_from_prediction(self, _prediction):
|
| 338 |
"""Stream from a prediction using the stream URL"""
|
|
|
|
| 367 |
yield f"Error: {_e}"
|
| 368 |
|
| 369 |
def _complete_chat(self, _model_name, _prompt, _system="", **_kwargs):
|
| 370 |
+
"""Complete chat using Replicate's run method and coalesce into a single string."""
|
| 371 |
_replicate_model = self._get_replicate_model(_model_name)
|
| 372 |
_params = self._sanitize_params(**_kwargs)
|
| 373 |
|
|
|
|
| 379 |
"top_p": _params['top_p']
|
| 380 |
}
|
| 381 |
|
| 382 |
+
if 'stop' in _kwargs and _kwargs['stop'] is not None:
|
| 383 |
+
_input["stop"] = _kwargs['stop']
|
| 384 |
+
|
| 385 |
try:
|
| 386 |
_result = self._client.run(_replicate_model, input=_input)
|
| 387 |
+
|
| 388 |
+
# If it's a list of strings or chunks, join
|
| 389 |
+
if isinstance(_result, list):
|
| 390 |
+
_joined = "".join([x.decode("utf-8", errors="ignore") if isinstance(x, bytes) else str(x) for x in _result])
|
| 391 |
+
return _joined
|
| 392 |
+
|
| 393 |
+
# Some models return generators/iterables; accumulate
|
| 394 |
+
try:
|
| 395 |
+
from collections.abc import Iterator, Iterable
|
| 396 |
+
if isinstance(_result, Iterator) or (
|
| 397 |
+
isinstance(_result, Iterable) and not isinstance(_result, (str, bytes))
|
| 398 |
+
):
|
| 399 |
+
_buf = []
|
| 400 |
+
for _piece in _result:
|
| 401 |
+
if isinstance(_piece, (str, bytes)):
|
| 402 |
+
_buf.append(_piece.decode("utf-8", errors="ignore") if isinstance(_piece, bytes) else _piece)
|
| 403 |
+
else:
|
| 404 |
+
_buf.append(str(_piece))
|
| 405 |
+
_text = "".join(_buf)
|
| 406 |
+
if _text:
|
| 407 |
+
return _text
|
| 408 |
+
except Exception:
|
| 409 |
+
pass
|
| 410 |
+
|
| 411 |
+
# FileOutput or scalar: cast to string; if empty, safe fallback
|
| 412 |
+
_text = str(_result) if _result is not None else ""
|
| 413 |
+
return _text
|
| 414 |
except Exception as _e:
|
| 415 |
_lg.error(f"Completion error for {_replicate_model}: {_e}")
|
| 416 |
+
# Return empty to avoid leaking internals into user-visible content
|
| 417 |
+
return ""
|
| 418 |
|
| 419 |
# Global variables
|
| 420 |
_client = None
|
|
|
|
| 586 |
'temperature': _request.temperature,
|
| 587 |
'top_p': _request.top_p,
|
| 588 |
'presence_penalty': _request.presence_penalty,
|
| 589 |
+
'frequency_penalty': _request.frequency_penalty,
|
| 590 |
+
'stop': _request.stop
|
| 591 |
}
|
| 592 |
|
| 593 |
# Use Replicate's direct streaming method with model parameter
|
|
|
|
| 712 |
_lg.info(f"[{_request_id}] Starting streaming response")
|
| 713 |
return _SR(
|
| 714 |
_generate_stream_response(_request, _prompt, _system, _request_id),
|
| 715 |
+
media_type="text/event-stream",
|
| 716 |
headers={
|
| 717 |
"Cache-Control": "no-cache",
|
| 718 |
+
"Connection": "keep-alive"
|
|
|
|
| 719 |
}
|
| 720 |
)
|
| 721 |
else:
|