cjerzak commited on
Commit
5ceabec
·
verified ·
1 Parent(s): 17b3872

Update app.R

Browse files
Files changed (1) hide show
  1. app.R +189 -61
app.R CHANGED
@@ -89,6 +89,17 @@ ui <- dashboardPage(
89
  box(width = 7, title = "Preview of Covariates (X)",
90
  status = "info", solidHeader = TRUE,
91
  DTOutput("covariates_table"))
 
 
 
 
 
 
 
 
 
 
 
92
  )
93
  ),
94
 
@@ -102,7 +113,6 @@ ui <- dashboardPage(
102
  box(width = 4, title = "Rerandomization Parameters",
103
  status = "primary", solidHeader = TRUE,
104
 
105
- # If user hasn't chosen data in tab 1, we fallback or show an error
106
  numericInput("n_treated", "Number Treated (n_treated)", value = 10, min = 1),
107
  selectInput("random_type", "Randomization Type:",
108
  choices = c("Monte Carlo" = "monte_carlo",
@@ -127,6 +137,13 @@ ui <- dashboardPage(
127
  br(),
128
  plotOutput("balance_hist", height = "250px")
129
  )
 
 
 
 
 
 
 
130
  )
131
  ),
132
 
@@ -172,6 +189,20 @@ ui <- dashboardPage(
172
  br(),
173
  plotOutput("test_plot", height = "280px")
174
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  )
176
  )
177
 
@@ -184,18 +215,31 @@ ui <- dashboardPage(
184
  # ---------------------------------------------------------
185
  server <- function(input, output, session) {
186
 
 
 
 
 
 
 
 
 
187
  # -------------------------------------------------------
188
  # 1. Covariate Data Handling
189
  # -------------------------------------------------------
190
  # We store the covariate matrix X in a reactiveVal for convenient reuse
191
  X_data <- reactiveVal(NULL)
192
 
193
- # Observe file input or simulation for X
194
  observeEvent(input$file_covariates, {
195
  req(input$file_covariates)
196
  inFile <- input$file_covariates
 
 
197
  df <- tryCatch(read.csv(inFile$datapath, header = TRUE),
198
  error = function(e) NULL)
 
 
 
199
  if (!is.null(df)) {
200
  X_data(as.matrix(df))
201
  }
@@ -205,8 +249,13 @@ server <- function(input, output, session) {
205
  observeEvent(input$simulate_btn, {
206
  n <- input$sim_n
207
  p <- input$sim_p
 
 
208
  # Basic simulation of N(0,1) data
209
  simX <- matrix(rnorm(n * p), nrow = n, ncol = p)
 
 
 
210
  X_data(simX)
211
  })
212
 
@@ -217,6 +266,25 @@ server <- function(input, output, session) {
217
  options = list(scrollX = TRUE, pageLength = 5))
218
  })
219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  # -------------------------------------------------------
221
  # 2. Generate Rerandomizations
222
  # -------------------------------------------------------
@@ -230,27 +298,37 @@ server <- function(input, output, session) {
230
  "Number treated cannot exceed total units.")
231
  )
232
 
233
- # We call generate_randomizations() from fastrerandomize
234
- nunits <- nrow(X_data())
235
- # If user selected "exact" but it's huge, a warning would appear in the console
236
- out <- tryCatch({
237
- generate_randomizations(
238
- n_units = nunits,
239
- n_treated = input$n_treated,
240
- X = X_data(),
241
- randomization_accept_prob= input$accept_prob,
242
- randomization_type = input$random_type,
243
- max_draws = if (input$random_type == "monte_carlo") input$max_draws else NULL,
244
- batch_size = if (input$random_type == "monte_carlo") input$batch_size else NULL,
245
- verbose = FALSE
246
- )
247
- }, error = function(e) e)
248
-
249
- if (inherits(out, "error")) {
250
- showNotification(paste("Error generating randomizations:", out$message), type = "error")
251
- return(NULL)
252
- }
253
- RerandResult(out)
 
 
 
 
 
 
 
 
 
 
254
  })
255
 
256
  # Summaries of accepted randomizations
@@ -287,6 +365,16 @@ server <- function(input, output, session) {
287
  theme_minimal(base_size = 14)
288
  })
