Samfy001 commited on
Commit
32c3c8a
·
verified ·
1 Parent(s): 4863b3c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -13
app.py CHANGED
@@ -173,7 +173,12 @@ class _RC:
173
  # Handle max_tokens
174
  _max_tokens = _kwargs.get('max_tokens')
175
  if _max_tokens is not None and _max_tokens > 0:
176
- _params['max_tokens'] = _max_tokens
 
 
 
 
 
177
  else:
178
  _params['max_tokens'] = 4096
179
 
@@ -249,7 +254,7 @@ class _RC:
249
  return _tool_prompt
250
 
251
  def _stream_chat(self, _model_name, _prompt, _system="", **_kwargs):
252
- """Stream chat using Replicate's streaming API"""
253
  _replicate_model = self._get_replicate_model(_model_name)
254
  _params = self._sanitize_params(**_kwargs)
255
 
@@ -261,14 +266,73 @@ class _RC:
261
  "top_p": _params['top_p']
262
  }
263
 
 
 
 
 
264
  try:
265
- # Use Replicate's streaming method
266
  for _event in self._client.stream(_replicate_model, input=_input):
267
- if _event:
268
- yield str(_event)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  except Exception as _e:
270
  _lg.error(f"Streaming error for {_replicate_model}: {_e}")
271
- yield f"Error: {_e}"
 
272
 
273
  def _stream_from_prediction(self, _prediction):
274
  """Stream from a prediction using the stream URL"""
@@ -303,7 +367,7 @@ class _RC:
303
  yield f"Error: {_e}"
304
 
305
  def _complete_chat(self, _model_name, _prompt, _system="", **_kwargs):
306
- """Complete chat using Replicate's run method"""
307
  _replicate_model = self._get_replicate_model(_model_name)
308
  _params = self._sanitize_params(**_kwargs)
309
 
@@ -315,12 +379,42 @@ class _RC:
315
  "top_p": _params['top_p']
316
  }
317
 
 
 
 
318
  try:
319
  _result = self._client.run(_replicate_model, input=_input)
320
- return "".join(_result) if isinstance(_result, list) else str(_result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  except Exception as _e:
322
  _lg.error(f"Completion error for {_replicate_model}: {_e}")
323
- return f"Error: {_e}"
 
324
 
325
  # Global variables
326
  _client = None
@@ -492,7 +586,8 @@ async def _generate_stream_response(_request: _CCR, _prompt: str, _system: str,
492
  'temperature': _request.temperature,
493
  'top_p': _request.top_p,
494
  'presence_penalty': _request.presence_penalty,
495
- 'frequency_penalty': _request.frequency_penalty
 
496
  }
497
 
498
  # Use Replicate's direct streaming method with model parameter
@@ -617,11 +712,10 @@ async def _create_chat_completion(_request: _CCR):
617
  _lg.info(f"[{_request_id}] Starting streaming response")
618
  return _SR(
619
  _generate_stream_response(_request, _prompt, _system, _request_id),
620
- media_type="text/plain",
621
  headers={
622
  "Cache-Control": "no-cache",
623
- "Connection": "keep-alive",
624
- "Content-Type": "text/event-stream"
625
  }
626
  )
627
  else:
 
173
  # Handle max_tokens
174
  _max_tokens = _kwargs.get('max_tokens')
175
  if _max_tokens is not None and _max_tokens > 0:
176
+ # Replicate Anthropic models often require >= 1024; clamp to avoid 422s
177
+ try:
178
+ _mt = int(_max_tokens)
179
+ except Exception:
180
+ _mt = 4096
181
+ _params['max_tokens'] = max(1024, _mt)
182
  else:
183
  _params['max_tokens'] = 4096
184
 
 
254
  return _tool_prompt
255
 
256
  def _stream_chat(self, _model_name, _prompt, _system="", **_kwargs):
257
+ """Stream chat using Replicate's streaming API, yielding only text chunks."""
258
  _replicate_model = self._get_replicate_model(_model_name)
259
  _params = self._sanitize_params(**_kwargs)
260
 
 
266
  "top_p": _params['top_p']
267
  }
268
 
269
+ # pass through stop sequences if provided
270
+ if 'stop' in _kwargs and _kwargs['stop'] is not None:
271
+ _input["stop"] = _kwargs['stop']
272
+
273
  try:
 
274
  for _event in self._client.stream(_replicate_model, input=_input):
