Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -3,7 +3,6 @@ import textwrap
|
|
| 3 |
import datetime
|
| 4 |
import json
|
| 5 |
import gradio as gr
|
| 6 |
-
from openai import OpenAI
|
| 7 |
import urllib.request
|
| 8 |
import feedparser
|
| 9 |
import time
|
|
@@ -369,12 +368,18 @@ class LLM:
|
|
| 369 |
def __init__(self, max_model_len: int = 4096):
|
| 370 |
self.api_key = OAI_API_KEY
|
| 371 |
self.max_model_len = max_model_len
|
| 372 |
-
self.
|
| 373 |
-
#models_list = self.client.models.list()
|
| 374 |
-
#self.model_name = models_list.data[0].id
|
| 375 |
self.model_name = MODEL_NAME
|
| 376 |
|
| 377 |
def generate(self, prompt: str, sampling_params: dict) -> dict:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
completion_params = {
|
| 379 |
"model": self.model_name,
|
| 380 |
"prompt": prompt,
|
|
@@ -385,14 +390,47 @@ class LLM:
|
|
| 385 |
"stream": False,
|
| 386 |
}
|
| 387 |
|
|
|
|
| 388 |
if "stop" in sampling_params:
|
| 389 |
completion_params["stop"] = sampling_params["stop"]
|
| 390 |
if "presence_penalty" in sampling_params:
|
| 391 |
completion_params["presence_penalty"] = sampling_params["presence_penalty"]
|
| 392 |
if "frequency_penalty" in sampling_params:
|
| 393 |
completion_params["frequency_penalty"] = sampling_params["frequency_penalty"]
|
| 394 |
-
|
| 395 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 396 |
|
| 397 |
def form_chat_prompt(message_history, functions=functions_dict.keys()):
|
| 398 |
"""Builds the chat prompt for the LLM."""
|
|
|
|
| 3 |
import datetime
|
| 4 |
import json
|
| 5 |
import gradio as gr
|
|
|
|
| 6 |
import urllib.request
|
| 7 |
import feedparser
|
| 8 |
import time
|
|
|
|
| 368 |
def __init__(self, max_model_len: int = 4096):
|
| 369 |
self.api_key = OAI_API_KEY
|
| 370 |
self.max_model_len = max_model_len
|
| 371 |
+
self.endpoint_url = ENDPOINT_URL
|
|
|
|
|
|
|
| 372 |
self.model_name = MODEL_NAME
|
| 373 |
|
| 374 |
def generate(self, prompt: str, sampling_params: dict) -> dict:
|
| 375 |
+
"""
|
| 376 |
+
Generate completion using direct HTTP request instead of OpenAI SDK.
|
| 377 |
+
"""
|
| 378 |
+
headers = {
|
| 379 |
+
"Content-Type": "application/json",
|
| 380 |
+
"Authorization": f"Bearer {self.api_key}"
|
| 381 |
+
}
|
| 382 |
+
|
| 383 |
completion_params = {
|
| 384 |
"model": self.model_name,
|
| 385 |
"prompt": prompt,
|
|
|
|
| 390 |
"stream": False,
|
| 391 |
}
|
| 392 |
|
| 393 |
+
# Add optional parameters if present
|
| 394 |
if "stop" in sampling_params:
|
| 395 |
completion_params["stop"] = sampling_params["stop"]
|
| 396 |
if "presence_penalty" in sampling_params:
|
| 397 |
completion_params["presence_penalty"] = sampling_params["presence_penalty"]
|
| 398 |
if "frequency_penalty" in sampling_params:
|
| 399 |
completion_params["frequency_penalty"] = sampling_params["frequency_penalty"]
|
| 400 |
+
|
| 401 |
+
# Add stop_token_ids if supported by Hyperbolic
|
| 402 |
+
if "stop_token_ids" in sampling_params:
|
| 403 |
+
completion_params["stop_token_ids"] = sampling_params["stop_token_ids"]
|
| 404 |
+
|
| 405 |
+
url = f"{self.endpoint_url}/completions"
|
| 406 |
+
|
| 407 |
+
try:
|
| 408 |
+
response = requests.post(url, headers=headers, json=completion_params)
|
| 409 |
+
response.raise_for_status()
|
| 410 |
+
|
| 411 |
+
# Format response to match expected structure
|
| 412 |
+
response_data = response.json()
|
| 413 |
+
|
| 414 |
+
# Create a response object that matches the OpenAI completion format
|
| 415 |
+
class CompletionResponse:
|
| 416 |
+
def __init__(self, data):
|
| 417 |
+
self.choices = []
|
| 418 |
+
if "choices" in data:
|
| 419 |
+
for choice in data["choices"]:
|
| 420 |
+
self.choices.append(type('Choice', (), {
|
| 421 |
+
'text': choice.get('text', ''),
|
| 422 |
+
'index': choice.get('index', 0),
|
| 423 |
+
'finish_reason': choice.get('finish_reason', None)
|
| 424 |
+
})())
|
| 425 |
+
|
| 426 |
+
return CompletionResponse(response_data)
|
| 427 |
+
|
| 428 |
+
except requests.exceptions.RequestException as e:
|
| 429 |
+
lgs(f"Request failed: {e}")
|
| 430 |
+
if hasattr(e, 'response') and e.response is not None:
|
| 431 |
+
lgs(f"Response status: {e.response.status_code}")
|
| 432 |
+
lgs(f"Response body: {e.response.text}")
|
| 433 |
+
raise Exception(f"API request failed: {str(e)}")
|
| 434 |
|
| 435 |
def form_chat_prompt(message_history, functions=functions_dict.keys()):
|
| 436 |
"""Builds the chat prompt for the LLM."""
|