289
 
 
 
 
 
 
 
 
 
 
 
290
  # -------------------------------------------------------
291
  # 3. Randomization Test
292
  # -------------------------------------------------------
@@ -296,10 +384,11 @@ server <- function(input, output, session) {
296
  observeEvent(input$simulateY_btn, {
297
  req(RerandResult())
298
  rr <- RerandResult()
299
- nunits <- nrow(rr$randomizations)
300
 
 
 
 
301
  # We'll just use the first accepted randomization as the "observed" assignment
302
- # in real usage, they'd pick or define their actual assignment
303
  obsW <- rr$randomizations[1, ]
304
 
305
  # Basic data generation: Y = X * beta + tau * W + noise
@@ -311,7 +400,10 @@ server <- function(input, output, session) {
311
  # random coefficients
312
  beta <- rnorm(ncol(Xval), 0, 1)
313
  linear_part <- Xval %*% beta
314
- Ysim <- as.numeric(linear_part + obsW*input$true_tau + rnorm(nunits, 0, input$noise_sd))
 
 
 
315
 
316
  Y_data(Ysim)
317
  })
@@ -320,7 +412,12 @@ server <- function(input, output, session) {
320
  observeEvent(input$file_outcomes, {
321
  req(input$file_outcomes)
322
  inFile <- input$file_outcomes
 
 
323
  dfy <- tryCatch(read.csv(inFile$datapath, header = FALSE), error=function(e) NULL)
 
 
 
324
  if (!is.null(dfy)) {
325
  if (ncol(dfy) > 1) {
326
  showNotification("Please provide a single-column CSV for Y.", type="error")
@@ -342,32 +439,41 @@ server <- function(input, output, session) {
342
  return(NULL)
343
  }
344
 
345
- # We'll do the test with:
346
- obsW <- rr$randomizations[1, ]
347
- obsY <- Y_data()
348
- cands <- rr$randomizations
349
-
350
- if (length(obsY) != length(obsW)) {
351
- showNotification("Dimension mismatch: Y must match number of units in the randomization.", type = "error")
352
- return(NULL)
353
- }
354
-
355
- # Call the randomization_test function
356
- outTest <- tryCatch({
357
- randomization_test(
358
- obsW = obsW,
359
- obsY = obsY,
360
- candidate_randomizations = cands,
361
- findFI = input$findFI
362
- )
363
- }, error=function(e) e)
364
-
365
- if (inherits(outTest, "error")) {
366
- showNotification(paste("Error in randomization_test:", outTest$message), type="error")
367
- return(NULL)
368
- }
369
-
370
- RandTestResult(outTest)
 
 
 
 
 
 
 
 
 
371
  })
372
 
373
  # Display p-value and observed tau
@@ -404,29 +510,51 @@ server <- function(input, output, session) {
404
  )
405
  })
406
 
407
- # A simple plot for the randomization distribution
 
408
  output$test_plot <- renderPlot({
409
  rt <- RandTestResult()
410
  if (is.null(rt)) {
411
  return(NULL)
412
  }
413
- # The distribution of test stats is stored in rt$stat_distribution if you used
414
- # advanced usage in the underlying code. The default version in
415
- # randomization_test() above only returns the final p-value, so we'll do a
416
- # simpler demonstration: we only plot a vertical line for the observed effect.
417
-
418
- # We'll just do a line:
419
  obs_val <- rt$tau_obs
420
 
421
- ggplot(data.frame(x=obs_val, y=0), aes(x, y)) +
422
  geom_point(size=4, color="red") +
423
  xlim(c(obs_val - abs(obs_val)*2 - 1, obs_val + abs(obs_val)*2 + 1)) +
424
  labs(title = "Observed Treatment Effect",
425
- subtitle = "No randomization distribution stored to plot.\n(This is a minimal demonstration.)",
426
  x = "Effect Size", y = "") +
427
  theme_minimal(base_size = 14) +
428
  geom_vline(xintercept = 0, linetype="dashed", color="gray40")
429
  })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
  }
431
 
432
  # ---------------------------------------------------------
 
89
  box(width = 7, title = "Preview of Covariates (X)",
90
  status = "info", solidHeader = TRUE,
91
  DTOutput("covariates_table"))
92
+ ),
93
+
94
+ # Performance info for data steps
95
+ fluidRow(
96
+ box(width = 12, title = "Performance Info for Data Steps", status = "warning", solidHeader = TRUE,
97
+ p("Time to upload X (CSV):"),
98
+ textOutput("time_data_upload"),
99
+ br(),
100
+ p("Time to simulate X:"),
101
+ textOutput("time_data_sim")
102
+ )
103
  )
104
  ),
