import gradio as gr import shap import torch from transformers import AutoModelForCausalLM, AutoTokenizer if torch.cuda.is_available(): device = "cuda" else: device = "cpu" MODEL_NAME = "gpt2" model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) # set model decoder to true model.config.is_decoder = True # set text-generation params under task_specific_params model.config.task_specific_params["text-generation"] = { "do_sample": True, "max_length": 50, "temperature": 0.7, "top_k": 50, "no_repeat_ngram_size": 2, } model = model.to(device) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) explainer = shap.Explainer(model, tokenizer) def start_experiment(): """Returns an APIExperiment object that is thread safe and can be used to log inferences to a single Experiment """ try: api = comet_ml.API() workspace = api.get_default_workspace() project_name = comet_ml.config.get_config()["comet.project_name"] experiment = comet_ml.APIExperiment( workspace=workspace, project_name=project_name ) experiment.log_other("Created from", "gradio-inference") message = f"Started Experiment: [{experiment.name}]({experiment.url})" return (experiment, message) except Exception as e: return None, None def predict(text, state, message): experiment = state shap_values = explainer([text]) plot = shap.plots.text(shap_values, display=False) if experiment is not None: experiment.log_other("message", message) experiment.log_html(plot) return plot with gr.Blocks() as demo: start_experiment_btn = gr.Button("Start New Experiment") experiment_status = gr.Markdown() # Log a message to the Experiment to provide more context experiment_message = gr.Textbox(label="Experiment Message") experiment = gr.State() input_text = gr.Textbox(label="Input Text", lines=5, interactive=True) submit_btn = gr.Button("Submit") output = gr.HTML() start_experiment_btn.click( start_experiment, outputs=[experiment, experiment_status] ) submit_btn.click( predict, inputs=[input_text, experiment, experiment_message], outputs=[output] ) demo.launch(share=True)