Spaces:
Runtime error
Runtime error
Commit
·
f027a65
1
Parent(s):
91fbea0
Dynamic radio buttons
Browse files
app.py
CHANGED
|
@@ -11,8 +11,6 @@ from shared import Client
|
|
| 11 |
config = json.loads(os.environ['CONFIG'])
|
| 12 |
|
| 13 |
|
| 14 |
-
model_names = list(config.keys())
|
| 15 |
-
|
| 16 |
|
| 17 |
clients = {}
|
| 18 |
for name in config:
|
|
@@ -25,19 +23,29 @@ for name in config:
|
|
| 25 |
clients[name] = client
|
| 26 |
|
| 27 |
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
def respond(
|
| 33 |
message,
|
| 34 |
history: List[Tuple[str, str]],
|
| 35 |
-
persona,
|
| 36 |
-
model,
|
| 37 |
-
info,
|
| 38 |
conversational,
|
| 39 |
max_tokens,
|
|
|
|
| 40 |
):
|
|
|
|
|
|
|
| 41 |
client = clients[model]
|
| 42 |
|
| 43 |
messages = []
|
|
@@ -74,20 +82,45 @@ def respond(
|
|
| 74 |
return response
|
| 75 |
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
|
| 92 |
if __name__ == "__main__":
|
| 93 |
-
|
|
|
|
| 11 |
config = json.loads(os.environ['CONFIG'])
|
| 12 |
|
| 13 |
|
|
|
|
|
|
|
| 14 |
|
| 15 |
clients = {}
|
| 16 |
for name in config:
|
|
|
|
| 23 |
clients[name] = client
|
| 24 |
|
| 25 |
|
| 26 |
+
model_names = list(config.keys())
|
| 27 |
+
radio_infos = [f"{name} ({clients[name].vllm_model_name})" for name in model_names]
|
| 28 |
+
accordion_info = "Config"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def parse_radio_select(radio_select):
|
| 33 |
+
value_index = next(i for i in range(len(radio_select)) if radio_select[i] is not None)
|
| 34 |
+
model = model_names[value_index]
|
| 35 |
+
persona = radio_select[value_index]
|
| 36 |
+
return model, persona
|
| 37 |
+
|
| 38 |
|
| 39 |
|
| 40 |
def respond(
|
| 41 |
message,
|
| 42 |
history: List[Tuple[str, str]],
|
|
|
|
|
|
|
|
|
|
| 43 |
conversational,
|
| 44 |
max_tokens,
|
| 45 |
+
*radio_select,
|
| 46 |
):
|
| 47 |
+
model, persona = parse_radio_select(radio_select)
|
| 48 |
+
|
| 49 |
client = clients[model]
|
| 50 |
|
| 51 |
messages = []
|
|
|
|
| 82 |
return response
|
| 83 |
|
| 84 |
|
| 85 |
+
# Components
|
| 86 |
+
radios = [gr.Radio(choices=clients[name].personas.keys(), value=None, label=info) for name, info in zip(model_names, radio_infos)]
|
| 87 |
+
radios[0].value = list(clients[model_names[0]].personas.keys())[0]
|
| 88 |
+
|
| 89 |
+
conversational_checkbox = gr.Checkbox(value=True, label="conversational")
|
| 90 |
+
max_tokens_slider = gr.Slider(minimum=64, maximum=2048, value=512, step=64, label="Max new tokens")
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
with gr.Blocks() as blocks:
|
| 95 |
+
# Events
|
| 96 |
+
radio_state = gr.State([radio.value for radio in radios])
|
| 97 |
+
@gr.on(triggers=[radio.input for radio in radios], inputs=[radio_state, *radios], outputs=[radio_state, *radios])
|
| 98 |
+
def radio_click(state, *new_state):
|
| 99 |
+
changed_index = next(i for i in range(len(state)) if state[i] != new_state[i])
|
| 100 |
+
changed_value = new_state[changed_index]
|
| 101 |
+
clean_state = [None if i != changed_index else changed_value for i in range(len(state))]
|
| 102 |
+
|
| 103 |
+
return clean_state, *clean_state
|
| 104 |
+
|
| 105 |
+
# Compile
|
| 106 |
+
with gr.Accordion(label=accordion_info, open=True, render=False) as accordion:
|
| 107 |
+
[radio.render() for radio in radios]
|
| 108 |
+
conversational_checkbox.render()
|
| 109 |
+
max_tokens_slider.render()
|
| 110 |
+
|
| 111 |
+
demo = gr.ChatInterface(
|
| 112 |
+
respond,
|
| 113 |
+
additional_inputs=[
|
| 114 |
+
conversational_checkbox,
|
| 115 |
+
max_tokens_slider,
|
| 116 |
+
*radios,
|
| 117 |
+
],
|
| 118 |
+
additional_inputs_accordion=accordion,
|
| 119 |
+
title="NeonLLM (v2024-07-03)",
|
| 120 |
+
concurrency_limit=5,
|
| 121 |
+
)
|
| 122 |
+
accordion.render()
|
| 123 |
|
| 124 |
|
| 125 |
if __name__ == "__main__":
|
| 126 |
+
blocks.launch()
|