105
 
 
113
  box(width = 4, title = "Rerandomization Parameters",
114
  status = "primary", solidHeader = TRUE,
115
 
 
116
  numericInput("n_treated", "Number Treated (n_treated)", value = 10, min = 1),
117
  selectInput("random_type", "Randomization Type:",
118
  choices = c("Monte Carlo" = "monte_carlo",
 
137
  br(),
138
  plotOutput("balance_hist", height = "250px")
139
  )
140
+ ),
141
+
142
+ # Performance info for randomization generation
143
+ fluidRow(
144
+ box(width = 12, title = "Performance Info for Generation", status = "warning", solidHeader = TRUE,
145
+ textOutput("time_generate")
146
+ )
147
  )
148
  ),
149
 
 
189
  br(),
190
  plotOutput("test_plot", height = "280px")
191
  )
192
+ ),
193
+
194
+ # Performance info for randomization test
195
+ fluidRow(
196
+ box(width = 12, title = "Performance Info for Randomization Test", status = "warning", solidHeader = TRUE,
197
+ p("Time to upload Y (CSV):"),
198
+ textOutput("time_data_uploadY"),
199
+ br(),
200
+ p("Time to simulate Y:"),
201
+ textOutput("time_data_simY"),
202
+ br(),
203
+ p("Time to run randomization test:"),
204
+ textOutput("time_randtest")
205
+ )
206
  )
207
  )
208
 
 
215
  # ---------------------------------------------------------
216
  server <- function(input, output, session) {
217
 
218
+ # -- ReactiveVals to store performance times (seconds)
219
+ time_data_upload <- reactiveVal(NA_real_)
220
+ time_data_sim <- reactiveVal(NA_real_)
221
+ time_generate <- reactiveVal(NA_real_)
222
+ time_data_uploadY <- reactiveVal(NA_real_)
223
+ time_data_simY <- reactiveVal(NA_real_)
224
+ time_randtest <- reactiveVal(NA_real_)
225
+
226
  # -------------------------------------------------------
227
  # 1. Covariate Data Handling
228
  # -------------------------------------------------------
229
  # We store the covariate matrix X in a reactiveVal for convenient reuse
230
  X_data <- reactiveVal(NULL)
231
 
232
+ # Observe file input (upload) for X
233
  observeEvent(input$file_covariates, {
234
  req(input$file_covariates)
235
  inFile <- input$file_covariates
236
+
237
+ start_time <- Sys.time()
238
  df <- tryCatch(read.csv(inFile$datapath, header = TRUE),
239
  error = function(e) NULL)
240
+ end_time <- Sys.time()
241
+ time_data_upload(as.numeric(difftime(end_time, start_time, units = "secs")))
242
+
243
  if (!is.null(df)) {
244
  X_data(as.matrix(df))
245
  }
 
249
  observeEvent(input$simulate_btn, {
250
  n <- input$sim_n
251
  p <- input$sim_p
252
+
253
+ start_time <- Sys.time()
254
  # Basic simulation of N(0,1) data
255
  simX <- matrix(rnorm(n * p), nrow = n, ncol = p)
256
+ end_time <- Sys.time()
257
+ time_data_sim(as.numeric(difftime(end_time, start_time, units = "secs")))
258
+
259
  X_data(simX)
260
  })
261
 
 
266
  options = list(scrollX = TRUE, pageLength = 5))
267
  })