275
+ if not _event:
276
+ continue
277
+
278
+ # Fast path: plain string/bytes token
279
+ if isinstance(_event, (str, bytes)):
280
+ yield (_event.decode('utf-8', errors='ignore') if isinstance(_event, bytes) else _event)
281
+ continue
282
+
283
+ # Normalize event interfaces (object, dict, or custom)
284
+ _etype, _edata = None, None
285
+ if isinstance(_event, dict):
286
+ _etype = _event.get('type') or _event.get('event')
287
+ _edata = _event.get('data') or _event.get('output') or _event.get('text')
288
+ else:
289
+ _etype = getattr(_event, 'type', None) or getattr(_event, 'event', None)
290
+ _edata = getattr(_event, 'data', None)
291
+
292
+ # Extract text payloads
293
+ if _etype == "output" or _edata is not None:
294
+ if isinstance(_edata, (list, tuple)):
295
+ for _piece in _edata:
296
+ if isinstance(_piece, (str, bytes)):
297
+ yield (_piece.decode('utf-8', errors='ignore') if isinstance(_piece, bytes) else _piece)
298
+ elif isinstance(_edata, (str, bytes)):
299
+ yield (_edata.decode('utf-8', errors='ignore') if isinstance(_edata, bytes) else _edata)
300
+ elif isinstance(_edata, dict):
301
+ # Common nested keys
302
+ for _k in ("text", "output", "delta"):
303
+ if _k in _edata and isinstance(_edata[_k], (str, bytes)):
304
+ _v = _edata[_k]
305
+ yield (_v.decode('utf-8', errors='ignore') if isinstance(_v, bytes) else _v)
306
+ break
307
+ elif _etype in {"completed", "done", "end"}:
308
+ break
309
+ else:
310
+ # Fallback to string form (restore old working behavior)
311
+ try:
312
+ _s = str(_event)
313
+ if _s:
314
+ yield _s
315
+ except Exception:
316
+ pass
317
+ elif _etype in {"error", "logs", "warning"}:
318
+ try:
319
+ _lg.warning(f"Replicate stream {_etype}: {_edata}")
320
+ except Exception:
321
+ pass
322
+ elif _etype in {"completed", "done", "end"}:
323
+ break
324
+ else:
325
+ # Unknown/eventless object; fallback to string form
326
+ try:
327
+ _s = str(_event)
328
+ if _s:
329
+ yield _s
330
+ except Exception:
331
+ pass
332
  except Exception as _e:
333
  _lg.error(f"Streaming error for {_replicate_model}: {_e}")
334
+ # Surface a minimal safe error token
335
+ yield ""
336
 
337
  def _stream_from_prediction(self, _prediction):
338
  """Stream from a prediction using the stream URL"""
 
367
  yield f"Error: {_e}"
368
 
369
  def _complete_chat(self, _model_name, _prompt, _system="", **_kwargs):
370
+ """Complete chat using Replicate's run method and coalesce into a single string."""
371
  _replicate_model = self._get_replicate_model(_model_name)
372
  _params = self._sanitize_params(**_kwargs)
373
 
 
379
  "top_p": _params['top_p']
380
  }
381
 
382
+ if 'stop' in _kwargs and _kwargs['stop'] is not None:
383
+ _input["stop"] = _kwargs['stop']
384
+
385
  try:
386
  _result = self._client.run(_replicate_model, input=_input)
387
+
388
+ # If it's a list of strings or chunks, join
389
+ if isinstance(_result, list):
390
+ _joined = "".join([x.decode("utf-8", errors="ignore") if isinstance(x, bytes) else str(x) for x in _result])
391
+ return _joined
392
+
393
+ # Some models return generators/iterables; accumulate
394
+ try:
395
+ from collections.abc import Iterator, Iterable
396
+ if isinstance(_result, Iterator) or (
397
+ isinstance(_result, Iterable) and not isinstance(_result, (str, bytes))
398
+ ):
399
+ _buf = []
400
+ for _piece in _result:
401
+ if isinstance(_piece, (str, bytes)):
402
+ _buf.append(_piece.decode("utf-8", errors="ignore") if isinstance(_piece, bytes) else _piece)
403
+ else:
404
+ _buf.append(str(_piece))
405
+ _text = "".join(_buf)
406
+ if _text:
407
+ return _text
408
+ except Exception:
409
+ pass
410
+
411
+ # FileOutput or scalar: cast to string; if empty, safe fallback
412
+ _text = str(_result) if _result is not None else ""
413
+ return _text
414
  except Exception as _e:
415
  _lg.error(f"Completion error for {_replicate_model}: {_e}")
416
+ # Return empty to avoid leaking internals into user-visible content
417
+ return ""
418
 
419
  # Global variables
420
  _client = None
 
586
  'temperature': _request.temperature,
587
  'top_p': _request.top_p,
588
  'presence_penalty': _request.presence_penalty,
589
+ 'frequency_penalty': _request.frequency_penalty,
590
+ 'stop': _request.stop
591
  }
592
 
593
  # Use Replicate's direct streaming method with model parameter
 
712
  _lg.info(f"[{_request_id}] Starting streaming response")
713
  return _SR(
714
  _generate_stream_response(_request, _prompt, _system, _request_id),
715
+ media_type="text/event-stream",
716
  headers={
717
  "Cache-Control": "no-cache",
718
+ "Connection": "keep-alive"
 
719
  }
720
  )
721
  else: