cjerzak commited on
Commit
9691d6d
Β·
verified Β·
1 Parent(s): 6f82bce

Update app.R

Browse files
Files changed (1) hide show
  1. app.R +763 -315
app.R CHANGED
@@ -1,399 +1,847 @@
1
- # setwd("~/Downloads")
2
- {
3
- # app.R
4
- options(error = NULL)
5
-
6
- # ------------------------------
7
- # 1. Load Packages
8
- # ------------------------------
 
 
 
 
 
 
 
9
  library(shiny)
10
  library(shinydashboard)
11
- library(leaflet)
12
- library(raster)
13
- library(DT)
14
- library(readr)
15
- library(dplyr) # For data manipulation
16
- library(ggplot2) # For histogram
17
- library(RColorBrewer)
18
- library(sp) # For handling map clicks/extracting raster values
19
 
20
- # ------------------------------
21
- # 2. Data & Config
22
- # ------------------------------
23
 
24
- # Define time periods corresponding to each band in the GeoTIFF
25
- time_periods <- c("1990–1992", "1993–1995", "1996–1998", "1999–2001", "2002–2004",
26
- "2005–2007", "2008–2010", "2011–2013", "2014–2016", "2017–2019")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- # Load GeoTIFF data (multi-band)
29
- wealth_stack <- stack("wealth_map.tif")
 
 
30
 
31
- # Clean up out-of-range values
32
- wealth_stack[wealth_stack <= 0 | wealth_stack > 1] <- NA
 
 
 
 
 
 
 
 
 
33
 
34
- # Load improvement data (change in IWI by state/province)
35
- improvement_data <- read_csv("poverty_improvement_by_state.csv")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- # Pre-calculate the mean IWI for each band (for the "Trends Over Time" chart).
38
- band_means <- sapply(seq_len(nlayers(wealth_stack)), function(i) {
39
- vals <- values(wealth_stack[[i]])
40
- vals <- vals[!is.na(vals)]
41
- mean(vals)
42
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- # ------------------------------
45
- # 3. UI
46
- # ------------------------------
47
  ui <- dashboardPage(
48
- # -- Header
49
- dashboardHeader(title = tags$span("aidevlab.org", style = "font-family: 'OCR A Std', monospace;")),
50
 
51
- # -- Sidebar
 
 
 
 
 
 
 
 
52
  dashboardSidebar(
53
  sidebarMenu(
54
- id = "tabs",
55
- menuItem("Wealth Map", tabName = "mapTab", icon = icon("map")),
56
- menuItem("Improvement Data", tabName = "improvementTab", icon = icon("table")),
57
- menuItem("Trends Over Time", tabName = "trendTab", icon = icon("chart-line"))
58
- ),
59
- # Show inputs only for the map tab
60
- conditionalPanel(
61
- condition = "input.tabs == 'mapTab'",
62
- br(),
63
- # Replaces the old selectInput for time periods with a slider that can animate
64
- sliderInput(
65
- inputId = "time_index",
66
- label = "Select Time Period:",
67
- min = 1,
68
- max = length(time_periods),
69
- value = 1,
70
- step = 1,
71
- animate = animationOptions(interval = 1500, loop = TRUE)
72
- ),
73
- selectInput("color_palette", "Select Color Palette:",
74
- choices = c("Viridis" = "viridis",
75
- "Plasma" = "plasma",
76
- "Magma" = "magma",
77
- "Inferno"= "inferno",
78
- "Spectral (Brewer)" = "Spectral"),
79
- selected = "plasma"),
80
- sliderInput("opacity", "Map Opacity:", min = 0.2, max = 1, value = 0.8, step = 0.1)
81
  )
82
  ),
83
 
84
- # -- Body
85
  dashboardBody(
 
 
86
  tags$head(
87
- tags$link(rel = "stylesheet", href = "https://fonts.cdnfonts.com/css/ocr-a-std"),
88
  tags$style(HTML("
89
- body {
90
- font-family: 'OCR A Std', monospace !important;
91
- }
92
  "))
93
  ),
 
94
  tabItems(
95
- # ---------- MAP TAB ----------
 
 
 
96
  tabItem(
97
- tabName = "mapTab",
98
- fluidRow(
99
- # Value Boxes across the top for key stats
100
- valueBoxOutput("highest_iwi_vb", width = 4),
101
- valueBoxOutput("lowest_iwi_vb", width = 4),
102
- valueBoxOutput("avg_iwi_vb", width = 4)
103
- ),
104
  fluidRow(
105
- # Map
106
- box(
107
- title = "Wealth Map of Africa", width = 8, solidHeader = TRUE, status = "primary",
108
- leafletOutput("map", height = "550px"),
109
- p("Click anywhere on the map to view the time-series of IWI for that specific location (shown below).")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  ),
111
- # Histogram
112
- box(
113
- title = "IWI Distribution (Selected Period)", width = 4, solidHeader = TRUE, status = "info",
114
- plotOutput("iwi_histogram", height = "250px"),
115
- p("This histogram shows the distribution of the International Wealth Index (IWI) values for the selected time period across Africa."),
116
- br(),
117
- strong("Note:"),
118
- " Wealth estimates for areas without human settlements have been excluded from the analysis."
119
- )
120
- ),
121
- # Time series at clicked location
122
- fluidRow(
123
- box(
124
- title = "Time Series at Clicked Location", width = 12, solidHeader = TRUE, status = "warning",
125
- plotOutput("clicked_ts_plot", height = "300px"),
126
- p("Click on the map to see the full IWI time-series (1990–2019) for that location.")
127
- )
128
  )
129
  ),
130
 
131
- # ---------- IMPROVEMENT DATA TAB ----------
 
 
132
  tabItem(
133
- tabName = "improvementTab",
 
134
  fluidRow(
135
- box(
136
- width = 12, title = "Poverty Improvement by State", status = "primary", solidHeader = TRUE,
137
- p("This table shows the estimated improvement in mean IWI between 1990–1992 and 2017–2019 for each province in Africa.
138
- The 'Improvement' column indicates the change in IWI over this period. You can sort or filter the table,
139
- and use the download button to export the data."),
140
- downloadButton("download_data", "Download CSV", icon = icon("download")),
141
- br(), br(),
142
- DTOutput("improvement_table")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  )
144
  )
145
  ),
146
 
147
- # ---------- TRENDS OVER TIME TAB ----------
 
 
148
  tabItem(
149
- tabName = "trendTab",
 
150
  fluidRow(
151
- box(
152
- width = 12, title = "Average Wealth Index Across Africa Over Time", status = "success", solidHeader = TRUE,
153
- p("This chart aggregates the mean IWI across all of Africa in each of the ten time periods.
154
- It provides a high-level view of how wealth (as measured by IWI) has changed over time."),
155
- plotOutput("trend_plot", height = "400px")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  )
157
  )
158
  )
159
- )
160
- )
161
- )
 
162
 
163
- # ------------------------------
164
- # 4. Server
165
- # ------------------------------
166
  server <- function(input, output, session) {
167
 
168
- # ReactiveVal to store the time-series of the last clicked point (across all periods).
169
- clicked_point_vals <- reactiveVal(NULL)
 
 
 
170
 
171
- # ----------------------------------
172
- # Reactive expression for selected raster layer
173
- # ----------------------------------
174
- selected_raster <- reactive({
175
- req(input$time_index)
176
- wealth_stack[[input$time_index]]
 
 
 
177
  })
178
 
179
- # ----------------------------------
180
- # Custom color palette function
181
- # (reactive to user-selected palette)
182
- # ----------------------------------
183
- color_pal <- reactive({
184
- palette_choice <- switch(
185
- input$color_palette,
186
- "viridis" = "viridis",
187
- "plasma" = "plasma",
188
- "magma" = "magma",
189
- "inferno" = "inferno",
190
- # Fallback to a Brewer palette for "Spectral"
191
- "Spectral" = "Spectral"
192
- )
193
- colorNumeric(
194
- palette = palette_choice,
195
- domain = c(0, 1), # Domain for map: 0 to 1
196
- na.color = "transparent"
197
- )
198
  })
199
 
200
- # ----------------------------------
201
- # 1. MAP OUTPUT
202
- # ----------------------------------
203
- output$map <- renderLeaflet({
204
- # We'll create 5 legend steps: 1, 0.75, 0.5, 0.25, 0
205
- legend_values <- seq(1, 0, length.out = 5)
206
-
207
- leaflet() %>%
208
- addProviderTiles(providers$OpenStreetMap) %>%
209
- setView(lng = 20, lat = 0, zoom = 3) %>% # Center on Africa
210
- addLegend(
211
- position = "bottomright",
212
- colors = color_pal()(legend_values),
213
- labels = sprintf("%.2f", legend_values),
214
- title = "IWI",
215
- opacity = 1
216
- )
217
  })
218
 
219
- # Redraw the raster when inputs change
220
- observeEvent(list(input$time_index, input$color_palette, input$opacity), {
221
- leafletProxy("map") %>%
222
- clearImages() %>%
223
- addRasterImage(
224
- selected_raster(),
225
- colors = color_pal(),
226
- opacity = input$opacity,
227
- project = TRUE
228
- )
229
- })
230
 
231
- # ----------------------------------
232
- # Handle clicks on the map to show full time-series at that location
233
- # ----------------------------------
234
- observeEvent(input$map_click, {
235
- click <- input$map_click
236
- if (!is.null(click)) {
237
- lat <- click$lat
238
- lng <- click$lng
 
 
 
 
239
 
240
- # Convert clicked point to SpatialPoints
241
- coords <- data.frame(lng = lng, lat = lat)
242
- sp_pt <- SpatialPoints(coords, proj4string = CRS("+proj=longlat +datum=WGS84 +no_defs"))
 
 
 
 
 
 
 
 
 
 
 
 
243
 
244
- # Extract values across ALL bands at the clicked location
245
- extracted_vals <- raster::extract(wealth_stack, sp_pt)
246
- # extracted_vals is a 1x10 matrix if the point is valid
247
- if (!is.null(extracted_vals)) {
248
- # Convert to numeric vector
249
- clicked_point_vals(as.numeric(extracted_vals))
250
  } else {
251
- # If the point is outside the raster or invalid
252
- clicked_point_vals(NULL)
253
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  }
255
  })
256
 
257
- # Plot the time-series for the clicked location
258
- output$clicked_ts_plot <- renderPlot({
259
- vals <- clicked_point_vals()
260
- if (is.null(vals)) {
261
- # No location clicked yet or invalid click
262
- plot.new()
263
- title("Click on the map to see the IWI time-series here.")
264
- return()
265
  }
266
-
267
- # If user clicked in a region with all NAs, do not plot
268
- if (all(is.na(vals))) {
269
- plot.new()
270
- title("No data at this location. Try another spot.")
271
- return()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  }
273
 
274
- df <- data.frame(Period = factor(time_periods, levels = time_periods),
275
- IWI = vals)
276
 
277
- ggplot(df, aes(x = Period, y = IWI, group = 1)) +
278
- geom_line(color = "darkorange", size = 1) +
279
- geom_point(color = "darkorange", size = 2) +
280
- labs(title = "Time Series of IWI at Clicked Location",
281
- x = "Time Period",
282
- y = "IWI (0 to 1)") +
283
- ylim(0, 1) +
284
- theme_minimal(base_size = 14) +
285
- theme(axis.text.x = element_text(angle = 45, hjust = 1))
 
 
 
286
  })
287
 
288
- # ----------------------------------
289
- # 2. HISTOGRAM OUTPUT (for selected time period)
290
- # ----------------------------------
291
- output$iwi_histogram <- renderPlot({
292
- # Extract raster values for histogram
293
- r_vals <- values(selected_raster())
294
- r_vals <- r_vals[!is.na(r_vals)]
295
-
296
- ggplot(data.frame(iwi = r_vals), aes(x = iwi)) +
297
- geom_histogram(binwidth = 0.02, fill = "#2c7bb6", color = "white", alpha = 0.7) +
298
- labs(x = "IWI (0 to 1)", y = "Frequency") +
299
- theme_minimal(base_size = 14)
300
  })
301
 
302
- # ----------------------------------
303
- # 3. VALUE BOXES FOR KEY STATS
304
- # ----------------------------------
305
- # Compute stats for current raster
306
- raster_stats <- reactive({
307
- r_vals <- values(selected_raster())
308
- r_vals <- r_vals[!is.na(r_vals)]
309
- list(
310
- highest = max(r_vals, na.rm = TRUE),
311
- lowest = min(r_vals, na.rm = TRUE),
312
- average = mean(r_vals, na.rm = TRUE)
313
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  })
315
 
316
- # Highest IWI
317
- output$highest_iwi_vb <- renderValueBox({
318
- valueBox(
319
- value = round(raster_stats()$highest, 3),
320
- subtitle = "Highest IWI",
321
- icon = icon("arrow-up"),
322
- color = "green"
323
- )
324
  })
325
 
326
- # Lowest IWI
327
- output$lowest_iwi_vb <- renderValueBox({
328
- valueBox(
329
- value = round(raster_stats()$lowest, 3),
330
- subtitle = "Lowest IWI",
331
- icon = icon("arrow-down"),
332
- color = "red"
333
- )
334
  })
335
 
336
- # Average IWI
337
- output$avg_iwi_vb <- renderValueBox({
338
- valueBox(
339
- value = round(raster_stats()$average, 3),
340
- subtitle = "Average IWI",
341
- icon = icon("balance-scale"),
342
- color = "blue"
343
- )
 
344
  })
345
 
346
- # ----------------------------------
347
- # 4. IMPROVEMENT DATA TABLE
348
- # ----------------------------------
349
- output$improvement_table <- renderDT({
350
- datatable(
351
- improvement_data,
352
- filter = "top",
353
- options = list(
354
- scrollX = TRUE,
355
- pageLength = 20,
356
- autoWidth = TRUE
357
- )
 
 
 
 
 
 
 
 
 
 
358
  )
359
  })
360
 
361
- # Download CSV
362
- output$download_data <- downloadHandler(
363
- filename = function() {
364
- paste0("poverty_improvement_", Sys.Date(), ".csv")
365
- },
366
- content = function(file) {
367
- write.csv(improvement_data, file, row.names = FALSE)
368
  }
369
- )
370
-
371
- # ----------------------------------
372
- # 5. TRENDS OVER TIME (line chart of mean IWI across all Africa)
373
- # ----------------------------------
374
- output$trend_plot <- renderPlot({
375
- df <- data.frame(
376
- Period = factor(time_periods, levels = time_periods),
377
- MeanIWI = band_means
378
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
 
380
- ggplot(df, aes(x = Period, y = MeanIWI, group = 1)) +
381
- geom_line(color = "#2c7bb6", size = 1.1) +
382
- geom_point(color = "#2c7bb6", size = 2) +
383
- labs(
384
- title = "Average IWI Over Time (Africa)",
385
- x = "Time Period",
386
- y = "Mean IWI"
387
- ) +
388
- ylim(0, 1) +
389
  theme_minimal(base_size = 14) +
390
- theme(axis.text.x = element_text(angle = 45, hjust = 1))
391
  })
392
  }
393
 
394
- # ------------------------------
395
- # 6. Run the App
396
- # ------------------------------
397
  shinyApp(ui = ui, server = server)
398
- }
399
-
 
1
+ #
2
+ # ============================================================
3
+ # app.R | Shiny App for Rerandomization with fastrerandomize
4
+ # ============================================================
5
+ # 1) The user can upload or simulate a covariate dataset (X).
6
+ # 2) They specify rerandomization parameters: n_treated, acceptance prob, etc.
7
+ # 3) The app generates a set of accepted randomizations under rerandomization.
8
+ # 4) The user can optionally upload or simulate outcomes (Y) and run a randomization test.
9
+ # 5) The app displays distribution of the balance measure (e.g., Hotelling's T^2)
10
+ # and final p-value/fiducial interval, along with run-time comparisons between
11
+ # fastrerandomize and base R methods.
12
+ #
13
+ # ----------------------------
14
+ # Load required packages
15
+ # ----------------------------
16
  library(shiny)
17
  library(shinydashboard)
18
+ library(DT) # For data tables
19
+ library(ggplot2) # For basic plotting
20
+ library(fastrerandomize) # Our rerandomization package
21
+ library(parallel) # For detecting CPU cores
22
+
23
+ # For production apps, ensure fastrerandomize is installed:
24
+ # install.packages("devtools")
25
+ # devtools::install_github("cjerzak/fastrerandomize-software/fastrerandomize")
26
 
27
+ # ---------------------------------------------------------
28
+ # HELPER FUNCTIONS (BASE R)
29
+ # ---------------------------------------------------------
30
 
31
+ # 1) Compute Hotelling's T^2 in base R
32
+ baseR_hotellingT2 <- function(X, W) {
33
+ # For a single assignment W:
34
+ # T^2 = (n0 * n1 / (n0 + n1)) * (xbar1 - xbar0)^T * S_inv * (xbar1 - xbar0)
35
+ n <- length(W)
36
+ n1 <- sum(W)
37
+ n0 <- n - n1
38
+ if (n1 == 0 || n0 == 0) return(NA_real_) # invalid scenario
39
+ xbar_treat <- colMeans(X[W == 1, , drop = FALSE])
40
+ xbar_control <- colMeans(X[W == 0, , drop = FALSE])
41
+ diff_vec <- (xbar_treat - xbar_control)
42
+
43
+ # covariance (pooled) – we just use cov(X)
44
+ S <- cov(X)
45
+ Sinv <- tryCatch(solve(S), error = function(e) NULL)
46
+ if (is.null(Sinv)) {
47
+ # fallback: diagonal approximation if solve fails
48
+ Sinv <- diag(1 / diag(S), ncol(S))
49
+ }
50
+
51
+ out <- (n0 * n1 / (n0 + n1)) * c(t(diff_vec) %*% Sinv %*% diff_vec)
52
+ out
53
+ }
54
+
55
+ # 2) Generate randomizations in base R, filtering by acceptance probability
56
+ # using T^2 and keep the best (lowest) fraction.
57
+ baseR_generate_randomizations <- function(n_units, n_treated, X, accept_prob, random_type,
58
+ max_draws, batch_size) {
59
+
60
+ # For safety, check if exact enumerations will explode:
61
+ if (random_type == "exact") {
62
+ n_comb_total <- choose(n_units, n_treated)
63
+ if (n_comb_total > 1e6) {
64
+ warning(
65
+ sprintf("Exact randomization is requested, but that is %s combinations.
66
+ This may be infeasible in terms of memory/time.
67
+ Consider Monte Carlo instead.", format(n_comb_total, big.mark=",")),
68
+ immediate. = TRUE
69
+ )
70
+ }
71
+ }
72
+
73
+ if (random_type == "exact") {
74
+ # -------------- EXACT RANDOMIZATIONS --------------
75
+ cidx <- combn(n_units, n_treated)
76
+ # Build assignment matrix
77
+ n_comb <- ncol(cidx)
78
+ assignment_mat <- matrix(0, nrow = n_comb, ncol = n_units)
79
+ for (i in seq_len(n_comb)) {
80
+ assignment_mat[i, cidx[, i]] <- 1
81
+ }
82
+ # Compute T^2 for each row
83
+ T2vals <- apply(assignment_mat, 1, function(w) baseR_hotellingT2(X, w))
84
+ # Drop any NA (in pathological cases)
85
+ keep_idx <- which(!is.na(T2vals))
86
+ assignment_mat <- assignment_mat[keep_idx, , drop = FALSE]
87
+ T2vals <- T2vals[keep_idx]
88
+
89
+ # acceptance threshold
90
+ cutoff <- quantile(T2vals, probs = accept_prob)
91
+ keep_final <- (T2vals < cutoff)
92
+ assignment_mat_accepted <- assignment_mat[keep_final, , drop = FALSE]
93
+ T2vals_accepted <- T2vals[keep_final]
94
+
95
+ } else {
96
+ # -------------- MONTE CARLO RANDOMIZATIONS --------------
97
+ # We'll sample max_draws permutations
98
+ base_assign <- c(rep(1, n_treated), rep(0, n_units - n_treated))
99
+
100
+ # We'll store T^2's in chunks to reduce memory overhead
101
+ batch_count <- ceiling(max_draws / batch_size)
102
+ all_assign <- list()
103
+ all_T2 <- numeric(0)
104
+
105
+ cur_draw <- 0
106
+ for (b in seq_len(batch_count)) {
107
+ ndraws_here <- min(batch_size, max_draws - cur_draw)
108
+ cur_draw <- cur_draw + ndraws_here
109
+
110
+ # sample permutations
111
+ perms <- matrix(nrow = ndraws_here, ncol = n_units)
112
+ for (j in seq_len(ndraws_here)) {
113
+ perms[j, ] <- sample(base_assign)
114
+ }
115
+ # T^2 for each
116
+ T2vals_batch <- apply(perms, 1, function(w) baseR_hotellingT2(X, w))
117
+
118
+ # collect
119
+ all_assign[[b]] <- perms
120
+ all_T2 <- c(all_T2, T2vals_batch)
121
+ }
122
+ assignment_mat <- do.call(rbind, all_assign)
123
+
124
+ # remove any NA
125
+ keep_idx <- which(!is.na(all_T2))
126
+ assignment_mat <- assignment_mat[keep_idx, , drop = FALSE]
127
+ all_T2 <- all_T2[keep_idx]
128
+
129
+ # acceptance threshold
130
+ cutoff <- quantile(all_T2, probs = accept_prob)
131
+ keep_final <- (all_T2 < cutoff)
132
+ assignment_mat_accepted <- assignment_mat[keep_final, , drop = FALSE]
133
+ T2vals_accepted <- all_T2[keep_final]
134
+ }
135
+
136
+ list(randomizations = assignment_mat_accepted, balance = T2vals_accepted)
137
+ }
138
 
139
+ # Helper: compute difference in means quickly
140
+ diff_in_means <- function(Y, W) {
141
+ mean(Y[W == 1]) - mean(Y[W == 0])
142
+ }
143
 
144
+ # Helper: for a given tau, relabel outcomes and compute the difference in means for a single permutation
145
+ compute_diff_at_tau_for_oneW <- function(Wprime, obsY, obsW, tau) {
146
+ # Y0_under_null = obsY - obsW * tau
147
+ Y0 <- obsY - obsW * tau
148
+ # Y1_under_null = Y0 + tau
149
+ # But in practice, for assignment Wprime, the observed outcome is:
150
+ # Y'(i) = Y0(i) if Wprime(i) = 0, or Y0(i) + tau if Wprime(i)=1
151
+ Yprime <- Y0
152
+ Yprime[Wprime == 1] <- Y0[Wprime == 1] + tau
153
+ diff_in_means(Yprime, Wprime)
154
+ }
155
 
156
+ # 3a) For base R randomization test: difference in means + optional p-value
157
+ # *without* fiducial interval
158
+ # (We will incorporate the FI logic below.)
159
+ baseR_randomization_test <- function(obsW, obsY, allW, findFI = FALSE, alpha = 0.05) {
160
+ # Observed diff in means
161
+ tau_obs <- diff_in_means(obsY, obsW)
162
+
163
+ # for each candidate assignment, compute diff in means on obsY
164
+ diffs <- apply(allW, 1, function(w) diff_in_means(obsY, w))
165
+
166
+ # p-value = fraction whose absolute diff >= observed
167
+ pval <- mean(abs(diffs) >= abs(tau_obs))
168
+
169
+ # optionally compute a fiducial interval
170
+ FI <- NULL
171
+ if (findFI) {
172
+ FI <- baseR_find_fiducial_interval(obsW, obsY, allW, tau_obs, alpha = alpha)
173
+ }
174
+
175
+ list(p_value = pval, tau_obs = tau_obs, FI = FI)
176
+ }
177
 
178
+ # 3b) The fiducial interval logic for base R, mirroring the approach in fastrerandomize:
179
+ # 1) Attempt to find a wide lower and upper bracket via random updates
180
+ # 2) Then a grid search in [lowerBound-1, upperBound*2] for which tau are accepted.
181
+ baseR_find_fiducial_interval <- function(obsW, obsY, allW, tau_obs, alpha = 0.05, c_initial = 2,
182
+ n_search_attempts = 500) {
183
+
184
+ # random bracket approach
185
+ lowerBound_est <- tau_obs - 3*tau_obs
186
+ upperBound_est <- tau_obs + 3*tau_obs
187
+
188
+ z_alpha <- qnorm(1 - alpha)
189
+ k <- 2 / (z_alpha * (2 * pi)^(-1/2) * exp(-z_alpha^2 / 2))
190
+
191
+ # For each iteration, pick one random assignment from allW
192
+ # then see how the implied difference changes, and update the bracket
193
+ n_allW <- nrow(allW)
194
+ for (step_t in seq_len(n_search_attempts)) {
195
+ # pick random assignment
196
+ idx <- sample.int(n_allW, 1)
197
+ Wprime <- allW[idx, ]
198
+
199
+ # ~~~~~ update lowerBound ~~~~~
200
+ # Y0 = obsY - obsW * lowerBound_est
201
+ # Y'(Wprime) = ...
202
+ lowerY0 <- obsY - obsW * lowerBound_est
203
+ Yprime_lower <- lowerY0
204
+ Yprime_lower[Wprime == 1] <- lowerY0[Wprime == 1] + lowerBound_est
205
+
206
+ tau_at_step_lower <- diff_in_means(Yprime_lower, Wprime)
207
+
208
+ c_step <- c_initial
209
+ # difference from obs
210
+ delta <- tau_obs - tau_at_step_lower
211
+
212
+ if (tau_at_step_lower < tau_obs) {
213
+ # move lowerBound up
214
+ lowerBound_est <- lowerBound_est + k * delta * (alpha/2) / step_t
215
+ } else {
216
+ # move it down
217
+ lowerBound_est <- lowerBound_est - k * (-delta) * (1 - alpha/2) / step_t
218
+ }
219
+
220
+ # ~~~~~ update upperBound ~~~~~
221
+ upperY0 <- obsY - obsW * upperBound_est
222
+ Yprime_upper <- upperY0
223
+ Yprime_upper[Wprime == 1] <- upperY0[Wprime == 1] + upperBound_est
224
+
225
+ tau_at_step_upper <- diff_in_means(Yprime_upper, Wprime)
226
+ delta2 <- tau_at_step_upper - tau_obs
227
+
228
+ if (tau_at_step_upper > tau_obs) {
229
+ # move upperBound down
230
+ upperBound_est <- upperBound_est - k * delta2 * (alpha/2) / step_t
231
+ } else {
232
+ # move it up
233
+ upperBound_est <- upperBound_est + k * (-delta2) * (1 - alpha/2) / step_t
234
+ }
235
+ }
236
+
237
+ # Now we do a grid search from (lowerBound_est - 1) to (upperBound_est * 2)
238
+ # in e.g. 100 steps, seeing which tau is "accepted".
239
+ # We'll define "accepted" if the min of:
240
+ # fraction(tau_obs >= distribution_of(tau_pseudo))
241
+ # fraction(tau_obs <= distribution_of(tau_pseudo))
242
+ # is > alpha, i.e. do not reject
243
+ grid_lower <- lowerBound_est - 1
244
+ grid_upper <- upperBound_est * 2
245
+ tau_seq <- seq(grid_lower, grid_upper, length.out = 100)
246
+
247
+ accepted <- logical(length(tau_seq))
248
+ for (i in seq_along(tau_seq)) {
249
+ tau_pseudo <- tau_seq[i]
250
+ # for each row in allW, compute the diff in means if the true effect = tau_pseudo
251
+ # distribution_of(tau_pseudo)
252
+ diffs_pseudo <- apply(allW, 1, function(wp) compute_diff_at_tau_for_oneW(wp, obsY, obsW, tau_pseudo))
253
+ # Then see how often diffs_pseudo >= tau_obs (or <= tau_obs)
254
+ frac_ge <- mean(diffs_pseudo >= tau_obs)
255
+ frac_le <- mean(diffs_pseudo <= tau_obs)
256
+ # min(...) is the typical "two-sided" approach
257
+ accepted[i] <- (min(frac_ge, frac_le) > alpha / 2) # or 0.05 if we want 5% test
258
+ }
259
+
260
+ if (!any(accepted)) {
261
+ # no values accepted => degenerate?
262
+ # We'll return the bracket we found, or NA.
263
+ return(c(NA, NA))
264
+ }
265
+
266
+ c(min(tau_seq[accepted]), max(tau_seq[accepted]))
267
+ }
268
 
269
+ # ---------------------------------------------------------
270
+ # UI Section
271
+ # ---------------------------------------------------------
272
  ui <- dashboardPage(
 
 
273
 
274
+ # ========== Header =================
275
+ dashboardHeader(
276
+ title = tags$span(
277
+ "fastrerandomize Demo",
278
+ style = "font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif;"
279
+ )
280
+ ),
281
+
282
+ # ========== Sidebar ================
283
  dashboardSidebar(
284
  sidebarMenu(
285
+ menuItem("Data & Covariates", tabName = "datatab", icon = icon("database")),
286
+ menuItem("Generate Randomizations", tabName = "gennet", icon = icon("random")),
287
+ menuItem("Randomization Test", tabName = "randtest", icon = icon("flask"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  )
289
  ),
290
 
291
+ # ========== Body ===================
292
  dashboardBody(
293
+
294
+ # A little CSS to keep the design timeless and clean
295
  tags$head(
 
296
  tags$style(HTML("
297
+ .smalltext { font-size: 90%; color: #555; }
298
+ .shiny-output-error { color: red; }
299
+ .shiny-input-container { margin-bottom: 15px; }
300
  "))
301
  ),
302
+
303
  tabItems(
304
+
305
+ # ------------------------------------------------
306
+ # 1) Data & Covariates Tab
307
+ # ------------------------------------------------
308
  tabItem(
309
+ tabName = "datatab",
310
+
 
 
 
 
 
311
  fluidRow(
312
+ box(width = 5, title = "Covariate Data: Upload or Simulate",
313
+ status = "primary", solidHeader = TRUE,
314
+
315
+ radioButtons("data_source", "Data Source:",
316
+ choices = c("Upload CSV" = "upload",
317
+ "Simulate data" = "simulate"),
318
+ selected = "simulate"),
319
+
320
+ conditionalPanel(
321
+ condition = "input.data_source == 'upload'",
322
+ fileInput("file_covariates", "Choose CSV File",
323
+ accept = c(".csv")),
324
+ helpText("Columns = features/covariates, rows = units.")
325
+ ),
326
+
327
+ conditionalPanel(
328
+ condition = "input.data_source == 'simulate'",
329
+ numericInput("sim_n", "Number of units (rows)", value = 100, min = 2),
330
+ numericInput("sim_p", "Number of covariates (columns)", value = 50, min = 1),
331
+ actionButton("simulate_btn", "Simulate X")
332
+ )
333
  ),
334
+
335
+ box(width = 7, title = "Preview of Covariates (X)",
336
+ status = "info", solidHeader = TRUE,
337
+ DTOutput("covariates_table"))
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  )
339
  ),
340
 
341
+ # ------------------------------------------------
342
+ # 2) Generate Randomizations Tab
343
+ # ------------------------------------------------
344
  tabItem(
345
+ tabName = "gennet",
346
+
347
  fluidRow(
348
+ box(width = 4, title = "Rerandomization Parameters",
349
+ status = "primary", solidHeader = TRUE,
350
+
351
+ numericInput("n_treated", "Number Treated (n_treated)", value = 10, min = 1),
352
+ selectInput("random_type", "Randomization Type:",
353
+ choices = c("Monte Carlo" = "monte_carlo",
354
+ "Exact" = "exact"),
355
+ selected = "monte_carlo"),
356
+ numericInput("accept_prob", "Acceptance Probability (stringency)",
357
+ value = 0.01, min = 0.0001, max = 1),
358
+ conditionalPanel(
359
+ condition = "input.random_type == 'monte_carlo'",
360
+ numericInput("max_draws", "Max Draws (MC)", value = 1e5, min = 1e3),
361
+ numericInput("batch_size", "Batch Size (MC)", value = 1e3, min = 1e2)
362
+ ),
363
+ actionButton("generate_btn", "Generate Randomizations")
364
+ ),
365
+
366
+ box(width = 8, title = "Summary of Accepted Randomizations",
367
+ status = "info", solidHeader = TRUE,
368
+
369
+ # First row of boxes: accepted randomizations and min balance measure
370
+ fluidRow(
371
+ column(width = 6, valueBoxOutput("n_accepted_box", width = 12)),
372
+ column(width = 6, valueBoxOutput("balance_min_box", width = 12))
373
+ ),
374
+
375
+ # Second row of boxes: fastrerandomize time & base R time
376
+ fluidRow(
377
+ column(width = 6, valueBoxOutput("fastrerand_time_box", width = 12)),
378
+ column(width = 6, valueBoxOutput("baseR_time_box", width = 12))
379
+ ),
380
+
381
+ br(),
382
+ plotOutput("balance_hist", height = "250px"),
383
+
384
+ # Hardware info note
385
+ br(),
386
+ uiOutput("hardware_info")
387
  )
388
  )
389
  ),
390
 
391
+ # ------------------------------------------------
392
+ # 3) Randomization Test Tab
393
+ # ------------------------------------------------
394
  tabItem(
395
+ tabName = "randtest",
396
+
397
  fluidRow(
398
+ box(width = 4, title = "Randomization Test Setup",
399
+ status = "primary", solidHeader = TRUE,
400
+
401
+ radioButtons("outcome_source", "Outcome Data (Y):",
402
+ choices = c("Simulate Y" = "simulate",
403
+ "Upload CSV" = "uploadY"),
404
+ selected = "simulate"),
405
+
406
+ conditionalPanel(
407
+ condition = "input.outcome_source == 'simulate'",
408
+ numericInput("true_tau", "True Effect (simulate)", 1, step = 0.5),
409
+ numericInput("noise_sd", "Noise SD for Y", 0.5, step = 0.1),
410
+ actionButton("simulateY_btn", "Simulate Y")
411
+ ),
412
+ conditionalPanel(
413
+ condition = "input.outcome_source == 'uploadY'",
414
+ fileInput("file_outcomes", "Choose CSV File with outcome vector Y",
415
+ accept = c(".csv")),
416
+ helpText("Single column with length = #units.")
417
+ ),
418
+
419
+ br(),
420
+ actionButton("run_randtest_btn", "Run Randomization Test"),
421
+ checkboxInput("findFI", "Compute Fiducial Interval?", value = FALSE)
422
+ ),
423
+
424
+ box(width = 8, title = "Test Results", status = "info", solidHeader = TRUE,
425
+
426
+ # First row: p-value and observed effect (fastrerandomize)
427
+ fluidRow(
428
+ column(width = 6, valueBoxOutput("pvalue_box", width = 12)),
429
+ column(width = 6, valueBoxOutput("tauobs_box", width = 12))
430
+ ),
431
+
432
+ # Second row: fastrerandomize test time & base R test time
433
+ fluidRow(
434
+ column(width = 6, valueBoxOutput("fastrerand_test_time_box", width = 12)),
435
+ column(width = 6, valueBoxOutput("baseR_test_time_box", width = 12))
436
+ ),
437
+
438
+ # Show fastrerandomize FI
439
+ uiOutput("fi_text"),
440
+
441
+ # Now show Base R results in a separate row
442
+ tags$hr(),
443
+ fluidRow(
444
+ column(width = 6, valueBoxOutput("pvalue_box_baseR", width = 12)),
445
+ column(width = 6, valueBoxOutput("tauobs_box_baseR", width = 12))
446
+ ),
447
+ fluidRow(
448
+ column(width = 12, uiOutput("fi_text_baseR"))
449
+ ),
450
+
451
+ br(),
452
+ plotOutput("test_plot", height = "280px")
453
  )
454
  )
455
  )
456
+
457
+ ) # end tabItems
458
+ ) # end dashboardBody
459
+ ) # end dashboardPage
460
 
461
+ # ---------------------------------------------------------
462
+ # SERVER
463
+ # ---------------------------------------------------------
464
  server <- function(input, output, session) {
465
 
466
+ # -------------------------------------------------------
467
+ # 1. Covariate Data Handling
468
+ # -------------------------------------------------------
469
+ # We store the covariate matrix X in a reactiveVal for convenient reuse
470
+ X_data <- reactiveVal(NULL)
471
 
472
+ # Observe file input or simulation for X
473
+ observeEvent(input$file_covariates, {
474
+ req(input$file_covariates)
475
+ inFile <- input$file_covariates
476
+ df <- tryCatch(read.csv(inFile$datapath, header = TRUE),
477
+ error = function(e) NULL)
478
+ if (!is.null(df)) {
479
+ X_data(as.matrix(df))
480
+ }
481
  })
482
 
483
+ # If the user clicks "Simulate X"
484
+ observeEvent(input$simulate_btn, {
485
+ n <- input$sim_n
486
+ p <- input$sim_p
487
+ # Basic simulation of N(0,1) data
488
+ simX <- matrix(rnorm(n * p), nrow = n, ncol = p)
489
+ X_data(simX)
 
 
 
 
 
 
 
 
 
 
 
 
490
  })
491
 
492
+ # Show X in table
493
+ output$covariates_table <- renderDT({
494
+ req(X_data())
495
+ datatable(as.data.frame(X_data()),
496
+ options = list(scrollX = TRUE, pageLength = 5))
 
 
 
 
 
 
 
 
 
 
 
 
497
  })
498
 
499
+ # -------------------------------------------------------
500
+ # 2. Generate Rerandomizations
501
+ # -------------------------------------------------------
502
+ # We'll keep the accepted randomizations from fastrerandomize in RerandResult
503
+ # and from base R in RerandResult_base.
504
+ RerandResult <- reactiveVal(NULL)
505
+ RerandResult_base <- reactiveVal(NULL)
 
 
 
 
506
 
507
+ # We also store their run times
508
+ fastrand_time <- reactiveVal(NULL)
509
+ baseR_time <- reactiveVal(NULL)
510
+
511
+ observeEvent(input$generate_btn, {
512
+ req(X_data())
513
+ validate(
514
+ need(nrow(X_data()) >= input$n_treated,
515
+ "Number treated cannot exceed total units.")
516
+ )
517
+
518
+ withProgress(message = "Computing results...", value = 0, {
519
 
520
+ # =========== 1) fastrerandomize generation timing ===========
521
+ t0_fast <- Sys.time()
522
+ out <- tryCatch({
523
+ generate_randomizations(
524
+ n_units = nrow(X_data()),
525
+ n_treated = input$n_treated,
526
+ X = X_data(),
527
+ randomization_accept_prob= input$accept_prob,
528
+ randomization_type = input$random_type,
529
+ max_draws = if (input$random_type == "monte_carlo") input$max_draws else NULL,
530
+ batch_size = if (input$random_type == "monte_carlo") input$batch_size else NULL,
531
+ verbose = FALSE
532
+ )
533
+ }, error = function(e) e)
534
+ t1_fast <- Sys.time()
535
 
536
+ if (inherits(out, "error")) {
537
+ showNotification(paste("Error generating randomizations (fastrerandomize):", out$message), type = "error")
538
+ RerandResult(NULL)
 
 
 
539
  } else {
540
+ RerandResult(out)
 
541
  }
542
+ fastrand_time(difftime(t1_fast, t0_fast, units = "secs"))
543
+
544
+ # =========== 2) base R generation timing ===========
545
+ t0_base <- Sys.time()
546
+ out_base <- tryCatch({
547
+ baseR_generate_randomizations(
548
+ n_units = nrow(X_data()),
549
+ n_treated = input$n_treated,
550
+ X = X_data(),
551
+ accept_prob= input$accept_prob,
552
+ random_type= input$random_type,
553
+ max_draws = if (input$random_type == "monte_carlo") input$max_draws else NULL,
554
+ batch_size = if (input$random_type == "monte_carlo") input$batch_size else NULL
555
+ )
556
+ }, error = function(e) e)
557
+ t1_base <- Sys.time()
558
+
559
+ if (inherits(out_base, "error")) {
560
+ showNotification(paste("Error generating randomizations (base R):", out_base$message), type = "error")
561
+ RerandResult_base(NULL)
562
+ } else {
563
+ RerandResult_base(out_base)
564
+ }
565
+ baseR_time(difftime(t1_base, t0_base, units = "secs"))
566
+ })
567
+ })
568
+
569
+ # Summaries of accepted randomizations
570
+ output$n_accepted_box <- renderValueBox({
571
+ rr <- RerandResult()
572
+ if (is.null(rr) || is.null(rr$randomizations)) {
573
+ valueBox("0", "Accepted Randomizations", icon = icon("ban"), color = "red")
574
+ } else {
575
+ nAcc <- nrow(rr$randomizations)
576
+ valueBox(nAcc, "Accepted Randomizations", icon = icon("check"), color = "green")
577
  }
578
  })
579
 
580
+ output$balance_min_box <- renderValueBox({
581
+ rr <- RerandResult()
582
+ if (is.null(rr) || is.null(rr$balance)) {
583
+ valueBox("---", "Min Balance Measure", icon = icon("question"), color = "orange")
584
+ } else {
585
+ minBal <- round(min(rr$balance), 4)
586
+ valueBox(minBal, "Min Balance Measure", icon = icon("thumbs-up"), color = "blue")
 
587
  }
588
+ })
589
+
590
+ # Timings for generation: fastrerandomize and base R
591
+ output$fastrerand_time_box <- renderValueBox({
592
+ tm <- fastrand_time()
593
+ if (is.null(tm)) {
594
+ valueBox("---", "fastrerandomize generation time (secs)", icon = icon("clock"), color = "teal")
595
+ } else {
596
+ valueBox(round(as.numeric(tm), 3), "fastrerandomize generation time (secs)",
597
+ icon = icon("clock"), color = "teal")
598
+ }
599
+ })
600
+
601
+ output$baseR_time_box <- renderValueBox({
602
+ tm <- baseR_time()
603
+ if (is.null(tm)) {
604
+ valueBox("---", "base R generation time (secs)", icon = icon("clock"), color = "lime")
605
+ } else {
606
+ valueBox(round(as.numeric(tm), 3), "base R generation time (secs)",
607
+ icon = icon("clock"), color = "lime")
608
+ }
609
+ })
610
+
611
+ # Plot histogram of the balance measure (fastrerandomize result)
612
+ output$balance_hist <- renderPlot({
613
+ rr <- RerandResult()
614
+ req(rr, rr$balance)
615
+ df <- data.frame(balance = rr$balance)
616
+ ggplot(df, aes(x = balance)) +
617
+ geom_histogram(binwidth = diff(range(df$balance))/30, fill = "darkblue", alpha = 0.7) +
618
+ labs(title = "Distribution of Balance Measure",
619
+ x = "Balance (e.g. T^2)",
620
+ y = "Frequency") +
621
+ theme_minimal(base_size = 14)
622
+ })
623
+
624
+ # Hardware info (CPU cores, GPU note)
625
+ output$hardware_info <- renderUI({
626
+ num_cores <- detectCores(logical = TRUE)
627
+ HTML(paste(
628
+ "<strong>System Hardware Info:</strong><br/>",
629
+ "Number of CPU cores detected:", num_cores, "<br/>",
630
+ "With additional CPU or GPU, greater speedups can be expected."
631
+ ))
632
+ })
633
+
634
+ # -------------------------------------------------------
635
+ # 3. Randomization Test
636
+ # -------------------------------------------------------
637
+ Y_data <- reactiveVal(NULL)
638
+
639
+ # (A) If user simulates Y
640
+ observeEvent(input$simulateY_btn, {
641
+ req(RerandResult())
642
+ rr <- RerandResult()
643
+ if (is.null(rr$randomizations) || nrow(rr$randomizations) < 1) {
644
+ showNotification("No accepted randomizations found. Cannot simulate Y for the 'observed' assignment.", type = "error")
645
+ return(NULL)
646
  }
647
 
648
+ obsW <- rr$randomizations[1, ]
649
+ nunits <- length(obsW)
650
 
651
+ # Basic data generation: Y = X * beta + tau * W + noise
652
+ Xval <- X_data()
653
+ if (is.null(Xval)) {
654
+ showNotification("No covariate data found to help simulate outcomes. Using intercept-only model.", type="warning")
655
+ Xval <- matrix(0, nrow = nunits, ncol = 1)
656
+ }
657
+ # random coefficients
658
+ beta <- rnorm(ncol(Xval), 0, 1)
659
+ linear_part <- Xval %*% beta
660
+ Ysim <- as.numeric(linear_part + obsW * input$true_tau + rnorm(nunits, 0, input$noise_sd))
661
+
662
+ Y_data(Ysim)
663
  })
664
 
665
+ # (B) If user uploads Y
666
+ observeEvent(input$file_outcomes, {
667
+ req(input$file_outcomes)
668
+ inFile <- input$file_outcomes
669
+ dfy <- tryCatch(read.csv(inFile$datapath, header = FALSE), error=function(e) NULL)
670
+ if (!is.null(dfy)) {
671
+ if (ncol(dfy) > 1) {
672
+ showNotification("Please provide a single-column CSV for Y.", type="error")
673
+ } else {
674
+ Y_data(as.numeric(dfy[[1]]))
675
+ }
676
+ }
677
  })
678
 
679
+ # The randomization test result:
680
+ RandTestResult <- reactiveVal(NULL)
681
+ RandTestResult_base <- reactiveVal(NULL)
682
+
683
+ # We'll store their times:
684
+ fastrand_test_time <- reactiveVal(NULL)
685
+ baseR_test_time <- reactiveVal(NULL)
686
+
687
+ observeEvent(input$run_randtest_btn, {
688
+ withProgress(message = "Computing results...", value = 0, {
689
+
690
+ req(RerandResult())
691
+ rr <- RerandResult()
692
+ req(rr$randomizations)
693
+ if (is.null(Y_data())) {
694
+ showNotification("No outcome data Y found. Upload or simulate first.", type="error")
695
+ return(NULL)
696
+ }
697
+
698
+ obsW <- rr$randomizations[1, ]
699
+ obsY <- Y_data()
700
+
701
+ # =========== 1) fastrerandomize randomization_test timing ===========
702
+ t0_testfast <- Sys.time()
703
+ outTest <- tryCatch({
704
+ randomization_test(
705
+ obsW = obsW,
706
+ obsY = obsY,
707
+ candidate_randomizations = rr$randomizations,
708
+ findFI = input$findFI
709
+ )
710
+ }, error=function(e) e)
711
+ t1_testfast <- Sys.time()
712
+
713
+ if (inherits(outTest, "error")) {
714
+ showNotification(paste("Error in randomization_test (fastrerandomize):", outTest$message), type="error")
715
+ RandTestResult(NULL)
716
+ } else {
717
+ RandTestResult(outTest)
718
+ }
719
+ fastrand_test_time(difftime(t1_testfast, t0_testfast, units = "secs"))
720
+
721
+ # =========== 2) base R randomization test timing ===========
722
+ req(RerandResult_base())
723
+ rr_base <- RerandResult_base()
724
+ if (is.null(rr_base$randomizations) || nrow(rr_base$randomizations) < 1) {
725
+ showNotification("No base R randomizations found. Cannot run base R test.", type = "error")
726
+ RandTestResult_base(NULL)
727
+ return(NULL)
728
+ }
729
+
730
+ t0_testbase <- Sys.time()
731
+ outTestBase <- tryCatch({
732
+ baseR_randomization_test(
733
+ obsW = obsW,
734
+ obsY = obsY,
735
+ allW = rr_base$randomizations,
736
+ findFI = input$findFI # if user wants the FI, do so
737
+ )
738
+ }, error = function(e) e)
739
+ t1_testbase <- Sys.time()
740
+
741
+ if (inherits(outTestBase, "error")) {
742
+ showNotification(paste("Error in randomization_test (base R):", outTestBase$message), type="error")
743
+ RandTestResult_base(NULL)
744
+ } else {
745
+ RandTestResult_base(outTestBase)
746
+ }
747
+ baseR_test_time(difftime(t1_testbase, t0_testbase, units = "secs"))
748
+ })
749
  })
750
 
751
+ # Display p-value and observed tau (from the fastrerandomize test)
752
+ output$pvalue_box <- renderValueBox({
753
+ rt <- RandTestResult()
754
+ if (is.null(rt)) {
755
+ valueBox("---", "p-value (fastrerandomize)", icon = icon("question"), color = "blue")
756
+ } else {
757
+ valueBox(round(rt$p_value, 4), "p-value (fastrerandomize)", icon = icon("list-check"), color = "purple")
758
+ }
759
  })
760
 
761
+ output$tauobs_box <- renderValueBox({
762
+ rt <- RandTestResult()
763
+ if (is.null(rt)) {
764
+ valueBox("---", "Observed Effect (fastrerandomize)", icon = icon("question"), color = "maroon")
765
+ } else {
766
+ valueBox(round(rt$tau_obs, 4), "Observed Effect (fastrerandomize)", icon = icon("bullseye"), color = "maroon")
767
+ }
 
768
  })
769
 
770
+ # Times for randomization test
771
+ output$fastrerand_test_time_box <- renderValueBox({
772
+ tm <- fastrand_test_time()
773
+ if (is.null(tm)) {
774
+ valueBox("---", "fastrerandomize test time (secs)", icon = icon("clock"), color = "teal")
775
+ } else {
776
+ valueBox(round(as.numeric(tm), 3), "fastrerandomize test time (secs)",
777
+ icon = icon("clock"), color = "teal")
778
+ }
779
  })
780
 
781
+ output$baseR_test_time_box <- renderValueBox({
782
+ tm <- baseR_test_time()
783
+ if (is.null(tm)) {
784
+ valueBox("---", "base R test time (secs)", icon = icon("clock"), color = "lime")
785
+ } else {
786
+ valueBox(round(as.numeric(tm), 3), "base R test time (secs)",
787
+ icon = icon("clock"), color = "lime")
788
+ }
789
+ })
790
+
791
+ # If we have a fiducial interval from fastrerandomize, display it
792
+ output$fi_text <- renderUI({
793
+ rt <- RandTestResult()
794
+ if (is.null(rt) || is.null(rt$FI)) {
795
+ return(NULL)
796
+ }
797
+ fi_lower <- round(rt$FI[1], 4)
798
+ fi_upper <- round(rt$FI[2], 4)
799
+
800
+ tagList(
801
+ strong("Fiducial Interval (fastrerandomize, 95%):"),
802
+ p(sprintf("[%.4f, %.4f]", fi_lower, fi_upper))
803
  )
804
  })
805
 
806
+ # If we have a fiducial interval from base R, display it
807
+ output$fi_text_baseR <- renderUI({
808
+ rt <- RandTestResult_base()
809
+ if (is.null(rt) || is.null(rt$FI)) {
810
+ return(NULL)
 
 
811
  }
812
+ fi_lower <- round(rt$FI[1], 4)
813
+ fi_upper <- round(rt$FI[2], 4)
814
+
815
+ tagList(
816
+ strong("Fiducial Interval (base R, 95%):"),
817
+ p(sprintf("[%.4f, %.4f]", fi_lower, fi_upper))
 
 
 
818
  )
819
+ })
820
+
821
+ # A simple plot for the randomization distribution (for demonstration).
822
+ # In this app, we do not store the entire distribution from either method,
823
+ # so we simply show the observed effect as a point.
824
+ output$test_plot <- renderPlot({
825
+ rt <- RandTestResult()
826
+ if (is.null(rt)) {
827
+ plot.new()
828
+ title("No test results yet.")
829
+ return(NULL)
830
+ }
831
+ # Just display the observed effect from fastrerandomize
832
+ obs_val <- rt$tau_obs
833
 
834
+ ggplot(data.frame(x = obs_val, y = 0), aes(x, y)) +
835
+ geom_point(size=4, color="red") +
836
+ xlim(c(obs_val - abs(obs_val)*2 - 1, obs_val + abs(obs_val)*2 + 1)) +
837
+ labs(title = "Observed Treatment Effect (fastrerandomize)",
838
+ x = "Effect Size", y = "") +
 
 
 
 
839
  theme_minimal(base_size = 14) +
840
+ geom_vline(xintercept = 0, linetype="dashed", color="gray40")
841
  })
842
  }
843
 
844
+ # ---------------------------------------------------------
845
+ # Run the Application
846
+ # ---------------------------------------------------------
847
  shinyApp(ui = ui, server = server)