268
 
269
+ # --- Performance outputs for Data & Covariates
270
+ output$time_data_upload <- renderText({
271
+ t <- time_data_upload()
272
+ if (is.na(t)) {
273
+ "Not run yet."
274
+ } else {
275
+ paste0(round(t, 3), " seconds")
276
+ }
277
+ })
278
+
279
+ output$time_data_sim <- renderText({
280
+ t <- time_data_sim()
281
+ if (is.na(t)) {
282
+ "Not run yet."
283
+ } else {
284
+ paste0(round(t, 3), " seconds")
285
+ }
286
+ })
287
+
288
  # -------------------------------------------------------
289
  # 2. Generate Rerandomizations
290
  # -------------------------------------------------------
 
298
  "Number treated cannot exceed total units.")
299
  )
300
 
301
+ # withProgress to show progress bar in the UI
302
+ withProgress(message = "Computing randomizations...", value = 0, {
303
+
304
+ # Measure time
305
+ start_time <- Sys.time()
306
+
307
+ # We call generate_randomizations() from fastrerandomize
308
+ nunits <- nrow(X_data())
309
+ out <- tryCatch({
310
+ generate_randomizations(
311
+ n_units = nunits,
312
+ n_treated = input$n_treated,
313
+ X = X_data(),
314
+ randomization_accept_prob= input$accept_prob,
315
+ randomization_type = input$random_type,
316
+ max_draws = if (input$random_type == "monte_carlo") input$max_draws else NULL,
317
+ batch_size = if (input$random_type == "monte_carlo") input$batch_size else NULL,
318
+ verbose = FALSE
319
+ )
320
+ }, error = function(e) e)
321
+
322
+ # End time
323
+ end_time <- Sys.time()
324
+ time_generate(as.numeric(difftime(end_time, start_time, units = "secs")))
325
+
326
+ if (inherits(out, "error")) {
327
+ showNotification(paste("Error generating randomizations:", out$message), type = "error")
328
+ return(NULL)
329
+ }
330
+ RerandResult(out)
331
+ })
332
  })
333
 
334
  # Summaries of accepted randomizations
 
365
  theme_minimal(base_size = 14)
366
  })
367
 
368
+ # --- Performance output for randomization generation
369
+ output$time_generate <- renderText({
370
+ t <- time_generate()
371
+ if (is.na(t)) {
372
+ "Not run yet."
373
+ } else {
374
+ paste0(round(t, 3), " seconds")
375
+ }
376
+ })
377
+
378
  # -------------------------------------------------------
379
  # 3. Randomization Test
380
  # -------------------------------------------------------
 
384
  observeEvent(input$simulateY_btn, {
385
  req(RerandResult())
386
  rr <- RerandResult()
 
387
 
388
+ nunits <- ncol(rr$randomizations) # #units is #cols in randomizations
389
+
390
+ start_time <- Sys.time()
391
  # We'll just use the first accepted randomization as the "observed" assignment
 
392
  obsW <- rr$randomizations[1, ]
393
 
394
  # Basic data generation: Y = X * beta + tau * W + noise
 
400
  # random coefficients
401
  beta <- rnorm(ncol(Xval), 0, 1)
402
  linear_part <- Xval %*% beta
403
+ Ysim <- as.numeric(linear_part + obsW * input$true_tau + rnorm(nunits, 0, input$noise_sd))
404
+ end_time <- Sys.time()
405
+
406
+ time_data_simY(as.numeric(difftime(end_time, start_time, units = "secs")))
407
 
408
  Y_data(Ysim)
409
  })
 
