Spaces:
Running
Running
add SHAP values analysis
Browse files- Dockerfile +3 -0
- app.R +106 -1
- calculate_shap.R +328 -0
- plot_shap.R +299 -0
- selected_features.tsv +16 -0
Dockerfile
CHANGED
|
@@ -7,7 +7,10 @@ RUN install2.r --error \
|
|
| 7 |
ggExtra \
|
| 8 |
readr \
|
| 9 |
caret \
|
|
|
|
| 10 |
ggplot2 \
|
|
|
|
|
|
|
| 11 |
shiny
|
| 12 |
|
| 13 |
RUN install2.r --error \
|
|
|
|
| 7 |
ggExtra \
|
| 8 |
readr \
|
| 9 |
caret \
|
| 10 |
+
fastshap \
|
| 11 |
ggplot2 \
|
| 12 |
+
ggExtra \
|
| 13 |
+
forcats \
|
| 14 |
shiny
|
| 15 |
|
| 16 |
RUN install2.r --error \
|
app.R
CHANGED
|
@@ -5,6 +5,9 @@ library(readr)
|
|
| 5 |
library(catboost)
|
| 6 |
library(ggplot2)
|
| 7 |
|
|
|
|
|
|
|
|
|
|
| 8 |
# Load the pre-trained model
|
| 9 |
model <- readRDS("goat_behavior_model_caret.rds")
|
| 10 |
|
|
@@ -149,7 +152,18 @@ ui <- fluidPage(
|
|
| 149 |
tableOutput("contents"),
|
| 150 |
verbatimTextOutput("confusionMatText"),
|
| 151 |
plotOutput("confusionMatPlot"),
|
| 152 |
-
downloadButton("downloadData", "Download Predictions"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
)
|
| 154 |
)
|
| 155 |
)
|
|
@@ -229,6 +243,97 @@ server <- function(input, output) {
|
|
| 229 |
theme_minimal() +
|
| 230 |
theme(axis.text.x = element_text(angle = 45, vjust = 1, hjust = 1))
|
| 231 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
}
|
| 233 |
|
| 234 |
# Create a Shiny app object
|
|
|
|
| 5 |
library(catboost)
|
| 6 |
library(ggplot2)
|
| 7 |
|
| 8 |
+
source("calculate_shap.R")
|
| 9 |
+
source("plot_shap.R")
|
| 10 |
+
|
| 11 |
# Load the pre-trained model
|
| 12 |
model <- readRDS("goat_behavior_model_caret.rds")
|
| 13 |
|
|
|
|
| 152 |
tableOutput("contents"),
|
| 153 |
verbatimTextOutput("confusionMatText"),
|
| 154 |
plotOutput("confusionMatPlot"),
|
| 155 |
+
downloadButton("downloadData", "Download Predictions")),
|
| 156 |
+
|
| 157 |
+
tabPanel("SHAP Summary",
|
| 158 |
+
plotOutput("SHAPSummary")),
|
| 159 |
+
|
| 160 |
+
tabPanel("SHAP Summary per class",
|
| 161 |
+
plotOutput("SHAPSummaryperclass")),
|
| 162 |
+
|
| 163 |
+
tabPanel("SHAP Dependency",
|
| 164 |
+
plotOutput("SHAPDependency"))
|
| 165 |
+
|
| 166 |
+
|
| 167 |
)
|
| 168 |
)
|
| 169 |
)
|
|
|
|
| 243 |
theme_minimal() +
|
| 244 |
theme(axis.text.x = element_text(angle = 45, vjust = 1, hjust = 1))
|
| 245 |
})
|
| 246 |
+
|
| 247 |
+
output$SHAPSummary <- renderPlot({
|
| 248 |
+
|
| 249 |
+
if (is.null(input$file1))
|
| 250 |
+
return(NULL)
|
| 251 |
+
|
| 252 |
+
inFile <- input$file1
|
| 253 |
+
dataset <- readr::read_delim(inFile$datapath,delim='\t')
|
| 254 |
+
predictions <- predict(model, dataset)
|
| 255 |
+
selected_variables <-
|
| 256 |
+
readr::read_delim(
|
| 257 |
+
"selected_features.tsv",
|
| 258 |
+
col_types = cols(),
|
| 259 |
+
delim = '\t'
|
| 260 |
+
)
|
| 261 |
+
new_dataset <-
|
| 262 |
+
dataset %>% select(selected_variables$variable, Anim, Activity)
|
| 263 |
+
new_dataset <- cbind(new_dataset, predictions)
|
| 264 |
+
|
| 265 |
+
shap_values <- calculate_shap(new_dataset, model, nsim = 30)
|
| 266 |
+
pall<-shap_summary_plot(shap_values)
|
| 267 |
+
pall+xlim(0,0.35)
|
| 268 |
+
})
|
| 269 |
+
|
| 270 |
+
output$SHAPSummaryperclass <- renderPlot({
|
| 271 |
+
|
| 272 |
+
if (is.null(input$file1))
|
| 273 |
+
return(NULL)
|
| 274 |
+
|
| 275 |
+
inFile <- input$file1
|
| 276 |
+
dataset <- readr::read_delim(inFile$datapath,delim='\t')
|
| 277 |
+
predictions <- predict(model, dataset)
|
| 278 |
+
selected_variables <-
|
| 279 |
+
readr::read_delim(
|
| 280 |
+
"selected_features.tsv",
|
| 281 |
+
col_types = cols(),
|
| 282 |
+
delim = '\t'
|
| 283 |
+
)
|
| 284 |
+
new_dataset <-
|
| 285 |
+
dataset %>% select(selected_variables$variable, Anim, Activity)
|
| 286 |
+
new_dataset <- cbind(new_dataset, predictions)
|
| 287 |
+
|
| 288 |
+
shap_values <- calculate_shap(new_dataset, model, nsim = 30)
|
| 289 |
+
|
| 290 |
+
pW<-shap_summary_plot_perclass(shap_values, class= "W",color="#C77CFF")+xlab("Activity W")+xlim(0,0.25)
|
| 291 |
+
pGM<-shap_summary_plot_perclass(shap_values, class= "GM",color="#7CAE00")+xlab("Activity GM")+xlim(0,0.25)
|
| 292 |
+
pG<-shap_summary_plot_perclass(shap_values, class= "G",color="#F8766D")+xlab("Activity G")+xlim(0,0.25)
|
| 293 |
+
pR<-shap_summary_plot_perclass(shap_values, class= "R",color="#00BFC4")+xlab("Activity R")+xlim(0,0.25)
|
| 294 |
+
|
| 295 |
+
grid.arrange(pW,pR,pG,pGM)
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
})
|
| 300 |
+
output$SHAPDependency <- renderPlot({
|
| 301 |
+
|
| 302 |
+
if (is.null(input$file1))
|
| 303 |
+
return(NULL)
|
| 304 |
+
|
| 305 |
+
inFile <- input$file1
|
| 306 |
+
dataset <- readr::read_delim(inFile$datapath,delim='\t')
|
| 307 |
+
predictions <- predict(model, dataset)
|
| 308 |
+
selected_variables <-
|
| 309 |
+
readr::read_delim(
|
| 310 |
+
"selected_features.tsv",
|
| 311 |
+
col_types = cols(),
|
| 312 |
+
delim = '\t'
|
| 313 |
+
)
|
| 314 |
+
new_dataset <-
|
| 315 |
+
dataset %>% select(selected_variables$variable, Anim, Activity)
|
| 316 |
+
new_dataset <- cbind(new_dataset, predictions)
|
| 317 |
+
|
| 318 |
+
shap_values <- calculate_shap(new_dataset, model, nsim = 30)
|
| 319 |
+
|
| 320 |
+
li<-list()
|
| 321 |
+
li[[1]]<-dependency_plot("Steps",dataset = new_dataset,shap=shap_values)
|
| 322 |
+
#li[[2]]<-dependency_plot("prev_steps1",dataset = new_dataset,shap=shap_values)
|
| 323 |
+
li[[2]]<-dependency_plot("%HeadDown",dataset = new_dataset,shap=shap_values)
|
| 324 |
+
#li[[4]]<-dependency_plot("prev_headdown1",dataset = new_dataset,shap=shap_values)
|
| 325 |
+
li[[3]]<-dependency_plot("Active",dataset = new_dataset,shap=shap_values)
|
| 326 |
+
#li[[6]]<-dependency_plot("prev_Active1",dataset = new_dataset,shap=shap_values)
|
| 327 |
+
li[[4]]<-dependency_plot("Standing",dataset = new_dataset,shap=shap_values)
|
| 328 |
+
#li[[8]]<-dependency_plot("prev_Standing1",dataset = new_dataset,shap=shap_values)
|
| 329 |
+
#li[[9]]<-dependency_plot("X_Act",dataset = new_dataset, shap=shap_values)
|
| 330 |
+
#li[[10]]<-dependency_plot("Y_Act",dataset = new_dataset, shap=shap_values)
|
| 331 |
+
#li[[11]]<-dependency_plot("DBA123",dataset = new_dataset, shap=shap_values)
|
| 332 |
+
#li[[12]]<-dependency_plot("DFA123",dataset = new_dataset, shap=shap_values)
|
| 333 |
+
do.call(grid.arrange, c(li, ncol = 1))
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
})
|
| 337 |
}
|
| 338 |
|
| 339 |
# Create a Shiny app object
|
calculate_shap.R
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
suppressPackageStartupMessages(library(dplyr))
|
| 2 |
+
suppressPackageStartupMessages(library(fastshap)) # for fast (approximate) Shapley values
|
| 3 |
+
suppressPackageStartupMessages(library(caret))
|
| 4 |
+
suppressPackageStartupMessages(library(doMC))
|
| 5 |
+
|
| 6 |
+
registerDoMC(cores = 10)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
p_function_G <-
|
| 10 |
+
function(object, newdata)
|
| 11 |
+
caret::predict.train(object, newdata = newdata, type = "prob")[, "G"]
|
| 12 |
+
p_function_GM <-
|
| 13 |
+
function(object, newdata)
|
| 14 |
+
caret::predict.train(object, newdata = newdata, type = "prob")[, "GM"]
|
| 15 |
+
p_function_R <-
|
| 16 |
+
function(object, newdata)
|
| 17 |
+
caret::predict.train(object, newdata = newdata, type = "prob")[, "R"]
|
| 18 |
+
p_function_W <-
|
| 19 |
+
function(object, newdata)
|
| 20 |
+
caret::predict.train(object, newdata = newdata, type = "prob")[, "W"]
|
| 21 |
+
|
| 22 |
+
# DEPRECATED
|
| 23 |
+
calculate_shap_deprecated <- function(dataset,model,nsim=10) {
|
| 24 |
+
# library(doParallel)
|
| 25 |
+
# registerDoParallel(8)
|
| 26 |
+
|
| 27 |
+
trainset <- dataset %>% na.omit() %>%
|
| 28 |
+
as.data.frame()
|
| 29 |
+
trainset_y <- dataset %>%
|
| 30 |
+
select(Activity) %>%
|
| 31 |
+
na.omit() %>%
|
| 32 |
+
unlist() %>%
|
| 33 |
+
unname()
|
| 34 |
+
trainset <- trainset %>% select(-Activity)
|
| 35 |
+
trainset_G <- trainset[which(trainset_y == "G"), ]
|
| 36 |
+
trainset_GM <- trainset[which(trainset_y == "GM"), ]
|
| 37 |
+
trainset_R <- trainset[which(trainset_y == "R"), ]
|
| 38 |
+
trainset_W <- trainset[which(trainset_y == "W"), ]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# Compute fast (approximate) Shapley values using 50 Monte Carlo repetitions
|
| 42 |
+
message(" - Calculating SHAP values for class G")
|
| 43 |
+
shap_values_G <-
|
| 44 |
+
fastshap::explain(
|
| 45 |
+
model,
|
| 46 |
+
X = trainset,
|
| 47 |
+
pred_wrapper = p_function_G,
|
| 48 |
+
nsim = nsim,
|
| 49 |
+
newdata = trainset_G,
|
| 50 |
+
.parallel = TRUE
|
| 51 |
+
)
|
| 52 |
+
message(" - Calculating SHAP values for class GM")
|
| 53 |
+
shap_values_GM <-
|
| 54 |
+
fastshap::explain(
|
| 55 |
+
model,
|
| 56 |
+
X = trainset,
|
| 57 |
+
pred_wrapper = p_function_GM,
|
| 58 |
+
nsim = nsim,
|
| 59 |
+
newdata = trainset_GM,
|
| 60 |
+
.parallel = TRUE
|
| 61 |
+
)
|
| 62 |
+
message(" - Calculating SHAP values for class R")
|
| 63 |
+
shap_values_R <-
|
| 64 |
+
fastshap::explain(
|
| 65 |
+
model,
|
| 66 |
+
X = trainset,
|
| 67 |
+
pred_wrapper = p_function_R,
|
| 68 |
+
nsim = nsim,
|
| 69 |
+
newdata = trainset_R,
|
| 70 |
+
.parallel = TRUE
|
| 71 |
+
)
|
| 72 |
+
message(" - Calculating SHAP values for class W")
|
| 73 |
+
shap_values_W <-
|
| 74 |
+
fastshap::explain(
|
| 75 |
+
model,
|
| 76 |
+
X = trainset,
|
| 77 |
+
pred_wrapper = p_function_W,
|
| 78 |
+
nsim = nsim,
|
| 79 |
+
newdata = trainset_W,
|
| 80 |
+
.parallel = TRUE
|
| 81 |
+
# adjust = TRUE
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
shap_values_GM$class<-"GM"
|
| 85 |
+
shap_values_G$class<-"G"
|
| 86 |
+
shap_values_R$class<-"R"
|
| 87 |
+
shap_values_W$class<-"W"
|
| 88 |
+
|
| 89 |
+
shap_values<-rbind(shap_values_G,
|
| 90 |
+
shap_values_GM,
|
| 91 |
+
shap_values_R,
|
| 92 |
+
shap_values_W)
|
| 93 |
+
shap_values
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
#' A new function for calcualting SHAP values
|
| 98 |
+
#' the function returns a dataframe with SHAP values in the same
|
| 99 |
+
#' order of the original dataset.
|
| 100 |
+
#'
|
| 101 |
+
#' SHAP value dataframe also contains information about Animal and
|
| 102 |
+
#' the prediction of the model. Notice that SHAP are calculated considering
|
| 103 |
+
#' the class (ground truth) and not the prediction. The prediction column is only
|
| 104 |
+
#' used for filtering ana analysis. The function `calculate_shapp_class()` can be
|
| 105 |
+
#' used for calculating SHAP values on prediction
|
| 106 |
+
#'
|
| 107 |
+
#' @param dataset a dataset used for calcuating SHAP. The dataset is used for
|
| 108 |
+
#' permutation during SHAP calculation and also each class is filtered and SHAP
|
| 109 |
+
#' value for each class is calculated.
|
| 110 |
+
#' @param model a model
|
| 111 |
+
#' @param nsim number of monte carlo simulation
|
| 112 |
+
#'
|
| 113 |
+
#' @return
|
| 114 |
+
#' @export
|
| 115 |
+
#'
|
| 116 |
+
#' @examples
|
| 117 |
+
calculate_shap <- function(dataset,model,nsim=10) {
|
| 118 |
+
trainset <- dataset %>% na.omit() %>%
|
| 119 |
+
as.data.frame()
|
| 120 |
+
trainset_y <- dataset %>%
|
| 121 |
+
select(Activity) %>%
|
| 122 |
+
na.omit() %>%
|
| 123 |
+
unlist() %>%
|
| 124 |
+
unname()
|
| 125 |
+
## Create an ID for maintaining the order
|
| 126 |
+
trainset <- cbind(id=seq(1:nrow(trainset)), trainset)
|
| 127 |
+
trainset <- trainset %>% select(-Activity)
|
| 128 |
+
|
| 129 |
+
trainset_G <- trainset[which(trainset_y == "G"), ]
|
| 130 |
+
trainset_GM <- trainset[which(trainset_y == "GM"), ]
|
| 131 |
+
trainset_R <- trainset[which(trainset_y == "R"), ]
|
| 132 |
+
trainset_W <- trainset[which(trainset_y == "W"), ]
|
| 133 |
+
|
| 134 |
+
id <- c(trainset_G$id,
|
| 135 |
+
trainset_GM$id,
|
| 136 |
+
trainset_R$id,
|
| 137 |
+
trainset_W$id)
|
| 138 |
+
trainset <- trainset %>% select(-id)
|
| 139 |
+
trainset_G <- trainset_G %>% select(-id)
|
| 140 |
+
trainset_GM <- trainset_GM %>% select(-id)
|
| 141 |
+
trainset_R <- trainset_R %>% select(-id)
|
| 142 |
+
trainset_W <- trainset_W %>% select(-id)
|
| 143 |
+
|
| 144 |
+
Anim <- c(trainset_G$Anim,
|
| 145 |
+
trainset_GM$Anim,
|
| 146 |
+
trainset_R$Anim,
|
| 147 |
+
trainset_W$Anim)
|
| 148 |
+
trainset <- trainset %>% select(-Anim)
|
| 149 |
+
trainset_G <- trainset_G %>% select(-Anim)
|
| 150 |
+
trainset_GM <- trainset_GM %>% select(-Anim)
|
| 151 |
+
trainset_R <- trainset_R %>% select(-Anim)
|
| 152 |
+
trainset_W <- trainset_W %>% select(-Anim)
|
| 153 |
+
|
| 154 |
+
predictions <- c(trainset_G$predictions,
|
| 155 |
+
trainset_GM$predictions,
|
| 156 |
+
trainset_R$predictions,
|
| 157 |
+
trainset_W$predictions)
|
| 158 |
+
trainset <- trainset %>% select(-predictions)
|
| 159 |
+
trainset_G <- trainset_G %>% select(-predictions)
|
| 160 |
+
trainset_GM <- trainset_GM %>% select(-predictions)
|
| 161 |
+
trainset_R <- trainset_R %>% select(-predictions)
|
| 162 |
+
trainset_W <- trainset_W %>% select(-predictions)
|
| 163 |
+
|
| 164 |
+
# Compute fast (approximate) Shapley values using 50 Monte Carlo repetitions
|
| 165 |
+
message(" - Calculating SHAP values for class G")
|
| 166 |
+
shap_values_G <-
|
| 167 |
+
fastshap::explain(
|
| 168 |
+
model,
|
| 169 |
+
X = trainset,
|
| 170 |
+
pred_wrapper = p_function_G,
|
| 171 |
+
nsim = nsim,
|
| 172 |
+
newdata = trainset_G,
|
| 173 |
+
.parallel = TRUE
|
| 174 |
+
)
|
| 175 |
+
message(" - Calculating SHAP values for class GM")
|
| 176 |
+
shap_values_GM <-
|
| 177 |
+
fastshap::explain(
|
| 178 |
+
model,
|
| 179 |
+
X = trainset,
|
| 180 |
+
pred_wrapper = p_function_GM,
|
| 181 |
+
nsim = nsim,
|
| 182 |
+
newdata = trainset_GM,
|
| 183 |
+
.parallel = TRUE
|
| 184 |
+
)
|
| 185 |
+
message(" - Calculating SHAP values for class R")
|
| 186 |
+
shap_values_R <-
|
| 187 |
+
fastshap::explain(
|
| 188 |
+
model,
|
| 189 |
+
X = trainset,
|
| 190 |
+
pred_wrapper = p_function_R,
|
| 191 |
+
nsim = nsim,
|
| 192 |
+
newdata = trainset_R,
|
| 193 |
+
.parallel = TRUE
|
| 194 |
+
)
|
| 195 |
+
message(" - Calculating SHAP values for class W")
|
| 196 |
+
shap_values_W <-
|
| 197 |
+
fastshap::explain(
|
| 198 |
+
model,
|
| 199 |
+
X = trainset,
|
| 200 |
+
pred_wrapper = p_function_W,
|
| 201 |
+
nsim = nsim,
|
| 202 |
+
newdata = trainset_W,
|
| 203 |
+
.parallel = TRUE
|
| 204 |
+
# adjust = TRUE
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
shap_values_G$class<-"G"
|
| 208 |
+
shap_values_GM$class<-"GM"
|
| 209 |
+
shap_values_R$class<-"R"
|
| 210 |
+
shap_values_W$class<-"W"
|
| 211 |
+
|
| 212 |
+
shap_values<-rbind(shap_values_G,
|
| 213 |
+
shap_values_GM,
|
| 214 |
+
shap_values_R,
|
| 215 |
+
shap_values_W)
|
| 216 |
+
|
| 217 |
+
shap_values <- shap_values %>% tibble::add_column(Anim)
|
| 218 |
+
shap_values <- shap_values %>% tibble::add_column(predictions)
|
| 219 |
+
#shap_values <-shap_values %>% tibble::add_column(id)
|
| 220 |
+
shap_values[order(id),]
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
#' Calculate SHAP values for a given PREDICTED class
|
| 224 |
+
#'
|
| 225 |
+
#' @param dataset the dataset used for permutation during SHAP calculation
|
| 226 |
+
#' @param new_data the new data we want to calculate SHAP
|
| 227 |
+
#' @param model the model used for explanation
|
| 228 |
+
#' @param nsim the number of Monte Carlos Simulations
|
| 229 |
+
#' @param function_class a wrapper function to obtain only a particular class
|
| 230 |
+
#' @param class_name the name of the class
|
| 231 |
+
#'
|
| 232 |
+
#' @return
|
| 233 |
+
#' @export
|
| 234 |
+
#'
|
| 235 |
+
#' @examples
|
| 236 |
+
#'
|
| 237 |
+
#' # Calculate the SHAP values for class G on new data
|
| 238 |
+
#' shap_values_G <- calculate_shap_class(
|
| 239 |
+
#' dataset,
|
| 240 |
+
#' new_data = newdata,
|
| 241 |
+
#' model = goat_model
|
| 242 |
+
#' nsim = 100,
|
| 243 |
+
#' function_class = p_function_G,
|
| 244 |
+
#' class_name = "G")
|
| 245 |
+
#'
|
| 246 |
+
#'
|
| 247 |
+
calculate_shap_class <- function(dataset, new_data, model,nsim=10,
|
| 248 |
+
function_class, class_name = "G") {
|
| 249 |
+
trainset <- dataset %>% na.omit() %>%
|
| 250 |
+
as.data.frame()
|
| 251 |
+
trainset_y <- dataset %>%
|
| 252 |
+
select(predictions) %>%
|
| 253 |
+
na.omit() %>%
|
| 254 |
+
unlist() %>%
|
| 255 |
+
unname()
|
| 256 |
+
|
| 257 |
+
trainset<- trainset %>%select (-Activity,-predictions,-Anim)
|
| 258 |
+
new_data_class <- new_data
|
| 259 |
+
|
| 260 |
+
Anim <- new_data_class$Anim
|
| 261 |
+
new_data_class <- new_data_class %>% select(-Anim)
|
| 262 |
+
|
| 263 |
+
Activity <- new_data_class$Activity
|
| 264 |
+
new_data_class <- new_data_class %>% select(-Activity)
|
| 265 |
+
|
| 266 |
+
predictions <- new_data_class$predictions
|
| 267 |
+
new_data_class <- new_data_class %>% select(-predictions)
|
| 268 |
+
|
| 269 |
+
# Compute fast (approximate) Shapley values using 50 Monte Carlo repetitions
|
| 270 |
+
message(" - Calculating SHAP values for class ",class_name)
|
| 271 |
+
shap_values_class <-
|
| 272 |
+
fastshap::explain(
|
| 273 |
+
model,
|
| 274 |
+
X = trainset,
|
| 275 |
+
pred_wrapper = function_class,
|
| 276 |
+
nsim = nsim,
|
| 277 |
+
newdata = new_data_class,
|
| 278 |
+
.parallel = TRUE
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
shap_values_class$class<-Activity
|
| 282 |
+
shap_values<-shap_values_class
|
| 283 |
+
|
| 284 |
+
shap_values <- shap_values %>% tibble::add_column(Anim)
|
| 285 |
+
shap_values <- shap_values %>% tibble::add_column(predictions)
|
| 286 |
+
shap_values
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
shap_summary_plot<-function(shap_values){
|
| 290 |
+
summary_plot <-
|
| 291 |
+
shap_values %>% reshape2::melt() %>% group_by(class, variable) %>%
|
| 292 |
+
summarise(mean = mean(abs(value))) %>%
|
| 293 |
+
arrange(desc(mean)) %>%
|
| 294 |
+
ggplot() +
|
| 295 |
+
ggdark::dark_theme_classic() +
|
| 296 |
+
geom_col(aes(
|
| 297 |
+
y = variable,
|
| 298 |
+
x = mean,
|
| 299 |
+
group = class,
|
| 300 |
+
fill = class
|
| 301 |
+
), position = "stack") +
|
| 302 |
+
xlab("Mean(|Shap Value|) Average impact on model output magnitude")
|
| 303 |
+
summary_plot
|
| 304 |
+
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
shap_beeswarm_plot<-function(shap_values,dataset){
|
| 308 |
+
|
| 309 |
+
shap_values <- shap_values %>% reshape2::melt()
|
| 310 |
+
dataset<-dataset %>% mutate(class=Activity) %>% select(-Activity) %>%
|
| 311 |
+
reshape2::melt() %>% group_by(variable) %>%
|
| 312 |
+
mutate(value_scale=range01(value))
|
| 313 |
+
|
| 314 |
+
beeswarm_plot<-cbind(shap_values, feature_value=dataset$value_scale) %>%
|
| 315 |
+
# filter(class=="GM") %>%
|
| 316 |
+
ggplot()+
|
| 317 |
+
facet_wrap(~class)+
|
| 318 |
+
#ggdark::dark_theme_bw()+
|
| 319 |
+
theme_classic()+
|
| 320 |
+
geom_hline(yintercept=0,
|
| 321 |
+
color = "red", size=0.5)+
|
| 322 |
+
ggforce::geom_sina(aes(x=variable,y=value,color=feature_value),size=0.5,bins=4,alpha=0.9,shape=15)+
|
| 323 |
+
scale_colour_gradient(low = "yellow", high = "red", na.value = NA)+
|
| 324 |
+
scale_colour_gradient(low = "skyblue", high = "orange", na.value = NA)+
|
| 325 |
+
xlab("Feature")+ylab("SHAP value")+
|
| 326 |
+
theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust=1))
|
| 327 |
+
beeswarm_plot
|
| 328 |
+
}
|
plot_shap.R
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
suppressPackageStartupMessages(library(dplyr))
|
| 2 |
+
suppressPackageStartupMessages(library(ggplot2))
|
| 3 |
+
suppressPackageStartupMessages(library(ggExtra))
|
| 4 |
+
suppressPackageStartupMessages(library(forcats))
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
range01 <- function(x){(x-min(x))/(max(x)-min(x))}
|
| 9 |
+
|
| 10 |
+
shap_summary_plot<-function(shap_values){
|
| 11 |
+
summary_plot <-
|
| 12 |
+
shap_values %>% reshape2::melt() %>% group_by(class, variable) %>%
|
| 13 |
+
summarise(mean = mean(abs(value))) %>%
|
| 14 |
+
arrange(desc(mean)) %>%
|
| 15 |
+
ggplot() +
|
| 16 |
+
# ggdark::dark_theme_classic() +
|
| 17 |
+
theme_classic()+
|
| 18 |
+
geom_col(aes(
|
| 19 |
+
y = variable,
|
| 20 |
+
x = mean,
|
| 21 |
+
group = class,
|
| 22 |
+
fill = class
|
| 23 |
+
), position = "stack") +
|
| 24 |
+
ylab("Feature")+
|
| 25 |
+
xlab("Mean(|Shap Value|) Average impact on model output magnitude per activity.")+
|
| 26 |
+
guides(fill=guide_legend(title="Activity"))
|
| 27 |
+
summary_plot
|
| 28 |
+
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
shap_summary_plot_perclass<-function(shap_values, class="G",color="#F8766D"){
|
| 33 |
+
shap_values <-shap_values %>% as.data.frame() %>% filter(class == {{class}} )
|
| 34 |
+
summary_plot <-
|
| 35 |
+
shap_values %>% reshape2::melt() %>% group_by(variable) %>%
|
| 36 |
+
summarise(mean = mean(abs(value))) %>%
|
| 37 |
+
ggplot() +
|
| 38 |
+
theme_classic()+
|
| 39 |
+
geom_col(aes(
|
| 40 |
+
x = mean,
|
| 41 |
+
y = fct_reorder(variable,mean)
|
| 42 |
+
),
|
| 43 |
+
fill = color
|
| 44 |
+
) +
|
| 45 |
+
ylab("Feature")+
|
| 46 |
+
xlab(paste0("Mean(|Shap Value|) Average impact on model output magnitude for activity ", class))+
|
| 47 |
+
guides(fill=guide_legend(title="Activity"))
|
| 48 |
+
summary_plot
|
| 49 |
+
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
shap_beeswarm_plot<-function(shap_values,dataset){
|
| 54 |
+
|
| 55 |
+
shap_values <- shap_values %>% reshape2::melt()
|
| 56 |
+
dataset<-dataset %>% mutate(class=Activity) %>% select(-Activity) %>%
|
| 57 |
+
reshape2::melt() %>% group_by(variable) %>%
|
| 58 |
+
mutate(value_scale=range01(value))
|
| 59 |
+
|
| 60 |
+
beeswarm_plot<-cbind(shap_values, feature_value=dataset$value_scale) %>% # filter(class=="GM") %>%
|
| 61 |
+
ggplot()+
|
| 62 |
+
facet_wrap(~class)+
|
| 63 |
+
#ggdark::dark_theme_bw()+
|
| 64 |
+
theme_classic()+
|
| 65 |
+
geom_hline(yintercept=0,
|
| 66 |
+
color = "red", size=0.5)+
|
| 67 |
+
ggforce::geom_sina(aes(x=variable,y=value,fill=feature_value),color="black", size=2.4,bins=4,alpha=0.9,shape=22)+
|
| 68 |
+
scale_fill_gradient(low = "yellow", high = "red", na.value = NA)+
|
| 69 |
+
scale_fill_gradient(low = "skyblue", high = "orange", na.value = NA)+
|
| 70 |
+
xlab("Feature")+ylab("SHAP value")+
|
| 71 |
+
theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust=1))
|
| 72 |
+
beeswarm_plot
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
#' Dependency plot for a particular feature. The plot considers
|
| 79 |
+
#' activities and FP/TP
|
| 80 |
+
#'
|
| 81 |
+
#' @param feature a particular feature to calculate
|
| 82 |
+
#' @param dataset a dataset with goat information
|
| 83 |
+
#' @param shap a shap value dataset for each feature.
|
| 84 |
+
#'
|
| 85 |
+
#' @return a dependency plot for each activity considering the selected feature
|
| 86 |
+
#' @export ggplot object
|
| 87 |
+
#'
|
| 88 |
+
#' @examples
|
| 89 |
+
#'
|
| 90 |
+
#' dataset <-
|
| 91 |
+
#' readr::read_delim("data/split/seba-caprino_loocv.tsv",
|
| 92 |
+
#' delim = '\t')
|
| 93 |
+
#' selected_variables <-
|
| 94 |
+
#' readr::read_delim(
|
| 95 |
+
#' "data/topnfeatures/seba-caprino_selected_features.tsv",
|
| 96 |
+
#' col_types = cols(),
|
| 97 |
+
#' delim = '\t'
|
| 98 |
+
#' )
|
| 99 |
+
#' dataset <-
|
| 100 |
+
#' dataset %>% select(selected_variables$variable,
|
| 101 |
+
#' Anim,
|
| 102 |
+
#' Activity)
|
| 103 |
+
#' goat_model <- readRDS("models/boost/seba-caprino_model.rds")
|
| 104 |
+
#' shap_values <- calculate_shap(dataset,
|
| 105 |
+
#' model = goat_model,
|
| 106 |
+
#' nsim = 30)
|
| 107 |
+
#' dependency_plot_full(feature = "Steps",
|
| 108 |
+
#' dataset = dataset,
|
| 109 |
+
#' shap = shap_values)
|
| 110 |
+
|
| 111 |
+
dependency_plot <- function(feature, dataset, shap) {
|
| 112 |
+
newdata <- dataset %>% mutate({{ feature }} := range01(!!sym(feature)))
|
| 113 |
+
#activities <- c("G", "GM", "W", "R")
|
| 114 |
+
activities<-dataset %>% pull(Activity) %>% unique()
|
| 115 |
+
plots <- list()
|
| 116 |
+
for (activity in activities) {
|
| 117 |
+
s <- shap[which(shap$class == activity), 1:18]
|
| 118 |
+
x <- newdata[which(newdata$Activity == activity), ]
|
| 119 |
+
data <- cbind(
|
| 120 |
+
shap = (s %>% as.data.frame %>% select(feature)),
|
| 121 |
+
feature = (x %>% select(feature)),
|
| 122 |
+
tp = x %>% mutate(tp = ifelse(Activity == predictions, "TP", "FP")) %>%
|
| 123 |
+
pull(tp)
|
| 124 |
+
)
|
| 125 |
+
names(data) <- c("shap", "feature", "tp")
|
| 126 |
+
p <- ggplot(data, aes(x = feature)) +
|
| 127 |
+
geom_point(aes(y = shap, color = tp), alpha = 0.3, size = 0.8) +
|
| 128 |
+
geom_smooth(aes(y = shap),
|
| 129 |
+
se = FALSE,
|
| 130 |
+
size = 0.5,
|
| 131 |
+
linetype = "dashed") +
|
| 132 |
+
geom_hline(
|
| 133 |
+
yintercept = 0,
|
| 134 |
+
color = 'red',
|
| 135 |
+
size = 0.5,
|
| 136 |
+
alpha = 0.5
|
| 137 |
+
) +
|
| 138 |
+
xlab(feature) +
|
| 139 |
+
labs(title = paste0("Activity ", activity)) +
|
| 140 |
+
ylab("SHAP Value") +
|
| 141 |
+
ylim(-0.1, 0.4) +
|
| 142 |
+
xlim(0, 1) +
|
| 143 |
+
theme_light() +
|
| 144 |
+
theme(legend.position = 'none')
|
| 145 |
+
|
| 146 |
+
p1 <-
|
| 147 |
+
ggMarginal(
|
| 148 |
+
p,
|
| 149 |
+
type = "histogram",
|
| 150 |
+
fill = 'gray',
|
| 151 |
+
color = 'white',
|
| 152 |
+
size = 10,
|
| 153 |
+
xparams = list(bins = 25),
|
| 154 |
+
yparams = list(bins = 15)
|
| 155 |
+
) #,margins='x')
|
| 156 |
+
plots[[activity]] <- p1
|
| 157 |
+
}
|
| 158 |
+
#plots
|
| 159 |
+
do.call(grid.arrange, c(plots, ncol = 4))
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
#' Dependency plot for a particular feature on a particular animal.
|
| 164 |
+
#' The plot considers activities and FP/TP
|
| 165 |
+
#'
|
| 166 |
+
#' @param feature a particular feature to calculate
|
| 167 |
+
#' @param dataset a dataset with goat information
|
| 168 |
+
#' @param shap a shap value dataset for each feature.
|
| 169 |
+
#' @param anim the id of the animal
|
| 170 |
+
#' @return a dependency plot for each activity considering the selected feature
|
| 171 |
+
#' @export ggplot object
|
| 172 |
+
#'
|
| 173 |
+
#' @examples
|
| 174 |
+
#'
|
| 175 |
+
#' dataset <-
|
| 176 |
+
#' readr::read_delim("data/split/seba-caprino_loocv.tsv",
|
| 177 |
+
#' delim = '\t')
|
| 178 |
+
#' selected_variables <-
|
| 179 |
+
#' readr::read_delim(
|
| 180 |
+
#' "data/topnfeatures/seba-caprino_selected_features.tsv",
|
| 181 |
+
#' col_types = cols(),
|
| 182 |
+
#' delim = '\t'
|
| 183 |
+
#' )
|
| 184 |
+
#' dataset <-
|
| 185 |
+
#' dataset %>% select(selected_variables$variable,
|
| 186 |
+
#' Anim,
|
| 187 |
+
#' Activity)
|
| 188 |
+
#' goat_model <- readRDS("models/boost/seba-caprino_model.rds")
|
| 189 |
+
#' shap_values <- calculate_shap(dataset,
|
| 190 |
+
#' model = goat_model,
|
| 191 |
+
#' nsim = 30)
|
| 192 |
+
#' dependency_plot_anim(feature = "Steps",
|
| 193 |
+
#' dataset = dataset,
|
| 194 |
+
#' shap = shap_values,
|
| 195 |
+
#' anim = 'a13')
|
| 196 |
+
dependency_plot_anim<- function(feature,dataset,shap,anim){
|
| 197 |
+
|
| 198 |
+
newdata <- dataset %>% mutate({{feature}} := range01(!!sym(feature)))
|
| 199 |
+
plots<-list()
|
| 200 |
+
activities<-newdata %>% filter(Anim == anim) %>% pull(Activity) %>% unique()
|
| 201 |
+
for (activity in activities) {
|
| 202 |
+
s <- shap[which(shap$class == activity &
|
| 203 |
+
shap$Anim == anim
|
| 204 |
+
), 1:18]
|
| 205 |
+
x <- newdata[which(newdata$Activity == activity &
|
| 206 |
+
newdata$Anim == anim
|
| 207 |
+
),]
|
| 208 |
+
data <- cbind(shap=(s %>% as.data.frame %>% select(feature)),
|
| 209 |
+
feature = (x %>% select(feature)),
|
| 210 |
+
tp = x %>% mutate(tp=ifelse(Activity == predictions,"TP","FP")) %>% pull(tp) )
|
| 211 |
+
names(data)<-c("shap","feature","tp")
|
| 212 |
+
|
| 213 |
+
p <- ggplot(data, aes(x = feature)) +
|
| 214 |
+
geom_point(aes(y = shap, color = tp), alpha = 0.3, size = 1.8) +
|
| 215 |
+
geom_smooth(aes(y = shap),
|
| 216 |
+
se = FALSE,
|
| 217 |
+
size = 0.5,
|
| 218 |
+
linetype = "dashed") +
|
| 219 |
+
geom_hline(
|
| 220 |
+
yintercept = 0,
|
| 221 |
+
color = 'red',
|
| 222 |
+
size = 0.5,
|
| 223 |
+
alpha = 0.5
|
| 224 |
+
) +
|
| 225 |
+
xlab(feature) +
|
| 226 |
+
labs(title = paste0("Activity ", activity)) +
|
| 227 |
+
ylab("SHAP Value") +
|
| 228 |
+
ylim(-0.1, 0.4) +
|
| 229 |
+
xlim(0, 1) +
|
| 230 |
+
theme_light() +
|
| 231 |
+
theme(legend.position = 'none')
|
| 232 |
+
|
| 233 |
+
p1 <-
|
| 234 |
+
ggMarginal(
|
| 235 |
+
p,
|
| 236 |
+
type = "histogram",
|
| 237 |
+
fill = 'gray',
|
| 238 |
+
color = 'white',
|
| 239 |
+
size = 15,
|
| 240 |
+
xparams = list(bins = 25),
|
| 241 |
+
yparams = list(bins = 15)
|
| 242 |
+
) #,margins='x')
|
| 243 |
+
plots[[activity]] <- p1
|
| 244 |
+
}
|
| 245 |
+
do.call(grid.arrange, c(plots, ncol = length(activities)))
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
#' contribution plot for SHAP values
|
| 249 |
+
#'
|
| 250 |
+
#' @param shap shap values for a particular class, animal, etc.
|
| 251 |
+
#' @param num_row the row number of the observation to show
|
| 252 |
+
#'
|
| 253 |
+
#' @return ggplot object
|
| 254 |
+
#' @export
|
| 255 |
+
#'
|
| 256 |
+
#' @examples
|
| 257 |
+
#'
|
| 258 |
+
#' shap_values_G <- calculate_shap_class(
|
| 259 |
+
#' dataset = dataset,
|
| 260 |
+
#' new_data = newdata,
|
| 261 |
+
#' model= model,
|
| 262 |
+
#' nsim = 100,
|
| 263 |
+
#' function_class = p_function_G,
|
| 264 |
+
#' class_name ="G")
|
| 265 |
+
#' p1 <- contribution_plot(shap_values_G,num_row = 1) +
|
| 266 |
+
#' labs(title="Anim a13: class G (FN)", subtitle = "SHAP analysis for class G")
|
| 267 |
+
#'
|
| 268 |
+
contribution_plot <-function(s, num_row = 1){
|
| 269 |
+
s<-s[num_row,]
|
| 270 |
+
s <- data.frame(
|
| 271 |
+
Variable = names(s[,1:15]),
|
| 272 |
+
Importance = apply(s[,1:15], MARGIN = 2, FUN = function(x) sum(x))
|
| 273 |
+
)
|
| 274 |
+
ggplot(s, aes(Variable, Importance, Importance,fill=Importance) )+
|
| 275 |
+
geom_col() +
|
| 276 |
+
coord_flip() +
|
| 277 |
+
xlab("") +
|
| 278 |
+
ylab("Shapley value")+
|
| 279 |
+
theme_classic()+
|
| 280 |
+
theme(legend.position = 'none')
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
contribution_plot_w_feature <-function(s, f, num_row = 1){
|
| 285 |
+
d <- data.frame(
|
| 286 |
+
variable = names(s[num_row,1:15]),
|
| 287 |
+
importance = apply(s[num_row,1:15], MARGIN = 2, FUN = function(x) sum(x)),
|
| 288 |
+
value = apply(f[num_row,1:15], MARGIN = 2, FUN = function(x) sum(x))
|
| 289 |
+
)
|
| 290 |
+
ggplot(d, aes(variable, importance, value,fill=value) )+
|
| 291 |
+
geom_col() +
|
| 292 |
+
geom_text(aes(label=round(value,digits = 2),hjust = 1.0),size=2)+
|
| 293 |
+
coord_flip() +
|
| 294 |
+
xlab("") +
|
| 295 |
+
ylab("Shapley value")+
|
| 296 |
+
scale_fill_gradient(low = 'lightgray', high = 'skyblue')+
|
| 297 |
+
theme_classic()+
|
| 298 |
+
theme(legend.position = 'none')
|
| 299 |
+
}
|
selected_features.tsv
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
variable
|
| 2 |
+
Steps
|
| 3 |
+
%HeadDown
|
| 4 |
+
Standing
|
| 5 |
+
Active
|
| 6 |
+
MeanXY
|
| 7 |
+
distance(m)
|
| 8 |
+
prev_steps1
|
| 9 |
+
X_Act
|
| 10 |
+
prev_Active1
|
| 11 |
+
prev_Standing1
|
| 12 |
+
DFA123
|
| 13 |
+
prev_headdown1
|
| 14 |
+
Lying
|
| 15 |
+
Y_Act
|
| 16 |
+
DBA123
|