d221's picture
Update app.py
e22dcdd verified
import os
import gradio as gr
import openai
import requests
from huggingface_hub import InferenceClient
###############################################################################
# 1. List of Models: Some open-source (HF), some require paid API (OpenAI)
###############################################################################
MODEL_OPTIONS = [
# Open-Source (Hugging Face)
"Open-Source: bigscience/bloom-560m",
"Open-Source: tiiuae/falcon-7b-instruct",
"Open-Source: openlm-research/open_llama_7b",
# Paid (OpenAI) - require a valid OPENAI API key
"OpenAI: gpt-3.5-turbo",
"OpenAI: gpt-4",
]
###############################################################################
# 2. Chat function
###############################################################################
def chat_with_model(
user_message, # user's text input
history, # chat history (handled by ChatInterface)
system_message, # system instructions
chosen_model, # which model from the dropdown
user_model_api_key, # user-supplied API key for the chosen model
max_tokens,
temperature,
top_p
):
"""
Depending on the user’s chosen model:
- If it starts with "Open-Source:", we call Hugging Face InferenceClient
- If it starts with "OpenAI:", we call the OpenAI ChatCompletion endpoint
For open-source, the API key can be left blank (anonymous).
For paid, an API key must be supplied.
"""
# Standard system text (if user left it empty, we provide a default)
system_text = system_message.strip() or "You are a helpful AI assistant."
# We'll build partial output as we stream
partial_response = ""
###############################
# CASE A: OPEN-SOURCE (HF)
###############################
if chosen_model.startswith("Open-Source:"):
# Extract the actual HF model name
hf_model = chosen_model.split("Open-Source:")[1].strip()
# If the user gave an API key, we use it; otherwise None => anonymous
hf_token = user_model_api_key.strip() if user_model_api_key else None
client = InferenceClient(token=hf_token)
# Build a naive prompt
prompt = (
f"{system_text}\n\n"
f"User: {user_message}\n"
"Assistant:"
)
generation_params = dict(
temperature=temperature,
max_new_tokens=int(max_tokens),
top_p=top_p,
repetition_penalty=1.0
)
try:
response_stream = client.text_generation(
prompt=prompt,
model=hf_model,
stream=True,
details=True,
**generation_params
)
for chunk in response_stream:
if chunk.token.special:
continue
partial_response += chunk.token.text
yield partial_response
except Exception as e:
yield f"Error calling Hugging Face Inference API: {str(e)}"
return
###############################
# CASE B: OPENAI
###############################
else:
# Must have an API key
if not user_model_api_key.strip():
yield "Error: This model requires a paid API key. Please provide a valid one."
return
openai.api_key = user_model_api_key.strip()
openai_model_name = chosen_model.split("OpenAI:")[1].strip() # e.g. "gpt-4"
# Build OpenAI chat messages
messages = [
{"role": "system", "content": system_text},
{"role": "user", "content": user_message}
]
try:
response = openai.ChatCompletion.create(
model=openai_model_name,
messages=messages,
temperature=temperature,
max_tokens=int(max_tokens),
top_p=top_p,
stream=True
)
for chunk in response:
if "choices" in chunk and len(chunk["choices"]) > 0:
delta = chunk["choices"][0]["delta"]
if "content" in delta:
partial_response += delta["content"]
yield partial_response
except Exception as e:
yield f"Error calling OpenAI API: {str(e)}"
return
###############################################################################
# 3. Build the Gradio Interface
###############################################################################
with gr.Blocks() as demo:
gr.Markdown(
"""
# Multi-Model Chatbot
Choose from open-source or paid models, and provide an API key if needed.
"""
)
with gr.Row():
# Left column for parameters
with gr.Column(scale=1, min_width=300):
system_message = gr.Textbox(
label="System Message",
value="You are a helpful open-source AI assistant.",
lines=3,
)
# Let user pick model first
chosen_model = gr.Dropdown(
label="Select a Model for Your ChatBot",
choices=MODEL_OPTIONS,
value=MODEL_OPTIONS[0], # default to the first
info="Open-Source models can be used anonymously. Paid models require a valid API key."
)
# Then the API key tile
user_model_api_key = gr.Textbox(
label="API Key for the Chosen Model",
placeholder="Required if you selected a paid model like GPT-4; optional otherwise",
type="password"
)
max_tokens = gr.Slider(
label="Max Tokens",
minimum=1,
maximum=2000,
value=512,
step=1
)
temperature = gr.Slider(
label="Temperature",
minimum=0.0,
maximum=2.0,
value=0.7,
step=0.1
)
top_p = gr.Slider(
label="Top-p",
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.01
)
# Right column for the chat interface
with gr.Column(scale=3):
chatbot = gr.ChatInterface(
fn=chat_with_model,
# extra inputs
additional_inputs=[
system_message,
chosen_model,
user_model_api_key,
max_tokens,
temperature,
top_p
],
type="messages"
)
demo.launch()