Spaces:
Runtime error
Runtime error
dlokesha
Part 2 final: local plasticity mean rank 37.1 β continual backprop rank 33.1 confirms reactive methods worse than standard
c3a6dfa | """ | |
| db.py β Supabase client for logging all experiment data. | |
| Stores: | |
| - runs β accuracy curves, final numbers per experiment | |
| - ablation β whole/center/periphery breakdown | |
| - spike_cache β bio-preprocessed vectors (avoid recomputing) | |
| - checkpoints β references to saved HF model files | |
| """ | |
| import io | |
| import os | |
| import numpy as np | |
| from supabase import create_client, Client | |
| def get_client() -> Client: | |
| url = os.environ["SUPABASE_URL"] | |
| key = os.environ["SUPABASE_KEY"] | |
| return create_client(url, key) | |
| def log_run( | |
| n_samples: int, | |
| epochs: int, | |
| baseline_curve: list, | |
| bio_curve: list, | |
| ) -> str: | |
| """Insert a completed run and return its UUID.""" | |
| client = get_client() | |
| payload = { | |
| "n_samples": n_samples, | |
| "epochs": epochs, | |
| "baseline_final_acc": round(baseline_curve[-1], 6), | |
| "bio_final_acc": round(bio_curve[-1], 6), | |
| "improvement": round(bio_curve[-1] - baseline_curve[-1], 6), | |
| "baseline_curve": baseline_curve, | |
| "bio_curve": bio_curve, | |
| } | |
| response = client.table("runs").insert(payload).execute() | |
| run_id = response.data[0]["id"] | |
| print(f"Logged run β {run_id}") | |
| return run_id | |
| def log_ablation(run_id: str, ablation_results: dict): | |
| """Insert ablation rows (whole/center/periphery) linked to a run.""" | |
| client = get_client() | |
| rows = [ | |
| {"run_id": run_id, "region": region, "accuracy": round(acc, 6)} | |
| for region, acc in ablation_results.items() | |
| ] | |
| client.table("ablation_results").insert(rows).execute() | |
| print(f"Logged ablation results for run {run_id}") | |
| def cache_spike_vectors(split: str, n_samples: int, vectors: np.ndarray): | |
| """ | |
| Serialize and store spike-rate vectors to Supabase. | |
| Avoids rerunning the slow reservoir computation on every training run. | |
| Args: | |
| split: 'train' or 'test' | |
| n_samples: number of samples (used as cache key) | |
| vectors: (N, n_units) float array | |
| """ | |
| client = get_client() | |
| # Serialize numpy array to bytes | |
| buf = io.BytesIO() | |
| np.save(buf, vectors) | |
| vector_bytes = buf.getvalue() | |
| # Check if cache already exists for this split+size | |
| existing = ( | |
| client.table("spike_cache") | |
| .select("id") | |
| .eq("split", split) | |
| .eq("n_samples", n_samples) | |
| .execute() | |
| ) | |
| if existing.data: | |
| # Update existing | |
| client.table("spike_cache").update({"vectors": vector_bytes.hex()}).eq( | |
| "split", split | |
| ).eq("n_samples", n_samples).execute() | |
| print(f"Updated spike cache: {split} / {n_samples} samples") | |
| else: | |
| # Insert new | |
| client.table("spike_cache").insert( | |
| {"split": split, "n_samples": n_samples, "vectors": vector_bytes.hex()} | |
| ).execute() | |
| print(f"Saved spike cache: {split} / {n_samples} samples") | |
| def load_spike_vectors(split: str, n_samples: int) -> np.ndarray | None: | |
| """ | |
| Load cached spike-rate vectors from Supabase. | |
| Returns None if no cache exists for this split+size. | |
| """ | |
| client = get_client() | |
| response = ( | |
| client.table("spike_cache") | |
| .select("vectors") | |
| .eq("split", split) | |
| .eq("n_samples", n_samples) | |
| .execute() | |
| ) | |
| if not response.data: | |
| return None | |
| vector_bytes = bytes.fromhex(response.data[0]["vectors"]) | |
| buf = io.BytesIO(vector_bytes) | |
| vectors = np.load(buf) | |
| print(f"Loaded spike cache from Supabase: {split} / {n_samples} samples") | |
| return vectors | |
| def log_checkpoint(run_id: str, model_type: str, hf_repo: str, final_acc: float): | |
| """Record that a model checkpoint was saved to HuggingFace.""" | |
| client = get_client() | |
| client.table("checkpoints").insert( | |
| { | |
| "run_id": run_id, | |
| "model_type": model_type, | |
| "hf_repo": hf_repo, | |
| "final_acc": round(final_acc, 6), | |
| } | |
| ).execute() | |
| print(f"Logged checkpoint: {model_type} β {hf_repo}") | |
| def fetch_all_runs() -> list[dict]: | |
| """Fetch all runs for display in the Gradio dashboard.""" | |
| client = get_client() | |
| response = ( | |
| client.table("runs") | |
| .select("*") | |
| .order("created_at", desc=True) | |
| .execute() | |
| ) | |
| return response.data | |