Samfy001 commited on
Commit
b099426
·
verified ·
1 Parent(s): 8cd1e07

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -11
app.py CHANGED
@@ -157,16 +157,58 @@ class _RC:
157
  _prompt += "Assistant: "
158
  return _prompt, _system
159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  def _create_prediction(self, _model_name, _prompt, _system="", **_kwargs):
161
  """Create a prediction using Replicate API"""
162
  _replicate_model = self._get_replicate_model(_model_name)
 
163
 
164
  _input = {
165
  "prompt": _prompt,
166
  "system_prompt": _system,
167
- "max_tokens": _kwargs.get('max_tokens', 4096),
168
- "temperature": _kwargs.get('temperature', 0.7),
169
- "top_p": _kwargs.get('top_p', 1.0)
170
  }
171
 
172
  try:
@@ -200,13 +242,14 @@ class _RC:
200
  def _stream_chat(self, _model_name, _prompt, _system="", **_kwargs):
201
  """Stream chat using Replicate's streaming API"""
202
  _replicate_model = self._get_replicate_model(_model_name)
 
203
 
204
  _input = {
205
  "prompt": _prompt,
206
  "system_prompt": _system,
207
- "max_tokens": _kwargs.get('max_tokens', 4096),
208
- "temperature": _kwargs.get('temperature', 0.7),
209
- "top_p": _kwargs.get('top_p', 1.0)
210
  }
211
 
212
  try:
@@ -253,13 +296,14 @@ class _RC:
253
  def _complete_chat(self, _model_name, _prompt, _system="", **_kwargs):
254
  """Complete chat using Replicate's run method"""
255
  _replicate_model = self._get_replicate_model(_model_name)
 
256
 
257
  _input = {
258
  "prompt": _prompt,
259
  "system_prompt": _system,
260
- "max_tokens": _kwargs.get('max_tokens', 4096),
261
- "temperature": _kwargs.get('temperature', 0.7),
262
- "top_p": _kwargs.get('top_p', 1.0)
263
  }
264
 
265
  try:
@@ -433,8 +477,17 @@ async def _generate_stream_response(_request: _CCR, _prompt: str, _system: str,
433
  _total_content = ""
434
 
435
  try:
 
 
 
 
 
 
 
 
 
436
  # Use Replicate's direct streaming method with model parameter
437
- for _chunk in _client._stream_chat(_request.model, _prompt, _system, **_request.model_dump()):
438
  if _chunk and isinstance(_chunk, str):
439
  _chunk_count += 1
440
  _total_content += _chunk
@@ -539,6 +592,17 @@ async def _create_chat_completion(_request: _CCR):
539
 
540
  _lg.info(f"[{_request_id}] Formatted prompt length: {len(_prompt)}")
541
 
 
 
 
 
 
 
 
 
 
 
 
542
  # Stream or complete
543
  if _request.stream:
544
  _lg.info(f"[{_request_id}] Starting streaming response")
@@ -554,7 +618,7 @@ async def _create_chat_completion(_request: _CCR):
554
  else:
555
  # Non-streaming completion
556
  _lg.info(f"[{_request_id}] Starting non-streaming completion")
557
- _content = _client._complete_chat(_request.model, _prompt, _system, **_request.model_dump())
558
 
559
  _completion_id = f"chatcmpl-{_u.uuid4().hex}"
560
  _created_time = int(_t.time())
 
157
  _prompt += "Assistant: "
158
  return _prompt, _system
159
 
160
+ def _sanitize_params(self, **_kwargs):
161
+ """Sanitize parameters and set proper defaults"""
162
+ _params = {}
163
+
164
+ # Handle max_tokens
165
+ _max_tokens = _kwargs.get('max_tokens')
166
+ if _max_tokens is not None and _max_tokens > 0:
167
+ _params['max_tokens'] = _max_tokens
168
+ else:
169
+ _params['max_tokens'] = 4096
170
+
171
+ # Handle temperature
172
+ _temperature = _kwargs.get('temperature')
173
+ if _temperature is not None:
174
+ _params['temperature'] = max(0.0, min(2.0, float(_temperature)))
175
+ else:
176
+ _params['temperature'] = 0.7
177
+
178
+ # Handle top_p
179
+ _top_p = _kwargs.get('top_p')
180
+ if _top_p is not None:
181
+ _params['top_p'] = max(0.0, min(1.0, float(_top_p)))
182
+ else:
183
+ _params['top_p'] = 1.0
184
+
185
+ # Handle presence_penalty
186
+ _presence_penalty = _kwargs.get('presence_penalty')
187
+ if _presence_penalty is not None:
188
+ _params['presence_penalty'] = max(-2.0, min(2.0, float(_presence_penalty)))
189
+ else:
190
+ _params['presence_penalty'] = 0.0
191
+
192
+ # Handle frequency_penalty
193
+ _frequency_penalty = _kwargs.get('frequency_penalty')
194
+ if _frequency_penalty is not None:
195
+ _params['frequency_penalty'] = max(-2.0, min(2.0, float(_frequency_penalty)))
196
+ else:
197
+ _params['frequency_penalty'] = 0.0
198
+
199
+ return _params
200
+
201
  def _create_prediction(self, _model_name, _prompt, _system="", **_kwargs):
202
  """Create a prediction using Replicate API"""
203
  _replicate_model = self._get_replicate_model(_model_name)
204
+ _params = self._sanitize_params(**_kwargs)
205
 
206
  _input = {
207
  "prompt": _prompt,
208
  "system_prompt": _system,
209
+ "max_tokens": _params['max_tokens'],
210
+ "temperature": _params['temperature'],
211
+ "top_p": _params['top_p']
212
  }
213
 
214
  try:
 
242
  def _stream_chat(self, _model_name, _prompt, _system="", **_kwargs):
243
  """Stream chat using Replicate's streaming API"""
244
  _replicate_model = self._get_replicate_model(_model_name)
245
+ _params = self._sanitize_params(**_kwargs)
246
 
247
  _input = {
248
  "prompt": _prompt,
249
  "system_prompt": _system,
250
+ "max_tokens": _params['max_tokens'],
251
+ "temperature": _params['temperature'],
252
+ "top_p": _params['top_p']
253
  }
254
 
255
  try:
 
296
  def _complete_chat(self, _model_name, _prompt, _system="", **_kwargs):
297
  """Complete chat using Replicate's run method"""
298
  _replicate_model = self._get_replicate_model(_model_name)
299
+ _params = self._sanitize_params(**_kwargs)
300
 
301
  _input = {
302
  "prompt": _prompt,
303
  "system_prompt": _system,
304
+ "max_tokens": _params['max_tokens'],
305
+ "temperature": _params['temperature'],
306
+ "top_p": _params['top_p']
307
  }
308
 
309
  try:
 
477
  _total_content = ""
478
 
479
  try:
480
+ # Extract only relevant parameters for Replicate API
481
+ _api_params = {
482
+ 'max_tokens': _request.max_tokens,
483
+ 'temperature': _request.temperature,
484
+ 'top_p': _request.top_p,
485
+ 'presence_penalty': _request.presence_penalty,
486
+ 'frequency_penalty': _request.frequency_penalty
487
+ }
488
+
489
  # Use Replicate's direct streaming method with model parameter
490
+ for _chunk in _client._stream_chat(_request.model, _prompt, _system, **_api_params):
491
  if _chunk and isinstance(_chunk, str):
492
  _chunk_count += 1
493
  _total_content += _chunk
 
592
 
593
  _lg.info(f"[{_request_id}] Formatted prompt length: {len(_prompt)}")
594
 
595
+ # Extract only relevant parameters for Replicate API
596
+ _api_params = {
597
+ 'max_tokens': _request.max_tokens,
598
+ 'temperature': _request.temperature,
599
+ 'top_p': _request.top_p,
600
+ 'presence_penalty': _request.presence_penalty,
601
+ 'frequency_penalty': _request.frequency_penalty
602
+ }
603
+
604
+ _lg.info(f"[{_request_id}] API parameters: {_api_params}")
605
+
606
  # Stream or complete
607
  if _request.stream:
608
  _lg.info(f"[{_request_id}] Starting streaming response")
 
618
  else:
619
  # Non-streaming completion
620
  _lg.info(f"[{_request_id}] Starting non-streaming completion")
621
+ _content = _client._complete_chat(_request.model, _prompt, _system, **_api_params)
622
 
623
  _completion_id = f"chatcmpl-{_u.uuid4().hex}"
624
  _created_time = int(_t.time())