412
  observeEvent(input$file_outcomes, {
413
  req(input$file_outcomes)
414
  inFile <- input$file_outcomes
415
+
416
+ start_time <- Sys.time()
417
  dfy <- tryCatch(read.csv(inFile$datapath, header = FALSE), error=function(e) NULL)
418
+ end_time <- Sys.time()
419
+ time_data_uploadY(as.numeric(difftime(end_time, start_time, units = "secs")))
420
+
421
  if (!is.null(dfy)) {
422
  if (ncol(dfy) > 1) {
423
  showNotification("Please provide a single-column CSV for Y.", type="error")
 
439
  return(NULL)
440
  }
441
 
442
+ withProgress(message = "Computing randomization test...", value = 0, {
443
+
444
+ start_time <- Sys.time()
445
+
446
+ obsW <- rr$randomizations[1, ]
447
+ obsY <- Y_data()
448
+ cands <- rr$randomizations
449
+
450
+ # Check that Y has same length as a single W
451
+ if (length(obsY) != length(obsW)) {
452
+ showNotification("Dimension mismatch: Y must match number of units in the randomization.",
453
+ type = "error")
454
+ return(NULL)
455
+ }
456
+
457
+ # Call the randomization_test function
458
+ outTest <- tryCatch({
459
+ randomization_test(
460
+ obsW = obsW,
461
+ obsY = obsY,
462
+ candidate_randomizations = cands,
463
+ findFI = input$findFI
464
+ )
465
+ }, error=function(e) e)
466
+
467
+ end_time <- Sys.time()
468
+ time_randtest(as.numeric(difftime(end_time, start_time, units = "secs")))
469
+
470
+ if (inherits(outTest, "error")) {
471
+ showNotification(paste("Error in randomization_test:", outTest$message), type="error")
472
+ return(NULL)
473
+ }
474
+
475
+ RandTestResult(outTest)
476
+ })
477
  })
478
 
479
  # Display p-value and observed tau
 
510
  )
511
  })
512
 
513
+ # A simple plot for the randomization distribution
514
+ # (no distribution stored by default, so just show the observed effect)
515
  output$test_plot <- renderPlot({
516
  rt <- RandTestResult()
517
  if (is.null(rt)) {
518
  return(NULL)
519
  }
 
 
 
 
 
 
520
  obs_val <- rt$tau_obs
521
 
522
+ ggplot(data.frame(x = obs_val, y = 0), aes(x, y)) +
523
  geom_point(size=4, color="red") +
524
  xlim(c(obs_val - abs(obs_val)*2 - 1, obs_val + abs(obs_val)*2 + 1)) +
525
  labs(title = "Observed Treatment Effect",
 
526
  x = "Effect Size", y = "") +
527
  theme_minimal(base_size = 14) +
528
  geom_vline(xintercept = 0, linetype="dashed", color="gray40")
529
  })
530
+
531
+ # --- Performance outputs for outcomes and randomization test
532
+ output$time_data_uploadY <- renderText({
533
+ t <- time_data_uploadY()
534
+ if (is.na(t)) {
535
+ "Not run yet."
536
+ } else {
537
+ paste0(round(t, 3), " seconds")
538
+ }
539
+ })
540
+
541
+ output$time_data_simY <- renderText({
542
+ t <- time_data_simY()
543
+ if (is.na(t)) {
544
+ "Not run yet."
545
+ } else {
546
+ paste0(round(t, 3), " seconds")
547
+ }
548
+ })
549
+
550
+ output$time_randtest <- renderText({
551
+ t <- time_randtest()
552
+ if (is.na(t)) {
553
+ "Not run yet."
554
+ } else {
555
+ paste0(round(t, 3), " seconds")
556
+ }
557
+ })
558
  }
559
 
560
  # ---------------------------------------------------------