File size: 6,817 Bytes
d7e41fd
 
e22dcdd
 
d7e41fd
 
e22dcdd
 
 
 
 
 
 
 
 
 
 
 
d7e41fd
 
e22dcdd
 
 
d7e41fd
e22dcdd
 
 
 
 
 
 
 
d7e41fd
 
e22dcdd
 
 
 
 
d7e41fd
 
e22dcdd
 
d7e41fd
e22dcdd
d7e41fd
 
e22dcdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7e41fd
 
e22dcdd
 
 
d7e41fd
 
e22dcdd
 
d7e41fd
 
e22dcdd
 
d7e41fd
 
e22dcdd
 
 
 
 
 
d7e41fd
 
e22dcdd
 
 
 
 
d7e41fd
 
 
e22dcdd
d7e41fd
 
 
e22dcdd
d7e41fd
 
e22dcdd
 
d7e41fd
 
e22dcdd
d7e41fd
 
e22dcdd
d7e41fd
 
 
e22dcdd
d7e41fd
 
e22dcdd
d7e41fd
 
 
e22dcdd
 
 
 
 
 
 
 
 
 
d7e41fd
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import os
import gradio as gr
import openai
import requests
from huggingface_hub import InferenceClient

###############################################################################
# 1. List of Models: Some open-source (HF), some require paid API (OpenAI)
###############################################################################
MODEL_OPTIONS = [
    # Open-Source (Hugging Face)
    "Open-Source: bigscience/bloom-560m", 
    "Open-Source: tiiuae/falcon-7b-instruct",
    "Open-Source: openlm-research/open_llama_7b",

    # Paid (OpenAI) - require a valid OPENAI API key
    "OpenAI: gpt-3.5-turbo",
    "OpenAI: gpt-4",
]

###############################################################################
# 2. Chat function
###############################################################################
def chat_with_model(
    user_message,        # user's text input
    history,             # chat history (handled by ChatInterface)
    system_message,      # system instructions
    chosen_model,        # which model from the dropdown
    user_model_api_key,  # user-supplied API key for the chosen model
    max_tokens,
    temperature,
    top_p
):
    """
    Depending on the user’s chosen model:
      - If it starts with "Open-Source:", we call Hugging Face InferenceClient
      - If it starts with "OpenAI:", we call the OpenAI ChatCompletion endpoint
    For open-source, the API key can be left blank (anonymous).
    For paid, an API key must be supplied.
    """

    # Standard system text (if user left it empty, we provide a default)
    system_text = system_message.strip() or "You are a helpful AI assistant."

    # We'll build partial output as we stream
    partial_response = ""

    ###############################
    # CASE A: OPEN-SOURCE (HF)
    ###############################
    if chosen_model.startswith("Open-Source:"):
        # Extract the actual HF model name
        hf_model = chosen_model.split("Open-Source:")[1].strip()

        # If the user gave an API key, we use it; otherwise None => anonymous
        hf_token = user_model_api_key.strip() if user_model_api_key else None
        client = InferenceClient(token=hf_token)

        # Build a naive prompt
        prompt = (
            f"{system_text}\n\n"
            f"User: {user_message}\n"
            "Assistant:"
        )

        generation_params = dict(
            temperature=temperature,
            max_new_tokens=int(max_tokens),
            top_p=top_p,
            repetition_penalty=1.0
        )

        try:
            response_stream = client.text_generation(
                prompt=prompt,
                model=hf_model,
                stream=True,
                details=True,
                **generation_params
            )
            for chunk in response_stream:
                if chunk.token.special:
                    continue
                partial_response += chunk.token.text
                yield partial_response

        except Exception as e:
            yield f"Error calling Hugging Face Inference API: {str(e)}"
            return

    ###############################
    # CASE B: OPENAI
    ###############################
    else:
        # Must have an API key
        if not user_model_api_key.strip():
            yield "Error: This model requires a paid API key. Please provide a valid one."
            return

        openai.api_key = user_model_api_key.strip()
        openai_model_name = chosen_model.split("OpenAI:")[1].strip()  # e.g. "gpt-4"

        # Build OpenAI chat messages
        messages = [
            {"role": "system", "content": system_text},
            {"role": "user", "content": user_message}
        ]

        try:
            response = openai.ChatCompletion.create(
                model=openai_model_name,
                messages=messages,
                temperature=temperature,
                max_tokens=int(max_tokens),
                top_p=top_p,
                stream=True
            )
            for chunk in response:
                if "choices" in chunk and len(chunk["choices"]) > 0:
                    delta = chunk["choices"][0]["delta"]
                    if "content" in delta:
                        partial_response += delta["content"]
                        yield partial_response

        except Exception as e:
            yield f"Error calling OpenAI API: {str(e)}"
            return

###############################################################################
# 3. Build the Gradio Interface
###############################################################################
with gr.Blocks() as demo:
    gr.Markdown(
        """
        # Multi-Model Chatbot
        Choose from open-source or paid models, and provide an API key if needed.
        """
    )
    with gr.Row():
        # Left column for parameters
        with gr.Column(scale=1, min_width=300):
            system_message = gr.Textbox(
                label="System Message",
                value="You are a helpful open-source AI assistant.",
                lines=3,
            )

            # Let user pick model first
            chosen_model = gr.Dropdown(
                label="Select a Model for Your ChatBot",
                choices=MODEL_OPTIONS,
                value=MODEL_OPTIONS[0],  # default to the first
                info="Open-Source models can be used anonymously. Paid models require a valid API key."
            )

            # Then the API key tile
            user_model_api_key = gr.Textbox(
                label="API Key for the Chosen Model",
                placeholder="Required if you selected a paid model like GPT-4; optional otherwise",
                type="password"
            )

            max_tokens = gr.Slider(
                label="Max Tokens",
                minimum=1,
                maximum=2000,
                value=512,
                step=1
            )
            temperature = gr.Slider(
                label="Temperature",
                minimum=0.0,
                maximum=2.0,
                value=0.7,
                step=0.1
            )
            top_p = gr.Slider(
                label="Top-p",
                minimum=0.1,
                maximum=1.0,
                value=0.9,
                step=0.01
            )

        # Right column for the chat interface
        with gr.Column(scale=3):
            chatbot = gr.ChatInterface(
                fn=chat_with_model,
                # extra inputs
                additional_inputs=[
                    system_message,
                    chosen_model,
                    user_model_api_key,
                    max_tokens,
                    temperature,
                    top_p
                ],
                type="messages"
            )

    demo.launch()