Update app.py
Browse files
app.py
CHANGED
|
@@ -157,16 +157,58 @@ class _RC:
|
|
| 157 |
_prompt += "Assistant: "
|
| 158 |
return _prompt, _system
|
| 159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
def _create_prediction(self, _model_name, _prompt, _system="", **_kwargs):
|
| 161 |
"""Create a prediction using Replicate API"""
|
| 162 |
_replicate_model = self._get_replicate_model(_model_name)
|
|
|
|
| 163 |
|
| 164 |
_input = {
|
| 165 |
"prompt": _prompt,
|
| 166 |
"system_prompt": _system,
|
| 167 |
-
"max_tokens":
|
| 168 |
-
"temperature":
|
| 169 |
-
"top_p":
|
| 170 |
}
|
| 171 |
|
| 172 |
try:
|
|
@@ -200,13 +242,14 @@ class _RC:
|
|
| 200 |
def _stream_chat(self, _model_name, _prompt, _system="", **_kwargs):
|
| 201 |
"""Stream chat using Replicate's streaming API"""
|
| 202 |
_replicate_model = self._get_replicate_model(_model_name)
|
|
|
|
| 203 |
|
| 204 |
_input = {
|
| 205 |
"prompt": _prompt,
|
| 206 |
"system_prompt": _system,
|
| 207 |
-
"max_tokens":
|
| 208 |
-
"temperature":
|
| 209 |
-
"top_p":
|
| 210 |
}
|
| 211 |
|
| 212 |
try:
|
|
@@ -253,13 +296,14 @@ class _RC:
|
|
| 253 |
def _complete_chat(self, _model_name, _prompt, _system="", **_kwargs):
|
| 254 |
"""Complete chat using Replicate's run method"""
|
| 255 |
_replicate_model = self._get_replicate_model(_model_name)
|
|
|
|
| 256 |
|
| 257 |
_input = {
|
| 258 |
"prompt": _prompt,
|
| 259 |
"system_prompt": _system,
|
| 260 |
-
"max_tokens":
|
| 261 |
-
"temperature":
|
| 262 |
-
"top_p":
|
| 263 |
}
|
| 264 |
|
| 265 |
try:
|
|
@@ -433,8 +477,17 @@ async def _generate_stream_response(_request: _CCR, _prompt: str, _system: str,
|
|
| 433 |
_total_content = ""
|
| 434 |
|
| 435 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
# Use Replicate's direct streaming method with model parameter
|
| 437 |
-
for _chunk in _client._stream_chat(_request.model, _prompt, _system, **
|
| 438 |
if _chunk and isinstance(_chunk, str):
|
| 439 |
_chunk_count += 1
|
| 440 |
_total_content += _chunk
|
|
@@ -539,6 +592,17 @@ async def _create_chat_completion(_request: _CCR):
|
|
| 539 |
|
| 540 |
_lg.info(f"[{_request_id}] Formatted prompt length: {len(_prompt)}")
|
| 541 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 542 |
# Stream or complete
|
| 543 |
if _request.stream:
|
| 544 |
_lg.info(f"[{_request_id}] Starting streaming response")
|
|
@@ -554,7 +618,7 @@ async def _create_chat_completion(_request: _CCR):
|
|
| 554 |
else:
|
| 555 |
# Non-streaming completion
|
| 556 |
_lg.info(f"[{_request_id}] Starting non-streaming completion")
|
| 557 |
-
_content = _client._complete_chat(_request.model, _prompt, _system, **
|
| 558 |
|
| 559 |
_completion_id = f"chatcmpl-{_u.uuid4().hex}"
|
| 560 |
_created_time = int(_t.time())
|
|
|
|
| 157 |
_prompt += "Assistant: "
|
| 158 |
return _prompt, _system
|
| 159 |
|
| 160 |
+
def _sanitize_params(self, **_kwargs):
|
| 161 |
+
"""Sanitize parameters and set proper defaults"""
|
| 162 |
+
_params = {}
|
| 163 |
+
|
| 164 |
+
# Handle max_tokens
|
| 165 |
+
_max_tokens = _kwargs.get('max_tokens')
|
| 166 |
+
if _max_tokens is not None and _max_tokens > 0:
|
| 167 |
+
_params['max_tokens'] = _max_tokens
|
| 168 |
+
else:
|
| 169 |
+
_params['max_tokens'] = 4096
|
| 170 |
+
|
| 171 |
+
# Handle temperature
|
| 172 |
+
_temperature = _kwargs.get('temperature')
|
| 173 |
+
if _temperature is not None:
|
| 174 |
+
_params['temperature'] = max(0.0, min(2.0, float(_temperature)))
|
| 175 |
+
else:
|
| 176 |
+
_params['temperature'] = 0.7
|
| 177 |
+
|
| 178 |
+
# Handle top_p
|
| 179 |
+
_top_p = _kwargs.get('top_p')
|
| 180 |
+
if _top_p is not None:
|
| 181 |
+
_params['top_p'] = max(0.0, min(1.0, float(_top_p)))
|
| 182 |
+
else:
|
| 183 |
+
_params['top_p'] = 1.0
|
| 184 |
+
|
| 185 |
+
# Handle presence_penalty
|
| 186 |
+
_presence_penalty = _kwargs.get('presence_penalty')
|
| 187 |
+
if _presence_penalty is not None:
|
| 188 |
+
_params['presence_penalty'] = max(-2.0, min(2.0, float(_presence_penalty)))
|
| 189 |
+
else:
|
| 190 |
+
_params['presence_penalty'] = 0.0
|
| 191 |
+
|
| 192 |
+
# Handle frequency_penalty
|
| 193 |
+
_frequency_penalty = _kwargs.get('frequency_penalty')
|
| 194 |
+
if _frequency_penalty is not None:
|
| 195 |
+
_params['frequency_penalty'] = max(-2.0, min(2.0, float(_frequency_penalty)))
|
| 196 |
+
else:
|
| 197 |
+
_params['frequency_penalty'] = 0.0
|
| 198 |
+
|
| 199 |
+
return _params
|
| 200 |
+
|
| 201 |
def _create_prediction(self, _model_name, _prompt, _system="", **_kwargs):
|
| 202 |
"""Create a prediction using Replicate API"""
|
| 203 |
_replicate_model = self._get_replicate_model(_model_name)
|
| 204 |
+
_params = self._sanitize_params(**_kwargs)
|
| 205 |
|
| 206 |
_input = {
|
| 207 |
"prompt": _prompt,
|
| 208 |
"system_prompt": _system,
|
| 209 |
+
"max_tokens": _params['max_tokens'],
|
| 210 |
+
"temperature": _params['temperature'],
|
| 211 |
+
"top_p": _params['top_p']
|
| 212 |
}
|
| 213 |
|
| 214 |
try:
|
|
|
|
| 242 |
def _stream_chat(self, _model_name, _prompt, _system="", **_kwargs):
|
| 243 |
"""Stream chat using Replicate's streaming API"""
|
| 244 |
_replicate_model = self._get_replicate_model(_model_name)
|
| 245 |
+
_params = self._sanitize_params(**_kwargs)
|
| 246 |
|
| 247 |
_input = {
|
| 248 |
"prompt": _prompt,
|
| 249 |
"system_prompt": _system,
|
| 250 |
+
"max_tokens": _params['max_tokens'],
|
| 251 |
+
"temperature": _params['temperature'],
|
| 252 |
+
"top_p": _params['top_p']
|
| 253 |
}
|
| 254 |
|
| 255 |
try:
|
|
|
|
| 296 |
def _complete_chat(self, _model_name, _prompt, _system="", **_kwargs):
|
| 297 |
"""Complete chat using Replicate's run method"""
|
| 298 |
_replicate_model = self._get_replicate_model(_model_name)
|
| 299 |
+
_params = self._sanitize_params(**_kwargs)
|
| 300 |
|
| 301 |
_input = {
|
| 302 |
"prompt": _prompt,
|
| 303 |
"system_prompt": _system,
|
| 304 |
+
"max_tokens": _params['max_tokens'],
|
| 305 |
+
"temperature": _params['temperature'],
|
| 306 |
+
"top_p": _params['top_p']
|
| 307 |
}
|
| 308 |
|
| 309 |
try:
|
|
|
|
| 477 |
_total_content = ""
|
| 478 |
|
| 479 |
try:
|
| 480 |
+
# Extract only relevant parameters for Replicate API
|
| 481 |
+
_api_params = {
|
| 482 |
+
'max_tokens': _request.max_tokens,
|
| 483 |
+
'temperature': _request.temperature,
|
| 484 |
+
'top_p': _request.top_p,
|
| 485 |
+
'presence_penalty': _request.presence_penalty,
|
| 486 |
+
'frequency_penalty': _request.frequency_penalty
|
| 487 |
+
}
|
| 488 |
+
|
| 489 |
# Use Replicate's direct streaming method with model parameter
|
| 490 |
+
for _chunk in _client._stream_chat(_request.model, _prompt, _system, **_api_params):
|
| 491 |
if _chunk and isinstance(_chunk, str):
|
| 492 |
_chunk_count += 1
|
| 493 |
_total_content += _chunk
|
|
|
|
| 592 |
|
| 593 |
_lg.info(f"[{_request_id}] Formatted prompt length: {len(_prompt)}")
|
| 594 |
|
| 595 |
+
# Extract only relevant parameters for Replicate API
|
| 596 |
+
_api_params = {
|
| 597 |
+
'max_tokens': _request.max_tokens,
|
| 598 |
+
'temperature': _request.temperature,
|
| 599 |
+
'top_p': _request.top_p,
|
| 600 |
+
'presence_penalty': _request.presence_penalty,
|
| 601 |
+
'frequency_penalty': _request.frequency_penalty
|
| 602 |
+
}
|
| 603 |
+
|
| 604 |
+
_lg.info(f"[{_request_id}] API parameters: {_api_params}")
|
| 605 |
+
|
| 606 |
# Stream or complete
|
| 607 |
if _request.stream:
|
| 608 |
_lg.info(f"[{_request_id}] Starting streaming response")
|
|
|
|
| 618 |
else:
|
| 619 |
# Non-streaming completion
|
| 620 |
_lg.info(f"[{_request_id}] Starting non-streaming completion")
|
| 621 |
+
_content = _client._complete_chat(_request.model, _prompt, _system, **_api_params)
|
| 622 |
|
| 623 |
_completion_id = f"chatcmpl-{_u.uuid4().hex}"
|
| 624 |
_created_time = int(_t.time())
|