Kilos1 commited on
Commit
4efe98f
·
verified ·
1 Parent(s): 8d0ca55

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import shap
3
+ import torch
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+
6
+ if torch.cuda.is_available():
7
+ device = "cuda"
8
+ else:
9
+ device = "cpu"
10
+
11
+ MODEL_NAME = "gpt2"
12
+
13
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
14
+
15
+ # set model decoder to true
16
+ model.config.is_decoder = True
17
+ # set text-generation params under task_specific_params
18
+ model.config.task_specific_params["text-generation"] = {
19
+ "do_sample": True,
20
+ "max_length": 50,
21
+ "temperature": 0.7,
22
+ "top_k": 50,
23
+ "no_repeat_ngram_size": 2,
24
+ }
25
+ model = model.to(device)
26
+
27
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
28
+ explainer = shap.Explainer(model, tokenizer)
29
+
30
+
31
+ def start_experiment():
32
+ """Returns an APIExperiment object that is thread safe
33
+ and can be used to log inferences to a single Experiment
34
+ """
35
+ try:
36
+ api = comet_ml.API()
37
+ workspace = api.get_default_workspace()
38
+ project_name = comet_ml.config.get_config()["comet.project_name"]
39
+
40
+ experiment = comet_ml.APIExperiment(
41
+ workspace=workspace, project_name=project_name
42
+ )
43
+ experiment.log_other("Created from", "gradio-inference")
44
+
45
+ message = f"Started Experiment: [{experiment.name}]({experiment.url})"
46
+
47
+ return (experiment, message)
48
+
49
+ except Exception as e:
50
+ return None, None
51
+
52
+
53
+ def predict(text, state, message):
54
+ experiment = state
55
+
56
+ shap_values = explainer([text])
57
+ plot = shap.plots.text(shap_values, display=False)
58
+
59
+ if experiment is not None:
60
+ experiment.log_other("message", message)
61
+ experiment.log_html(plot)
62
+
63
+ return plot
64
+
65
+
66
+ with gr.Blocks() as demo:
67
+ start_experiment_btn = gr.Button("Start New Experiment")
68
+ experiment_status = gr.Markdown()
69
+
70
+ # Log a message to the Experiment to provide more context
71
+ experiment_message = gr.Textbox(label="Experiment Message")
72
+ experiment = gr.State()
73
+
74
+ input_text = gr.Textbox(label="Input Text", lines=5, interactive=True)
75
+ submit_btn = gr.Button("Submit")
76
+
77
+ output = gr.HTML()
78
+
79
+ start_experiment_btn.click(
80
+ start_experiment, outputs=[experiment, experiment_status]
81
+ )
82
+ submit_btn.click(
83
+ predict, inputs=[input_text, experiment, experiment_message], outputs=[output]
84
+ )