neuroforge / db.py
dlokesha
Part 2 final: local plasticity mean rank 37.1 β€” continual backprop rank 33.1 confirms reactive methods worse than standard
c3a6dfa
"""
db.py β€” Supabase client for logging all experiment data.
Stores:
- runs β†’ accuracy curves, final numbers per experiment
- ablation β†’ whole/center/periphery breakdown
- spike_cache β†’ bio-preprocessed vectors (avoid recomputing)
- checkpoints β†’ references to saved HF model files
"""
import io
import os
import numpy as np
from supabase import create_client, Client
def get_client() -> Client:
url = os.environ["SUPABASE_URL"]
key = os.environ["SUPABASE_KEY"]
return create_client(url, key)
def log_run(
n_samples: int,
epochs: int,
baseline_curve: list,
bio_curve: list,
) -> str:
"""Insert a completed run and return its UUID."""
client = get_client()
payload = {
"n_samples": n_samples,
"epochs": epochs,
"baseline_final_acc": round(baseline_curve[-1], 6),
"bio_final_acc": round(bio_curve[-1], 6),
"improvement": round(bio_curve[-1] - baseline_curve[-1], 6),
"baseline_curve": baseline_curve,
"bio_curve": bio_curve,
}
response = client.table("runs").insert(payload).execute()
run_id = response.data[0]["id"]
print(f"Logged run β†’ {run_id}")
return run_id
def log_ablation(run_id: str, ablation_results: dict):
"""Insert ablation rows (whole/center/periphery) linked to a run."""
client = get_client()
rows = [
{"run_id": run_id, "region": region, "accuracy": round(acc, 6)}
for region, acc in ablation_results.items()
]
client.table("ablation_results").insert(rows).execute()
print(f"Logged ablation results for run {run_id}")
def cache_spike_vectors(split: str, n_samples: int, vectors: np.ndarray):
"""
Serialize and store spike-rate vectors to Supabase.
Avoids rerunning the slow reservoir computation on every training run.
Args:
split: 'train' or 'test'
n_samples: number of samples (used as cache key)
vectors: (N, n_units) float array
"""
client = get_client()
# Serialize numpy array to bytes
buf = io.BytesIO()
np.save(buf, vectors)
vector_bytes = buf.getvalue()
# Check if cache already exists for this split+size
existing = (
client.table("spike_cache")
.select("id")
.eq("split", split)
.eq("n_samples", n_samples)
.execute()
)
if existing.data:
# Update existing
client.table("spike_cache").update({"vectors": vector_bytes.hex()}).eq(
"split", split
).eq("n_samples", n_samples).execute()
print(f"Updated spike cache: {split} / {n_samples} samples")
else:
# Insert new
client.table("spike_cache").insert(
{"split": split, "n_samples": n_samples, "vectors": vector_bytes.hex()}
).execute()
print(f"Saved spike cache: {split} / {n_samples} samples")
def load_spike_vectors(split: str, n_samples: int) -> np.ndarray | None:
"""
Load cached spike-rate vectors from Supabase.
Returns None if no cache exists for this split+size.
"""
client = get_client()
response = (
client.table("spike_cache")
.select("vectors")
.eq("split", split)
.eq("n_samples", n_samples)
.execute()
)
if not response.data:
return None
vector_bytes = bytes.fromhex(response.data[0]["vectors"])
buf = io.BytesIO(vector_bytes)
vectors = np.load(buf)
print(f"Loaded spike cache from Supabase: {split} / {n_samples} samples")
return vectors
def log_checkpoint(run_id: str, model_type: str, hf_repo: str, final_acc: float):
"""Record that a model checkpoint was saved to HuggingFace."""
client = get_client()
client.table("checkpoints").insert(
{
"run_id": run_id,
"model_type": model_type,
"hf_repo": hf_repo,
"final_acc": round(final_acc, 6),
}
).execute()
print(f"Logged checkpoint: {model_type} β†’ {hf_repo}")
def fetch_all_runs() -> list[dict]:
"""Fetch all runs for display in the Gradio dashboard."""
client = get_client()
response = (
client.table("runs")
.select("*")
.order("created_at", desc=True)
.execute()
)
return response.data