cjerzak commited on
Commit
440bfdf
·
verified ·
1 Parent(s): 2caade6

Update app.R

Browse files
Files changed (1) hide show
  1. app.R +425 -48
app.R CHANGED
@@ -1,58 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
  library(shiny)
2
- library(bslib)
3
- library(dplyr)
4
- library(ggplot2)
 
5
 
6
- df <- readr::read_csv("penguins.csv")
7
- # Find subset of columns that are suitable for scatter plot
8
- df_num <- df |> select(where(is.numeric), -Year)
9
 
10
- ui <- page_sidebar(
11
- theme = bs_theme(bootswatch = "minty"),
12
- title = "Penguins explorer",
13
- sidebar = sidebar(
14
- varSelectInput("xvar", "X variable", df_num, selected = "Bill Length (mm)"),
15
- varSelectInput("yvar", "Y variable", df_num, selected = "Bill Depth (mm)"),
16
- checkboxGroupInput("species", "Filter by species",
17
- choices = unique(df$Species), selected = unique(df$Species)
18
- ),
19
- hr(), # Add a horizontal rule
20
- checkboxInput("by_species", "Show species", TRUE),
21
- checkboxInput("show_margins", "Show marginal plots", TRUE),
22
- checkboxInput("smooth", "Add smoother"),
23
  ),
24
- plotOutput("scatter")
25
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
 
 
 
27
  server <- function(input, output, session) {
28
- subsetted <- reactive({
29
- req(input$species)
30
- df |> filter(Species %in% input$species)
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  })
32
-
33
- output$scatter <- renderPlot(
34
- {
35
- p <- ggplot(subsetted(), aes(!!input$xvar, !!input$yvar)) +
36
- theme_light() +
37
- list(
38
- theme(legend.position = "bottom"),
39
- if (input$by_species) aes(color = Species),
40
- geom_point(),
41
- if (input$smooth) geom_smooth()
42
- )
43
-
44
- if (input$show_margins) {
45
- margin_type <- if (input$by_species) "density" else "histogram"
46
- p <- p |> ggExtra::ggMarginal(
47
- type = margin_type, margins = "both",
48
- size = 8, groupColour = input$by_species, groupFill = input$by_species
49
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  }
51
-
52
- p
53
- },
54
- res = 100
55
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  }
57
 
58
- shinyApp(ui, server)
 
 
 
 
1
+ # ============================================================
2
+ # app.R | Shiny App for Rerandomization with fastrerandomize
3
+ # ============================================================
4
+ # 1) The user can upload or simulate a covariate dataset (X).
5
+ # 2) They specify rerandomization parameters: n_treated, acceptance prob, etc.
6
+ # 3) The app generates a set of accepted randomizations under rerandomization.
7
+ # 4) The user can optionally upload or simulate outcomes (Y) and run a randomization test.
8
+ # 5) The app displays distribution of the balance measure (e.g., Hotelling's T^2) and final p-value/fiducial interval.
9
+
10
+ # ----------------------------
11
+ # Load required packages
12
+ # ----------------------------
13
  library(shiny)
14
+ library(shinydashboard)
15
+ library(DT) # For data tables
16
+ library(ggplot2) # For basic plotting
17
+ library(fastrerandomize) # Our rerandomization package
18
 
19
+ # For production apps, ensure fastrerandomize is installed:
20
+ # install.packages("devtools")
21
+ # devtools::install_github("cjerzak/fastrerandomize-software/fastrerandomize")
22
 
23
+ # ---------------------------------------------------------
24
+ # UI Section
25
+ # ---------------------------------------------------------
26
+ ui <- dashboardPage(
27
+
28
+ # ========== Header =================
29
+ dashboardHeader(
30
+ title = tags$span(
31
+ "fastrerandomize Demo",
32
+ style = "font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif;"
33
+ )
 
 
34
  ),
35
+
36
+ # ========== Sidebar ================
37
+ dashboardSidebar(
38
+ sidebarMenu(
39
+ menuItem("Data & Covariates", tabName = "datatab", icon = icon("database")),
40
+ menuItem("Generate Randomizations", tabName = "gennet", icon = icon("random")),
41
+ menuItem("Randomization Test", tabName = "randtest", icon = icon("flask"))
42
+ )
43
+ ),
44
+
45
+ # ========== Body ===================
46
+ dashboardBody(
47
+
48
+ # A little CSS to keep the design timeless and clean
49
+ tags$head(
50
+ tags$style(HTML("
51
+ .smalltext { font-size: 90%; color: #555; }
52
+ .shiny-output-error { color: red; }
53
+ .shiny-input-container { margin-bottom: 15px; }
54
+ "))
55
+ ),
56
+
57
+ tabItems(
58
+
59
+ # ------------------------------------------------
60
+ # 1) Data & Covariates Tab
61
+ # ------------------------------------------------
62
+ tabItem(
63
+ tabName = "datatab",
64
+
65
+ fluidRow(
66
+ box(width = 5, title = "Covariate Data: Upload or Simulate",
67
+ status = "primary", solidHeader = TRUE,
68
+
69
+ radioButtons("data_source", "Data Source:",
70
+ choices = c("Upload CSV" = "upload",
71
+ "Simulate data" = "simulate"),
72
+ selected = "simulate"),
73
+
74
+ conditionalPanel(
75
+ condition = "input.data_source == 'upload'",
76
+ fileInput("file_covariates", "Choose CSV File",
77
+ accept = c(".csv")),
78
+ helpText("Columns = features/covariates, rows = units.")
79
+ ),
80
+
81
+ conditionalPanel(
82
+ condition = "input.data_source == 'simulate'",
83
+ numericInput("sim_n", "Number of units (rows)", value = 20, min = 2),
84
+ numericInput("sim_p", "Number of covariates (columns)", value = 3, min = 1),
85
+ actionButton("simulate_btn", "Simulate X")
86
+ )
87
+ ),
88
+
89
+ box(width = 7, title = "Preview of Covariates (X)",
90
+ status = "info", solidHeader = TRUE,
91
+ DTOutput("covariates_table"))
92
+ )
93
+ ),
94
+
95
+ # ------------------------------------------------
96
+ # 2) Generate Randomizations Tab
97
+ # ------------------------------------------------
98
+ tabItem(
99
+ tabName = "gennet",
100
+
101
+ fluidRow(
102
+ box(width = 4, title = "Rerandomization Parameters",
103
+ status = "primary", solidHeader = TRUE,
104
+
105
+ # If user hasn't chosen data in tab 1, we fallback or show an error
106
+ numericInput("n_treated", "Number Treated (n_treated)", value = 10, min = 1),
107
+ selectInput("random_type", "Randomization Type:",
108
+ choices = c("Monte Carlo" = "monte_carlo",
109
+ "Exact" = "exact"),
110
+ selected = "monte_carlo"),
111
+ numericInput("accept_prob", "Acceptance Probability (stringency)",
112
+ value = 0.05, min = 0.0001, max = 1),
113
+ conditionalPanel(
114
+ condition = "input.random_type == 'monte_carlo'",
115
+ numericInput("max_draws", "Max Draws (MC)", value = 1e4, min = 1e3),
116
+ numericInput("batch_size", "Batch Size (MC)", value = 1e3, min = 1e2)
117
+ ),
118
+ actionButton("generate_btn", "Generate Randomizations")
119
+ ),
120
+
121
+ box(width = 8, title = "Summary of Accepted Randomizations",
122
+ status = "info", solidHeader = TRUE,
123
+ fluidRow(
124
+ valueBoxOutput("n_accepted_box", width = 6),
125
+ valueBoxOutput("balance_min_box", width = 6)
126
+ ),
127
+ br(),
128
+ plotOutput("balance_hist", height = "250px")
129
+ )
130
+ )
131
+ ),
132
+
133
+ # ------------------------------------------------
134
+ # 3) Randomization Test Tab
135
+ # ------------------------------------------------
136
+ tabItem(
137
+ tabName = "randtest",
138
+
139
+ fluidRow(
140
+ box(width = 4, title = "Randomization Test Setup",
141
+ status = "primary", solidHeader = TRUE,
142
+
143
+ radioButtons("outcome_source", "Outcome Data (Y):",
144
+ choices = c("Simulate Y" = "simulate",
145
+ "Upload CSV" = "uploadY"),
146
+ selected = "simulate"),
147
+
148
+ conditionalPanel(
149
+ condition = "input.outcome_source == 'simulate'",
150
+ numericInput("true_tau", "True Effect (simulate)", 1, step = 0.5),
151
+ numericInput("noise_sd", "Noise SD for Y", 0.5, step = 0.1),
152
+ actionButton("simulateY_btn", "Simulate Y")
153
+ ),
154
+ conditionalPanel(
155
+ condition = "input.outcome_source == 'uploadY'",
156
+ fileInput("file_outcomes", "Choose CSV File with outcome vector Y",
157
+ accept = c(".csv")),
158
+ helpText("Single column with length = #units.")
159
+ ),
160
+
161
+ br(),
162
+ actionButton("run_randtest_btn", "Run Randomization Test"),
163
+ checkboxInput("findFI", "Compute Fiducial Interval?", value = FALSE)
164
+ ),
165
+
166
+ box(width = 8, title = "Test Results", status = "info", solidHeader = TRUE,
167
+ fluidRow(
168
+ valueBoxOutput("pvalue_box", width = 6),
169
+ valueBoxOutput("tauobs_box", width = 6)
170
+ ),
171
+ uiOutput("fi_text"),
172
+ br(),
173
+ plotOutput("test_plot", height = "280px")
174
+ )
175
+ )
176
+ )
177
+
178
+ ) # end tabItems
179
+ ) # end dashboardBody
180
+ ) # end dashboardPage
181
 
182
+ # ---------------------------------------------------------
183
+ # SERVER
184
+ # ---------------------------------------------------------
185
  server <- function(input, output, session) {
186
+
187
+ # -------------------------------------------------------
188
+ # 1. Covariate Data Handling
189
+ # -------------------------------------------------------
190
+ # We store the covariate matrix X in a reactiveVal for convenient reuse
191
+ X_data <- reactiveVal(NULL)
192
+
193
+ # Observe file input or simulation for X
194
+ observeEvent(input$file_covariates, {
195
+ req(input$file_covariates)
196
+ inFile <- input$file_covariates
197
+ df <- tryCatch(read.csv(inFile$datapath, header = TRUE),
198
+ error = function(e) NULL)
199
+ if (!is.null(df)) {
200
+ X_data(as.matrix(df))
201
+ }
202
  })
203
+
204
+ # If the user clicks "Simulate X"
205
+ observeEvent(input$simulate_btn, {
206
+ n <- input$sim_n
207
+ p <- input$sim_p
208
+ # Basic simulation of N(0,1) data
209
+ simX <- matrix(rnorm(n * p), nrow = n, ncol = p)
210
+ X_data(simX)
211
+ })
212
+
213
+ # Show X in table
214
+ output$covariates_table <- renderDT({
215
+ req(X_data())
216
+ datatable(as.data.frame(X_data()),
217
+ options = list(scrollX = TRUE, pageLength = 5))
218
+ })
219
+
220
+ # -------------------------------------------------------
221
+ # 2. Generate Rerandomizations
222
+ # -------------------------------------------------------
223
+ # We'll keep the accepted randomizations in a reactiveVal
224
+ RerandResult <- reactiveVal(NULL)
225
+
226
+ observeEvent(input$generate_btn, {
227
+ req(X_data())
228
+ validate(
229
+ need(nrow(X_data()) >= input$n_treated,
230
+ "Number treated cannot exceed total units.")
231
+ )
232
+
233
+ # We call generate_randomizations() from fastrerandomize
234
+ nunits <- nrow(X_data())
235
+ # If user selected "exact" but it's huge, a warning would appear in the console
236
+ out <- tryCatch({
237
+ generate_randomizations(
238
+ n_units = nunits,
239
+ n_treated = input$n_treated,
240
+ X = X_data(),
241
+ randomization_accept_prob= input$accept_prob,
242
+ randomization_type = input$random_type,
243
+ max_draws = if (input$random_type == "monte_carlo") input$max_draws else NULL,
244
+ batch_size = if (input$random_type == "monte_carlo") input$batch_size else NULL,
245
+ verbose = FALSE
246
+ )
247
+ }, error = function(e) e)
248
+
249
+ if (inherits(out, "error")) {
250
+ showNotification(paste("Error generating randomizations:", out$message), type = "error")
251
+ return(NULL)
252
+ }
253
+ RerandResult(out)
254
+ })
255
+
256
+ # Summaries of accepted randomizations
257
+ output$n_accepted_box <- renderValueBox({
258
+ rr <- RerandResult()
259
+ if (is.null(rr) || is.null(rr$randomizations)) {
260
+ valueBox("0", "Accepted Randomizations", icon = icon("ban"), color = "red")
261
+ } else {
262
+ nAcc <- nrow(rr$randomizations)
263
+ valueBox(nAcc, "Accepted Randomizations", icon = icon("check"), color = "green")
264
+ }
265
+ })
266
+
267
+ output$balance_min_box <- renderValueBox({
268
+ rr <- RerandResult()
269
+ if (is.null(rr) || is.null(rr$balance)) {
270
+ valueBox("---", "Min Balance Measure", icon = icon("question"), color = "orange")
271
+ } else {
272
+ minBal <- round(min(rr$balance), 4)
273
+ valueBox(minBal, "Min Balance Measure", icon = icon("thumbs-up"), color = "blue")
274
+ }
275
+ })
276
+
277
+ # Plot histogram of the balance measure
278
+ output$balance_hist <- renderPlot({
279
+ rr <- RerandResult()
280
+ req(rr, rr$balance)
281
+ df <- data.frame(balance = rr$balance)
282
+ ggplot(df, aes(x = balance)) +
283
+ geom_histogram(binwidth = diff(range(df$balance))/30, fill = "darkblue", alpha = 0.7) +
284
+ labs(title = "Distribution of Balance Measure",
285
+ x = "Balance (e.g. T^2)",
286
+ y = "Frequency") +
287
+ theme_minimal(base_size = 14)
288
+ })
289
+
290
+ # -------------------------------------------------------
291
+ # 3. Randomization Test
292
+ # -------------------------------------------------------
293
+ Y_data <- reactiveVal(NULL)
294
+
295
+ # (A) If user simulates Y
296
+ observeEvent(input$simulateY_btn, {
297
+ req(RerandResult())
298
+ rr <- RerandResult()
299
+ nunits <- nrow(rr$randomizations)
300
+
301
+ # We'll just use the first accepted randomization as the "observed" assignment
302
+ # in real usage, they'd pick or define their actual assignment
303
+ obsW <- rr$randomizations[1, ]
304
+
305
+ # Basic data generation: Y = X * beta + tau * W + noise
306
+ Xval <- X_data()
307
+ if (is.null(Xval)) {
308
+ showNotification("No covariate data found to help simulate outcomes. Using intercept-only model.", type="warning")
309
+ Xval <- matrix(0, nrow = nunits, ncol = 1)
310
+ }
311
+ # random coefficients
312
+ beta <- rnorm(ncol(Xval), 0, 1)
313
+ linear_part <- Xval %*% beta
314
+ Ysim <- as.numeric(linear_part + obsW*input$true_tau + rnorm(nunits, 0, input$noise_sd))
315
+
316
+ Y_data(Ysim)
317
+ })
318
+
319
+ # (B) If user uploads Y
320
+ observeEvent(input$file_outcomes, {
321
+ req(input$file_outcomes)
322
+ inFile <- input$file_outcomes
323
+ dfy <- tryCatch(read.csv(inFile$datapath, header = FALSE), error=function(e) NULL)
324
+ if (!is.null(dfy)) {
325
+ if (ncol(dfy) > 1) {
326
+ showNotification("Please provide a single-column CSV for Y.", type="error")
327
+ } else {
328
+ Y_data(as.numeric(dfy[[1]]))
329
  }
330
+ }
331
+ })
332
+
333
+ # The randomization test result:
334
+ RandTestResult <- reactiveVal(NULL)
335
+
336
+ observeEvent(input$run_randtest_btn, {
337
+ req(RerandResult())
338
+ rr <- RerandResult()
339
+ req(rr$randomizations)
340
+ if (is.null(Y_data())) {
341
+ showNotification("No outcome data Y found. Upload or simulate first.", type="error")
342
+ return(NULL)
343
+ }
344
+
345
+ # We'll do the test with:
346
+ obsW <- rr$randomizations[1, ]
347
+ obsY <- Y_data()
348
+ cands <- rr$randomizations
349
+
350
+ if (length(obsY) != length(obsW)) {
351
+ showNotification("Dimension mismatch: Y must match number of units in the randomization.", type = "error")
352
+ return(NULL)
353
+ }
354
+
355
+ # Call the randomization_test function
356
+ outTest <- tryCatch({
357
+ randomization_test(
358
+ obsW = obsW,
359
+ obsY = obsY,
360
+ candidate_randomizations = cands,
361
+ findFI = input$findFI
362
+ )
363
+ }, error=function(e) e)
364
+
365
+ if (inherits(outTest, "error")) {
366
+ showNotification(paste("Error in randomization_test:", outTest$message), type="error")
367
+ return(NULL)
368
+ }
369
+
370
+ RandTestResult(outTest)
371
+ })
372
+
373
+ # Display p-value and observed tau
374
+ output$pvalue_box <- renderValueBox({
375
+ rt <- RandTestResult()
376
+ if (is.null(rt)) {
377
+ valueBox("---", "p-value", icon = icon("question"), color = "blue")
378
+ } else {
379
+ valueBox(round(rt$p_value, 4), "p-value", icon = icon("list-check"), color = "purple")
380
+ }
381
+ })
382
+
383
+ output$tauobs_box <- renderValueBox({
384
+ rt <- RandTestResult()
385
+ if (is.null(rt)) {
386
+ valueBox("---", "Observed Effect", icon = icon("question"), color = "maroon")
387
+ } else {
388
+ valueBox(round(rt$tau_obs, 4), "Observed Effect", icon = icon("bullseye"), color = "maroon")
389
+ }
390
+ })
391
+
392
+ # If we have a fiducial interval, display it
393
+ output$fi_text <- renderUI({
394
+ rt <- RandTestResult()
395
+ if (is.null(rt) || is.null(rt$FI)) {
396
+ return(NULL)
397
+ }
398
+ fi_lower <- round(rt$FI[1], 4)
399
+ fi_upper <- round(rt$FI[2], 4)
400
+
401
+ tagList(
402
+ strong("Fiducial Interval (95%):"),
403
+ p(sprintf("[%.4f, %.4f]", fi_lower, fi_upper))
404
+ )
405
+ })
406
+
407
+ # A simple plot for the randomization distribution
408
+ output$test_plot <- renderPlot({
409
+ rt <- RandTestResult()
410
+ if (is.null(rt)) {
411
+ return(NULL)
412
+ }
413
+ # The distribution of test stats is stored in rt$stat_distribution if you used
414
+ # advanced usage in the underlying code. The default version in
415
+ # randomization_test() above only returns the final p-value, so we'll do a
416
+ # simpler demonstration: we only plot a vertical line for the observed effect.
417
+
418
+ # We'll just do a line:
419
+ obs_val <- rt$tau_obs
420
+
421
+ ggplot(data.frame(x=obs_val, y=0), aes(x, y)) +
422
+ geom_point(size=4, color="red") +
423
+ xlim(c(obs_val - abs(obs_val)*2 - 1, obs_val + abs(obs_val)*2 + 1)) +
424
+ labs(title = "Observed Treatment Effect",
425
+ subtitle = "No randomization distribution stored to plot.\n(This is a minimal demonstration.)",
426
+ x = "Effect Size", y = "") +
427
+ theme_minimal(base_size = 14) +
428
+ geom_vline(xintercept = 0, linetype="dashed", color="gray40")
429
+ })
430
  }
431
 
432
+ # ---------------------------------------------------------
433
+ # Run the Application
434
+ # ---------------------------------------------------------
435
+ shinyApp(ui = ui, server = server)