Spaces:
Sleeping
Sleeping
File size: 3,355 Bytes
a138eb5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import requests
import json
from openai import OpenAI
from config.config import Config
def get_api_config(model_name):
"""
Get API base URL and API key based on the model name.
"""
if model_name.startswith("meta-llama/"):
return Config.HOSTED_BASE_URL, Config.HOSTED_API_KEY
elif model_name == "llama3.1":
return Config.LOCAL_BASE_URL, None
else:
raise ValueError(f"Invalid model name: {model_name}")
def handle_hosted_request(client, model_name, messages, container):
"""
Handles the hosted Llama 3.1 model requests via OpenAI's API.
"""
try:
stream = client.chat.completions.create(
model=model_name,
messages=messages,
stream=True,
)
response_placeholder = container.empty()
full_response = ""
for chunk in stream:
if chunk.choices[0].delta.content is not None:
full_response += chunk.choices[0].delta.content
response_placeholder.markdown(full_response + "▌")
response_placeholder.markdown(full_response)
return full_response
except Exception as e:
error_message = f"API Error: {str(e)}"
container.error(error_message)
return None
def handle_local_request(base_url, model_name, messages, container):
"""
Handles requests to the locally hosted Llama 3.1 model.
"""
try:
payload = {
"model": model_name,
"messages": messages,
"stream": True,
}
headers = {"Content-Type": "application/json"}
response_placeholder = container.empty()
full_response = ""
with requests.post(
base_url, json=payload, headers=headers, stream=True
) as response:
response.raise_for_status()
for line in response.iter_lines():
if line:
try:
chunk = json.loads(line)
if "done" in chunk and chunk["done"]:
break
if "message" in chunk and "content" in chunk["message"]:
content = chunk["message"]["content"]
full_response += content
response_placeholder.markdown(full_response + "▌")
except json.JSONDecodeError:
pass
response_placeholder.markdown(full_response)
return full_response
except requests.RequestException as e:
error_message = f"API Error: {str(e)}"
container.error(error_message)
return None
def stream_response(messages, container, model_name):
"""
This function handles the API request based on the model (hosted or local) and streams the response.
"""
base_url, api_key = get_api_config(model_name)
if model_name.startswith("meta-llama/"):
client = OpenAI(api_key=api_key, base_url=base_url)
return handle_hosted_request(client, model_name, messages, container)
elif model_name == "llama3.1":
return handle_local_request(base_url, model_name, messages, container)
else:
raise ValueError("Unsupported model selected.") |