ra / app.R
cjerzak's picture
Update app.R
dc52259 verified
# setwd("~/Dropbox/ImageDeconfoundAid/BrokenExperiment/ShinyApp/"); Sys.setenv(RETICULATE_PYTHON = "/Users/cjerzak/miniconda3/bin/python")
# app.R — Remote Audit: Design-Based Tests of Randomization with Satellite Imagery
# ==============================================================================
# Performs conditional randomization tests to audit experimental integrity
# using pre-treatment satellite imagery features (NDVI, Nightlight)
# ==============================================================================
# For Hugging Face deployment, set secrets GEE_PROJECT, GEE_EMAIL, GEE_KEY (the service account key JSON as string)
options(shiny.maxRequestSize = 50 * 1024^2)
options(error = NULL)
suppressPackageStartupMessages({
library(shiny)
library(bslib)
library(DT)
library(shinyWidgets)
library(reticulate)
library(dplyr)
library(xgboost)
library(future)
library(future.apply)
library(parallel)
})
# ============================================================================
# Helper Functions
# ============================================================================
.clamp01 <- function(p, eps = 1e-6) pmin(pmax(p, eps), 1 - eps)
.lik_improvement <- function(A, phat, a_bar) {
phat <- .clamp01(phat)
a_bar <- .clamp01(a_bar)
L <- sum(A * log(phat) + (1 - A) * log(1 - phat))
L0 <- sum(A * log(a_bar) + (1 - A) * log(1 - a_bar))
L - L0
}
.cf_xgboost_phat <- function(A, X, K = 5, folds = NULL, ntree = 300L, mtry = NULL) {
n <- length(A)
if (is.null(folds)) folds <- sample(rep(seq_len(K), length.out = n))
phat <- rep(NA_real_, n)
for (k in seq_len(K)) {
idx_te <- which(folds == k)
idx_tr <- which(folds != k)
A_tr <- A[idx_tr]; X_tr <- as.matrix(X[idx_tr, , drop = FALSE])
X_te <- as.matrix(X[idx_te, , drop = FALSE])
sdv <- apply(X_tr, 2, sd)
keep <- is.finite(sdv) & (sdv > 1e-12)
if (!any(keep)) {
phat[idx_te] <- mean(A_tr)
next
}
X_tr_k <- X_tr[, keep, drop = FALSE]
X_te_k <- X_te[, keep, drop = FALSE]
X_df <- as.data.frame(X_tr_k)
X_mat <- model.matrix(~ . -1, data = X_df)
y_num <- as.numeric(A_tr)
p <- ncol(X_mat)
mtry_use <- if (is.null(mtry)) max(1L, floor(sqrt(p))) else max(1L, min(as.integer(mtry), p))
colsample_frac <- min(1, as.numeric(mtry_use) / max(1, p))
params <- list(
objective = "binary:logistic",
eval_metric = "logloss",
eta = 0.1,
max_depth = 6,
subsample = 0.8,
colsample_bytree = colsample_frac,
nthread = parallel::detectCores()
)
dtrain <- xgboost::xgb.DMatrix(data = X_mat, label = y_num)
fit <- xgboost::xgb.train(
params = params,
data = dtrain,
nrounds = as.integer(ntree),
verbose = 0
)
phat[idx_te] <- predict(fit, X_te_k)
}
list(phat = .clamp01(phat), folds = folds)
}
.draw_assign_fixed_m <- function(n, m) {
A <- rep.int(0L, n)
A[sample.int(n, m)] <- 1L
A
}
remote_audit_crt <- function(A, X,
K = 5,
B = 1000,
seed = 123,
label = "",
xgboost_ntree = 300L,
xgboost_mtry = NULL) {
stopifnot(length(A) == nrow(X))
keep <- is.finite(A) & apply(as.matrix(X), 1, function(r) all(is.finite(r)))
A <- as.integer(A[keep])
X <- as.matrix(X[keep, , drop = FALSE])
n <- length(A)
a_bar <- .clamp01(mean(A))
m <- sum(A)
set.seed(seed)
folds <- sample(rep(seq_len(K), length.out = n))
obs <- .cf_xgboost_phat(A, X, K = K, folds = folds, ntree = xgboost_ntree, mtry = xgboost_mtry)
T_obs <- .lik_improvement(A, obs$phat, a_bar)
plan(multisession(workers = availableCores()))
T_null <- future_sapply(seq_len(B), future.seed = TRUE, FUN = function(b) {
A_b <- .draw_assign_fixed_m(n, m)
ph <- .cf_xgboost_phat(A_b, X, K = K, folds = folds, ntree = xgboost_ntree, mtry = xgboost_mtry)$phat
.lik_improvement(A_b, ph, a_bar)
})
plan(sequential)
pval <- (1 + sum(T_null >= T_obs)) / (B + 1)
list(
T_obs = T_obs,
T_null = T_null,
p_value = pval,
a_bar = a_bar,
n = n,
treated = m,
K = K,
B = B,
label = label,
learner = "xgboost"
)
}
# ============================================================================
# UI
# ============================================================================
theme <- bs_theme(bootswatch = "flatly")
ui <- page_sidebar(
tags$head(tags$title("Remote Audit")),
title = div(
span("Remote Audit", style = "font-weight:700;"),
span(" with Satellite Imagery", style = "color: #888;")
),
theme = theme,
sidebar = sidebar(
width = 360,
h5("Data Input"),
radioButtons("data_source", NULL,
choices = c("Upload CSV" = "upload",
"Use Example (Begum et al. 2022)" = "example"),
selected = "example"),
conditionalPanel(
"input.data_source == 'upload'",
fileInput("file_csv", "Upload CSV", accept = ".csv")
),
h5("Audit Configuration"),
selectInput("audit_type", "Audit Type",
choices = c("Randomization" = "randomization",
"Missingness" = "missingness"),
selected = "randomization"),
conditionalPanel(
"input.audit_type == 'randomization'",
selectInput("treat_col", "Treatment Column", choices = NULL),
numericInput("control_val", "Control Value", value = 1, step = 1),
numericInput("treat_val", "Treatment Value", value = 2, step = 1)
),
conditionalPanel(
"input.audit_type == 'missingness'",
selectInput("missing_col", "Variable to Check", choices = NULL)
),
selectInput("lat_col", "Latitude Column", choices = NULL),
selectInput("long_col", "Longitude Column", choices = NULL),
numericInput("start_year", "Start Year", value = 2010, min = 1990, max = 2026),
numericInput("end_year", "End Year", value = 2011, min = 1990, max = 2026),
checkboxGroupInput("features", "Features",
choices = c("NDVI Median" = "ndvi_median",
"Nightlight Median" = "ntl_median"),
selected = c("ndvi_median", "ntl_median")),
h5("Parameters"),
numericInput("K", "K-Folds", value = 10, min = 2, max = 20),
numericInput("B", "Resamples", value = 1000, min = 100, max = 5000, step = 100),
numericInput("seed", "Random Seed", value = 987),
numericInput("ntree", "Number of Trees", value = 300, min = 50, max = 1000),
actionButton("run_audit", "Run Audit",
class = "btn-primary btn-lg",
icon = icon("play"),
style = "width: 100%;"),
tags$a(
href = "https://connorjerzak.com/linkorgs-summary/",
target = "_blank",
icon("circle-question"), " Technical Details"
)
),
layout_columns(
col_widths = c(12),
card(
card_header("Data Preview"),
card_body(
DTOutput("data_preview")
)
),
conditionalPanel(
"output.audit_complete",
card(
card_header("Audit Results"),
card_body(
uiOutput("results_summary"),
plotOutput("audit_plot", height = "400px"),
downloadButton("download_results", "Download Results",
class = "btn-success")
)
)
)
)
)
# ============================================================================
# Server
# ============================================================================
server <- function(input, output, session) {
audit_results <- reactiveVal(NULL)
data_loaded <- reactive({
if (input$data_source == "example") {
print(list.files())
{
load("./Islam2019_WithGeocodesAndSatData.Rdata")
return(data)
}
} else {
req(input$file_csv)
tryCatch({
read.csv(input$file_csv$datapath, stringsAsFactors = FALSE)
}, error = function(e) {
showNotification(paste("Error reading CSV:", e$message),
type = "error", duration = 10)
NULL
})
}
})
observe({
df <- data_loaded()
req(df)
cols <- names(df)
updateSelectInput(session, "treat_col", choices = cols,
selected = if ("begum_treat" %in% cols) "begum_treat" else cols[1])
updateSelectInput(session, "missing_col", choices = cols,
selected = cols[1])
updateSelectInput(session, "lat_col", choices = cols,
selected = grep("lat", cols, value = TRUE, ignore.case = TRUE)[1] %||% NULL)
updateSelectInput(session, "long_col", choices = cols,
selected = grep("lon|long", cols, value = TRUE, ignore.case = TRUE)[1] %||% NULL)
})
observeEvent(input$data_source, {
if (input$data_source == "upload" && is.null(input$file_csv)) {
updateSelectInput(session, "treat_col", choices = character(0))
updateSelectInput(session, "missing_col", choices = character(0))
updateSelectInput(session, "lat_col", choices = character(0))
updateSelectInput(session, "long_col", choices = character(0))
}
})
output$data_preview <- renderDT({
df <- data_loaded()
req(df)
datatable(
head(df, 100),
options = list(pageLength = 10, scrollX = TRUE, dom = 'tip'),
rownames = FALSE
)
})
observeEvent(input$run_audit, {
df <- data_loaded()
req(df)
missing_feats <- setdiff(input$features, names(df))
if (length(missing_feats) > 0) {
showNotification("Fetching satellite features from GEE...", type = "message")
req(input$lat_col %in% names(df), input$long_col %in% names(df))
if (input$start_year > input$end_year) {
showNotification("Start year must be <= end year", type = "error")
return()
}
gee_project <- Sys.getenv("GEE_PROJECT")
gee_email <- Sys.getenv("GEE_EMAIL")
gee_key <- Sys.getenv("GEE_KEY")
{
py_run_string("
import ee
import pandas as pd
import json
import os
def _ee_init(project, email=None, key_data=None):
# Normalize empty strings to None
project = project or None
email = (email or None) if email else None
key_data = (key_data or None) if key_data else None
# Prefer service account if provided; otherwise try ADC (non-interactive)
if key_data:
# key_data must be a JSON string
key_json = key_data if isinstance(key_data, str) else json.dumps(key_data)
credentials = ee.ServiceAccountCredentials(email, key_data=key_json)
ee.Initialize(credentials=credentials, project=project)
else:
try:
ee.Initialize(project=project)
except Exception as e:
raise RuntimeError(
'No service-account key provided and ADC not available. '
'Set GOOGLE_APPLICATION_CREDENTIALS or provide GEE_EMAIL+GEE_KEY.'
) from e
def satellite_stats(points, start, end, sample_scale=250):
# Build FeatureCollection from input points
feats = [ee.Feature(ee.Geometry.Point([float(p['lon']), float(p['lat'])]),
{'rowid': str(p['rowid'])}) for p in points]
fc = ee.FeatureCollection(feats)
# MODIS NDVI (scaled by 0.0001), mask to SummaryQA == 0
def mask_modis(img):
qa = img.select('SummaryQA')
return (img.updateMask(qa.eq(0))
.select('NDVI').multiply(0.0001)
.copyProperties(img, img.propertyNames()))
modis = (ee.ImageCollection('MODIS/061/MOD13Q1')
.filterDate(start, end)
.map(mask_modis))
ndvi_mean = modis.select('NDVI').mean().rename('ndvi_mean')
ndvi_median = modis.select('NDVI').median().rename('ndvi_median')
ndvi_max = modis.select('NDVI').max().rename('ndvi_max')
# Night lights: DMSP (pre-2014) and VIIRS (2012+)
dmsp = ee.ImageCollection('NOAA/DMSP-OLS/NIGHTTIME_LIGHTS').select('stable_lights')
#viirs = ee.ImageCollection('NOAA/VIIRS/DNB/MONTHLY_V1/VCMSLCFG').select('avg_rad') # old
viirs = ee.ImageCollection('NOAA/VIIRS/DNB/MONTHLY_V1/VCMCFG').select('avg_rad') # new
dmsp_window = dmsp.filterDate(start, end)
viirs_window = viirs.filterDate(start, end)
# Overlap window for simple cross-calibration (DMSP↔VIIRS)
#overlap_start = ee.Date('2012-01-01') # old
overlap_start = ee.Date('2012-04-01') # new
overlap_end = ee.Date('2013-12-31') # DMSP coverage effectively ends 2013
dmsp_ov_img = dmsp.filterDate(overlap_start, overlap_end).mean()
viirs_ov_img = viirs.filterDate(overlap_start, overlap_end).mean()
# Buffer features to form a region-of-interest for overlap means
fc_buffer = fc.map(lambda f: ee.Feature(f).buffer(5000))
region_geom = fc_buffer.geometry()
# Null-safe reducers via dictionaries + contains()
dmsp_ov_dict = ee.Dictionary(dmsp_ov_img.reduceRegion(
reducer=ee.Reducer.mean(), geometry=region_geom, scale=5000, maxPixels=1e13))
viirs_ov_dict = ee.Dictionary(viirs_ov_img.reduceRegion(
reducer=ee.Reducer.mean(), geometry=region_geom, scale=5000, maxPixels=1e13))
dmsp_global_mean = ee.Number(ee.Image(dmsp_ov_img).reduceRegion(
reducer=ee.Reducer.mean(), geometry=ee.Geometry.Rectangle([-180, -90, 180, 90]),
scale=50000, maxPixels=1e13).get('stable_lights'))
viirs_global_mean = ee.Number(ee.Image(viirs_ov_img).reduceRegion(
reducer=ee.Reducer.mean(), geometry=ee.Geometry.Rectangle([-180, -90, 180, 90]),
scale=50000, maxPixels=1e13).get('avg_rad'))
dmsp_has = dmsp_ov_dict.contains('stable_lights')
viirs_has = viirs_ov_dict.contains('avg_rad')
dmsp_use = ee.Number(ee.Algorithms.If(dmsp_has, dmsp_ov_dict.get('stable_lights'), dmsp_global_mean))
viirs_use = ee.Number(ee.Algorithms.If(viirs_has, viirs_ov_dict.get('avg_rad'), viirs_global_mean))
# Guard divide-by-zero; produce a VIIRS-per-DMSP scale factor (ratio)
k_viirs_per_dmsp = ee.Number(ee.Algorithms.If(dmsp_use.gt(0), viirs_use.divide(dmsp_use), 1))
# Build a merged NTL series in VIIRS-equivalent units
dmsp_equiv = dmsp_window.map(
lambda img: img.select('stable_lights').multiply(k_viirs_per_dmsp).rename('ntl').toFloat()
)
viirs_prep = viirs_window.map(
lambda img: img.select('avg_rad').rename('ntl').toFloat()
)
ntl_window = dmsp_equiv.merge(viirs_prep)
ntl_mean = ntl_window.mean().rename('ntl_mean')
ntl_median = ntl_window.median().rename('ntl_median')
ntl_max = ntl_window.max().rename('ntl_max')
# Stack all bands
stacked = (ndvi_mean
.addBands([ndvi_median, ndvi_max,
ntl_mean, ntl_median, ntl_max]))
# IMPORTANT: use the intended pixel size (meters) for sampling
samples = stacked.sampleRegions(collection=fc, properties=['rowid'], scale=sample_scale)
# Bring results client-side
info = samples.getInfo()
rows = []
for f in info.get('features', []):
p = f.get('properties', {}) or {}
rows.append({
'rowid': p.get('rowid'),
'ndvi_mean': p.get('ndvi_mean'),
'ndvi_median': p.get('ndvi_median'),
'ndvi_max': p.get('ndvi_max'),
'ntl_mean': p.get('ntl_mean'),
'ntl_median': p.get('ntl_median'),
'ntl_max': p.get('ntl_max')
})
return pd.DataFrame(rows)
")
}
py$`_ee_init`(project = gee_project, email = gee_email, key_data = gee_key)
df$rowid <- seq_len(nrow(df))
pts_all <- df %>%
filter(is.finite(!!sym(input$lat_col)), is.finite(!!sym(input$long_col))) %>%
transmute(rowid = as.character(rowid),
lon = !!sym(input$long_col),
lat = !!sym(input$lat_col))
if (nrow(pts_all) == 0) {
showNotification("No valid geocoordinates found", type = "error")
return()
}
start <- sprintf("%d-01-01", input$start_year)
end <- sprintf("%d-01-01", input$end_year + 1)
batch_size <- 200L
idx <- split(seq_len(nrow(pts_all)), ceiling(seq_len(nrow(pts_all)) / batch_size))
sat_all <- list()
for (ii in idx) {
chunk <- pts_all[ii, , drop = FALSE]
points <- lapply(seq_len(nrow(chunk)), function(i) {
list(
rowid = chunk$rowid[i],
lon = chunk$lon[i],
lat = chunk$lat[i]
)
})
df_chunk <- py$satellite_stats(points, start, end, as.integer(250))
if (!is.null(df_chunk) && nrow(df_chunk) > 0) {
sat_all[[length(sat_all) + 1L]] <- df_chunk
}
}
if (length(sat_all) > 0) {
sat_df <- bind_rows(sat_all) %>% mutate(rowid = as.integer(rowid))
df <- left_join(df, sat_df, by = "rowid") %>% select(-rowid)
} else {
showNotification("Failed to fetch satellite data", type = "error")
return()
}
missing_feats <- setdiff(input$features, names(df))
if (length(missing_feats) > 0) {
showNotification(paste("Could not fetch:", paste(missing_feats, collapse = ", ")), type = "error")
return()
}
}
withProgress(message = "Running audit...", value = 0, {
incProgress(0.2, detail = "Preparing data...")
if (input$audit_type == "randomization") {
req(input$treat_col)
if (!(input$treat_col %in% names(df))) {
showNotification("Treatment column not found", type = "error")
return()
}
tt <- df[[input$treat_col]]
mask <- (tt %in% c(input$control_val, input$treat_val))
if (sum(mask) == 0) {
showNotification("No units match control/treatment values",
type = "error")
return()
}
A <- ifelse(tt[mask] == input$treat_val, 1L, 0L)
X <- as.matrix(df[mask, input$features, drop = FALSE])
keep <- apply(X, 1, function(r) all(is.finite(r)))
A <- A[keep]
X <- X[keep, , drop = FALSE]
if (length(A) < 10) {
showNotification("Too few complete cases (need >= 10)",
type = "error")
return()
}
} else {
req(input$missing_col)
if (!(input$missing_col %in% names(df))) {
showNotification("Missing column not found", type = "error")
return()
}
R <- as.integer(!is.na(df[[input$missing_col]]))
if (all(R == 1)) {
showNotification(
"No missingness detected in selected variable. Audit cannot proceed.",
type = "warning", duration = 10
)
return()
}
if (all(R == 0)) {
showNotification(
"All values are missing. Audit cannot proceed.",
type = "warning", duration = 10
)
return()
}
A <- R
X <- as.matrix(df[, input$features, drop = FALSE])
keep <- apply(X, 1, function(r) all(is.finite(r)))
A <- A[keep]
X <- X[keep, , drop = FALSE]
}
incProgress(0.4, detail = "Running conditional randomization test...")
results <- tryCatch({
remote_audit_crt(
A = A,
X = X,
K = input$K,
B = input$B,
seed = input$seed,
label = ifelse(input$audit_type == "randomization",yes = input$treat_col, no = input$missing_col),
xgboost_ntree = input$ntree
)
}, error = function(e) {
showNotification(paste("Audit failed:", e$message),
type = "error", duration = 10)
NULL
})
incProgress(1.0, detail = "Complete!")
if (!is.null(results)) {
audit_results(results)
showNotification("Audit complete!", type = "message", duration = 3)
}
})
})
output$results_summary <- renderUI({
res <- audit_results()
req(res)
HTML(sprintf(
"<h4>%s Audit Results</h4>
<p><strong>Learner:</strong> %s</p>
<p><strong>Sample size:</strong> %d (Treated: %d, Control: %d)</p>
<p><strong>Test statistic (T):</strong> %.4f</p>
<p><strong>P-value:</strong> %.4f</p>
<p><strong>Interpretation:</strong> %s</p>",
tools::toTitleCase(input$audit_type),
toupper(res$learner),
res$n,
res$treated,
res$n - res$treated,
res$T_obs,
res$p_value,
if (res$p_value < 0.05) {
"⚠️ Assignment is MORE predictable from satellite features than expected under random assignment (p < 0.05). This suggests potential deviation from the stated randomization mechanism."
} else {
"✓ Assignment is NOT significantly more predictable from satellite features than expected under random assignment (p >= 0.05). No evidence of deviation detected."
}
))
})
if(FALSE){
output$audit_plot <- renderPlot({
res <- audit_results()
req(res)
hist(res$T_null, breaks = 50,
main = sprintf("%s Audit: %s Learner",
tools::toTitleCase(input$audit_type),
toupper(res$learner)),
xlab = "Out-of-sample log-likelihood improvement (T)",
ylab = "Count",
col = "lightblue",
border = "white")
abline(v = res$T_obs, col = "red", lwd = 3, lty = 2)
legend("topright",
legend = c("Null distribution", "Observed"),
col = c("lightblue", "red"),
lwd = c(10, 3),
lty = c(1, 2))
mtext(sprintf("n=%d, treated=%d (%.1f%%), B=%d, p=%.4f",
res$n, res$treated, 100 * res$a_bar, res$B, res$p_value),
side = 3, line = 0.5, cex = 0.9)
})
}
output$download_results <- downloadHandler(
filename = function() {
sprintf("remote_audit_results_%s.csv", format(Sys.time(), "%Y%m%d_%H%M%S"))
},
content = function(file) {
res <- audit_results()
req(res)
summary_df <- data.frame(
audit_type = input$audit_type,
learner = res$learner,
n = res$n,
treated = res$treated,
treatment_rate = res$a_bar,
K = res$K,
B = res$B,
T_observed = res$T_obs,
p_value = res$p_value,
seed = input$seed,
features = paste(input$features, collapse = ";")
)
write.csv(summary_df, file, row.names = FALSE)
}
)
output$audit_complete <- reactive({
!is.null(audit_results())
})
outputOptions(output, "audit_complete", suspendWhenHidden = FALSE)
}
shinyApp(ui, server)