Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import shap | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| else: | |
| device = "cpu" | |
| MODEL_NAME = "gpt2" | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) | |
| # set model decoder to true | |
| model.config.is_decoder = True | |
| # set text-generation params under task_specific_params | |
| model.config.task_specific_params["text-generation"] = { | |
| "do_sample": True, | |
| "max_length": 50, | |
| "temperature": 0.7, | |
| "top_k": 50, | |
| "no_repeat_ngram_size": 2, | |
| } | |
| model = model.to(device) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| explainer = shap.Explainer(model, tokenizer) | |
| def start_experiment(): | |
| """Returns an APIExperiment object that is thread safe | |
| and can be used to log inferences to a single Experiment | |
| """ | |
| try: | |
| api = comet_ml.API() | |
| workspace = api.get_default_workspace() | |
| project_name = comet_ml.config.get_config()["comet.project_name"] | |
| experiment = comet_ml.APIExperiment( | |
| workspace=workspace, project_name=project_name | |
| ) | |
| experiment.log_other("Created from", "gradio-inference") | |
| message = f"Started Experiment: [{experiment.name}]({experiment.url})" | |
| return (experiment, message) | |
| except Exception as e: | |
| return None, None | |
| def predict(text, state, message): | |
| experiment = state | |
| shap_values = explainer([text]) | |
| plot = shap.plots.text(shap_values, display=False) | |
| if experiment is not None: | |
| experiment.log_other("message", message) | |
| experiment.log_html(plot) | |
| return plot | |
| with gr.Blocks() as demo: | |
| start_experiment_btn = gr.Button("Start New Experiment") | |
| experiment_status = gr.Markdown() | |
| # Log a message to the Experiment to provide more context | |
| experiment_message = gr.Textbox(label="Experiment Message") | |
| experiment = gr.State() | |
| input_text = gr.Textbox(label="Input Text", lines=5, interactive=True) | |
| submit_btn = gr.Button("Submit") | |
| output = gr.HTML() | |
| start_experiment_btn.click( | |
| start_experiment, outputs=[experiment, experiment_status] | |
| ) | |
| submit_btn.click( | |
| predict, inputs=[input_text, experiment, experiment_message], outputs=[output] | |
| ) | |
| demo.launch(share=True) |