Spaces:
Runtime error
Runtime error
J-Antoine ZAGATO
commited on
Commit
·
40d38f3
1
Parent(s):
5962754
Added multi model structure wo api key this time
Browse files
app.py
CHANGED
|
@@ -126,6 +126,18 @@ def generate(model_name,
|
|
| 126 |
|
| 127 |
return generated_sequences[0]
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
def prepare_dataset(dataset):
|
| 130 |
dataset = load_dataset(dataset, split='train')
|
| 131 |
return dataset
|
|
@@ -252,9 +264,14 @@ def upload_flag(*args):
|
|
| 252 |
if flagging_callback.flag(list(args), flag_option = None):
|
| 253 |
return gr.update(visible=True)
|
| 254 |
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
|
| 259 |
dataset = gr.Variable(value=DATASET)
|
| 260 |
prompts_var = gr.Variable(value=None)
|
|
@@ -264,76 +281,106 @@ with gr.Blocks() as demo:
|
|
| 264 |
custom_model_path = gr.Variable(value=None)
|
| 265 |
flag_choice = gr.Variable(label = "Flag", value=None)
|
| 266 |
|
| 267 |
-
|
| 268 |
flagging_callback = gr.HuggingFaceDatasetSaver(hf_token = HF_AUTH_TOKEN,
|
| 269 |
dataset_name = "fsdlredteam/flagged_2",
|
| 270 |
organization = "fsdlredteam",
|
| 271 |
private = True )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
|
| 273 |
-
|
| 274 |
|
| 275 |
-
|
| 276 |
-
gr.Markdown("### 1. Select a prompt")
|
| 277 |
|
| 278 |
-
|
|
|
|
|
|
|
| 279 |
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
|
| 284 |
-
|
| 285 |
|
| 286 |
-
|
| 287 |
|
| 288 |
-
|
| 289 |
-
with gr.Column(scale=1): # Model choice & output
|
| 290 |
-
gr.Markdown("### 2. Evaluate output")
|
| 291 |
|
|
|
|
| 292 |
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
|
|
|
|
|
|
|
|
|
| 296 |
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
output_spans = gr.HighlightedText(visible=True, label="Generated text")
|
| 303 |
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
with gr.Row(): # Flagging
|
| 307 |
-
|
| 308 |
-
with gr.Column(scale=1):
|
| 309 |
-
flag_radio = gr.Radio(choices=["Toxic", "Offensive", "Repetitive", "Incorrect", "Other",],
|
| 310 |
-
label="What's wrong with the output ?",
|
| 311 |
-
interactive=True,
|
| 312 |
-
visible=False)
|
| 313 |
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
|
|
|
| 335 |
|
| 336 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
inspo_button.click(fn=show_dataset,
|
| 338 |
inputs=dataset,
|
| 339 |
outputs=[prompts_drop, randomize_button, prompts_var])
|
|
|
|
| 126 |
|
| 127 |
return generated_sequences[0]
|
| 128 |
|
| 129 |
+
def show_mode(mode):
|
| 130 |
+
if mode == 'Single Model':
|
| 131 |
+
return (
|
| 132 |
+
gr.update(visible=True),
|
| 133 |
+
gr.update(visible=False)
|
| 134 |
+
)
|
| 135 |
+
if mode == 'Multi-Model':
|
| 136 |
+
return (
|
| 137 |
+
gr.update(visible=False),
|
| 138 |
+
gr.update(visible=True)
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
def prepare_dataset(dataset):
|
| 142 |
dataset = load_dataset(dataset, split='train')
|
| 143 |
return dataset
|
|
|
|
| 264 |
if flagging_callback.flag(list(args), flag_option = None):
|
| 265 |
return gr.update(visible=True)
|
| 266 |
|
| 267 |
+
CSS = """
|
| 268 |
+
#inside_group {
|
| 269 |
+
padding-top: 0.6em;
|
| 270 |
+
padding-bottom: 0.6em;
|
| 271 |
+
}
|
| 272 |
+
"""
|
| 273 |
+
|
| 274 |
+
with gr.Blocks(css=CSS) as demo:
|
| 275 |
|
| 276 |
dataset = gr.Variable(value=DATASET)
|
| 277 |
prompts_var = gr.Variable(value=None)
|
|
|
|
| 281 |
custom_model_path = gr.Variable(value=None)
|
| 282 |
flag_choice = gr.Variable(label = "Flag", value=None)
|
| 283 |
|
|
|
|
| 284 |
flagging_callback = gr.HuggingFaceDatasetSaver(hf_token = HF_AUTH_TOKEN,
|
| 285 |
dataset_name = "fsdlredteam/flagged_2",
|
| 286 |
organization = "fsdlredteam",
|
| 287 |
private = True )
|
| 288 |
+
|
| 289 |
+
gr.Markdown("# Project Interface proposal")
|
| 290 |
+
gr.Markdown("### Pick a text generation model below, write a prompt and explore the output")
|
| 291 |
+
gr.Markdown("### Or compare multiple models")
|
| 292 |
+
|
| 293 |
+
choose_mode = gr.Radio(choices=['Single Model', "Multi-Model"],
|
| 294 |
+
value='Single Model',
|
| 295 |
+
interactive=True,
|
| 296 |
+
visible=True,
|
| 297 |
+
show_label=False)
|
| 298 |
+
|
| 299 |
+
with gr.Group() as single_model:
|
| 300 |
+
with gr.Row():
|
| 301 |
+
|
| 302 |
+
with gr.Column(scale=1): # input & prompts dataset exploration
|
| 303 |
+
gr.Markdown("### 1. Select a prompt", elem_id="inside_group")
|
| 304 |
+
|
| 305 |
+
input_text = gr.Textbox(label="Write your prompt below.",
|
| 306 |
+
interactive=True,
|
| 307 |
+
lines=4,
|
| 308 |
+
elem_id="inside_group")
|
| 309 |
+
|
| 310 |
+
gr.Markdown("— or —", elem_id="inside_group")
|
| 311 |
+
|
| 312 |
+
inspo_button = gr.Button('Click here if you need some inspiration', elem_id="inside_group")
|
| 313 |
|
| 314 |
+
prompts_drop = gr.Dropdown(visible=False, elem_id="inside_group")
|
| 315 |
|
| 316 |
+
randomize_button = gr.Button('Show another subset', visible=False, elem_id="inside_group")
|
|
|
|
| 317 |
|
| 318 |
+
|
| 319 |
+
with gr.Column(scale=1): # Model choice & output
|
| 320 |
+
gr.Markdown("### 2. Evaluate output")
|
| 321 |
|
| 322 |
+
|
| 323 |
+
model_radio = gr.Radio(choices=list(CHECKPOINTS.keys()),
|
| 324 |
+
label='Model',
|
| 325 |
+
interactive=True,
|
| 326 |
+
elem_id="inside_group")
|
| 327 |
+
|
| 328 |
+
search_bar = gr.Textbox(label="Search model",
|
| 329 |
+
interactive=True,
|
| 330 |
+
visible=False,
|
| 331 |
+
elem_id="inside_group")
|
| 332 |
+
model_drop = gr.Dropdown(visible=False)
|
| 333 |
|
| 334 |
+
generate_button = gr.Button('Submit your prompt')
|
| 335 |
|
| 336 |
+
output_spans = gr.HighlightedText(visible=True, label="Generated text", elem_id="inside_group")
|
| 337 |
|
| 338 |
+
flag_button = gr.Button("Report output here", visible=False)
|
|
|
|
|
|
|
| 339 |
|
| 340 |
+
with gr.Row(): # Flagging
|
| 341 |
|
| 342 |
+
with gr.Column(scale=1):
|
| 343 |
+
flag_radio = gr.Radio(choices=["Toxic", "Offensive", "Repetitive", "Incorrect", "Other",],
|
| 344 |
+
label="What's wrong with the output ?",
|
| 345 |
+
interactive=True,
|
| 346 |
+
visible=False,
|
| 347 |
+
elem_id="inside_group")
|
| 348 |
|
| 349 |
+
user_comment = gr.Textbox(label="(Optional) Briefly describe the issue",
|
| 350 |
+
visible=False,
|
| 351 |
+
interactive=True,
|
| 352 |
+
elem_id="inside_group")
|
|
|
|
|
|
|
| 353 |
|
| 354 |
+
confirm_flag_button = gr.Button("Confirm report", visible=False, elem_id="inside_group")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
|
| 356 |
+
with gr.Row(): # Flagging success
|
| 357 |
+
success_message = gr.Markdown("Your report has been successfully registered. Thank you!",
|
| 358 |
+
visible=False,
|
| 359 |
+
elem_id="inside_group")
|
| 360 |
+
|
| 361 |
+
with gr.Row(): # Toxicity buttons
|
| 362 |
+
toxi_button = gr.Button("Run a toxicity analysis of the model's output", visible=False, elem_id="inside_group")
|
| 363 |
+
toxi_button_compare = gr.Button("Compare toxicity on input and output", visible=False, elem_id="inside_group")
|
| 364 |
+
|
| 365 |
+
with gr.Row(): # Toxicity scores
|
| 366 |
+
toxi_scores_input = gr.JSON(label = "Detoxify classification of your input",
|
| 367 |
+
visible=False,
|
| 368 |
+
elem_id="inside_group")
|
| 369 |
+
toxi_scores_output = gr.JSON(label="Detoxify classification of the model's output",
|
| 370 |
+
visible=False,
|
| 371 |
+
elem_id="inside_group")
|
| 372 |
+
toxi_scores_compare = gr.JSON(label = "Percentage change between Input and Output",
|
| 373 |
+
visible=False,
|
| 374 |
+
elem_id="inside_group")
|
| 375 |
+
|
| 376 |
+
with gr.Group() as multi_model:
|
| 377 |
+
gr.Markdown("Model comparison will be here")
|
| 378 |
|
| 379 |
|
| 380 |
+
choose_mode.change(fn=show_mode,
|
| 381 |
+
inputs=choose_mode,
|
| 382 |
+
outputs=[single_model, multi_model])
|
| 383 |
+
|
| 384 |
inspo_button.click(fn=show_dataset,
|
| 385 |
inputs=dataset,
|
| 386 |
outputs=[prompts_drop, randomize_button, prompts_var])
|