{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# SAGE Control UI via ngrok\n", "\n", "This notebook clones the repo if needed, starts the real FastAPI server, protects the control UI with `SAGE_WEB_PASSWORD`, and exposes it through ngrok.\n", "\n", "The public URL serves:\n", "\n", "- `GET /` for the browser control panel\n", "- `GET /health`\n", "- `POST /generate` on the GPU server\n" ], "id": "ff445bd1d37c1021" }, { "metadata": {}, "cell_type": "code", "source": [ "\"\"\"\n", "SAGE — 5 Billion Token Dataset Downloader\n", "==========================================\n", "Downloads ~5B tokens from free Hugging Face datasets and saves them\n", "as JSONL files in your data/raw/ directory, ready for the SAGE pipeline.\n", "\n", "Token budget breakdown:\n", " general_web.jsonl → 2.5B tokens (FineWeb)\n", " code.jsonl → 1.0B tokens (The Stack v2 - Python, JS, Rust, Go, C++)\n", " math_science.jsonl → 0.5B tokens (OpenWebMath)\n", " multilingual.jsonl → 0.5B tokens (Wikipedia 20+ languages)\n", " synthetic.jsonl → 0.5B tokens (OpenHermes instruction data)\n", " ─────────────────────────────────────\n", " TOTAL → ~5.0B tokens\n", "\n", "Usage:\n", " pip install datasets huggingface_hub tqdm\n", " python debug/download_5b_tokens.py --output-dir data/raw\n", " python debug/download_5b_tokens.py --output-dir data/raw --resume\n", "\"\"\"\n", "\n", "import argparse\n", "import json\n", "import sys\n", "import time\n", "from pathlib import Path\n", "\n", "missing = []\n", "try:\n", " from datasets import load_dataset\n", "except ImportError:\n", " missing.append(\"datasets\")\n", "try:\n", " from tqdm import tqdm\n", "except ImportError:\n", " missing.append(\"tqdm\")\n", "\n", "if missing:\n", " print(f\"[ERROR] Missing packages: {', '.join(missing)}\")\n", " print(f\" Run: pip install {' '.join(missing)}\")\n", " sys.exit(1)\n", "\n", "\n", "def estimate_tokens(text: str) -> int:\n", " return max(1, len(text) // 4)\n", "\n", "def human_tokens(n: int) -> str:\n", " if n >= 1_000_000_000:\n", " return f\"{n/1_000_000_000:.2f}B\"\n", " if n >= 1_000_000:\n", " return f\"{n/1_000_000:.1f}M\"\n", " return f\"{n:,}\"\n", "\n", "def human_bytes(n: int) -> str:\n", " for unit in [\"B\", \"KB\", \"MB\", \"GB\"]:\n", " if n < 1024:\n", " return f\"{n:.1f} {unit}\"\n", " n /= 1024\n", " return f\"{n:.1f} TB\"\n", "\n", "\n", "class JSONLWriter:\n", " def __init__(self, path: Path, target_tokens: int, resume: bool = False):\n", " self.path = path\n", " self.target_tokens = target_tokens\n", " self.tokens_written = 0\n", " self.records_written = 0\n", "\n", " if resume and path.exists():\n", " print(f\" [resume] Counting existing tokens in {path.name}...\")\n", " with open(path, \"r\", encoding=\"utf-8\") as f:\n", " for line in f:\n", " try:\n", " rec = json.loads(line)\n", " self.tokens_written += estimate_tokens(rec.get(\"text\", \"\"))\n", " self.records_written += 1\n", " except json.JSONDecodeError:\n", " pass\n", " print(f\" [resume] Already have {human_tokens(self.tokens_written)} / {human_tokens(target_tokens)}\")\n", " self._file = open(path, \"a\", encoding=\"utf-8\", buffering=1024 * 1024)\n", " else:\n", " path.parent.mkdir(parents=True, exist_ok=True)\n", " self._file = open(path, \"w\", encoding=\"utf-8\", buffering=1024 * 1024)\n", "\n", " @property\n", " def done(self) -> bool:\n", " return self.tokens_written >= self.target_tokens\n", "\n", " def write(self, record: dict) -> int:\n", " text = record.get(\"text\", \"\")\n", " if not text or len(text.strip()) < 50:\n", " return 0\n", " toks = estimate_tokens(text)\n", " self._file.write(json.dumps(record, ensure_ascii=False) + \"\\n\")\n", " self.tokens_written += toks\n", " self.records_written += 1\n", " return toks\n", "\n", " def close(self):\n", " self._file.flush()\n", " self._file.close()\n", "\n", " def __enter__(self): return self\n", " def __exit__(self, *_): self.close()\n", "\n", "\n", "def download_general_web(writer):\n", " print(\"\\n[1/5] general_web.jsonl — FineWeb\")\n", " bar = tqdm(total=writer.target_tokens, initial=writer.tokens_written,\n", " unit=\"tok\", unit_scale=True, desc=\" web tokens\")\n", " ds = load_dataset(\"HuggingFaceFW/fineweb\", name=\"sample-10BT\",\n", " split=\"train\", streaming=True)\n", " for sample in ds:\n", " if writer.done: break\n", " bar.update(writer.write({\"text\": sample[\"text\"], \"source\": \"fineweb\",\n", " \"url\": sample.get(\"url\", \"\"), \"language\": \"en\"}))\n", " bar.close()\n", " print(f\" ✓ {human_tokens(writer.tokens_written)} tokens | {writer.records_written:,} records\")\n", "\n", "\n", "def download_code(writer):\n", " print(\"\\n[2/5] code.jsonl — The Stack v2\")\n", " LANGUAGES = [(\"python\",\"Python\"),(\"javascript\",\"JavaScript\"),(\"typescript\",\"TypeScript\"),\n", " (\"rust\",\"Rust\"),(\"go\",\"Go\"),(\"cpp\",\"C++\"),(\"java\",\"Java\"),\n", " (\"bash\",\"Bash\"),(\"sql\",\"SQL\"),(\"html\",\"HTML\")]\n", " bar = tqdm(total=writer.target_tokens, initial=writer.tokens_written,\n", " unit=\"tok\", unit_scale=True, desc=\" code tokens\")\n", " tokens_per_lang = writer.target_tokens // len(LANGUAGES)\n", " for lang_id, lang_name in LANGUAGES:\n", " if writer.done: break\n", " lang_tokens = 0\n", " print(f\" → {lang_name}...\")\n", " try:\n", " ds = load_dataset(\"bigcode/the-stack-v2-train-smol-ids\",\n", " data_dir=f\"data/{lang_id}\", split=\"train\",\n", " streaming=True, trust_remote_code=True)\n", " for sample in ds:\n", " if writer.done or lang_tokens >= tokens_per_lang: break\n", " content = sample.get(\"content\", \"\") or sample.get(\"text\", \"\")\n", " if not content: continue\n", " t = writer.write({\"text\": content, \"source\": \"the_stack_v2\",\n", " \"language\": lang_id})\n", " bar.update(t); lang_tokens += t\n", " except Exception as e:\n", " print(f\" [warn] {lang_name} failed ({e}), skipping.\")\n", " bar.close()\n", " print(f\" ✓ {human_tokens(writer.tokens_written)} tokens | {writer.records_written:,} records\")\n", "\n", "\n", "def download_math(writer):\n", " print(\"\\n[3/5] math_science.jsonl — OpenWebMath\")\n", " bar = tqdm(total=writer.target_tokens, initial=writer.tokens_written,\n", " unit=\"tok\", unit_scale=True, desc=\" math tokens\")\n", " ds = load_dataset(\"open-web-math/open-web-math\", split=\"train\", streaming=True)\n", " for sample in ds:\n", " if writer.done: break\n", " bar.update(writer.write({\"text\": sample[\"text\"], \"source\": \"open_web_math\",\n", " \"url\": sample.get(\"url\", \"\")}))\n", " bar.close()\n", " print(f\" ✓ {human_tokens(writer.tokens_written)} tokens | {writer.records_written:,} records\")\n", "\n", "\n", "def download_multilingual(writer):\n", " print(\"\\n[4/5] multilingual.jsonl — Wikipedia (20 languages)\")\n", " LANGUAGES = [(\"en\",\"English\"),(\"es\",\"Spanish\"),(\"fr\",\"French\"),(\"de\",\"German\"),\n", " (\"zh\",\"Chinese\"),(\"ja\",\"Japanese\"),(\"pt\",\"Portuguese\"),(\"ar\",\"Arabic\"),\n", " (\"ru\",\"Russian\"),(\"hi\",\"Hindi\"),(\"it\",\"Italian\"),(\"ko\",\"Korean\"),\n", " (\"nl\",\"Dutch\"),(\"pl\",\"Polish\"),(\"sv\",\"Swedish\"),(\"tr\",\"Turkish\"),\n", " (\"vi\",\"Vietnamese\"),(\"id\",\"Indonesian\"),(\"uk\",\"Ukrainian\"),(\"fa\",\"Persian\")]\n", " bar = tqdm(total=writer.target_tokens, initial=writer.tokens_written,\n", " unit=\"tok\", unit_scale=True, desc=\" multilingual tokens\")\n", " tokens_per_lang = writer.target_tokens // len(LANGUAGES)\n", " for lang_code, lang_name in LANGUAGES:\n", " if writer.done: break\n", " lang_tokens = 0\n", " try:\n", " ds = load_dataset(\"wikimedia/wikipedia\", f\"20231101.{lang_code}\",\n", " split=\"train\", streaming=True, trust_remote_code=True)\n", " for sample in ds:\n", " if writer.done or lang_tokens >= tokens_per_lang: break\n", " text = sample.get(\"text\", \"\")\n", " if not text: continue\n", " t = writer.write({\"text\": text, \"source\": \"wikipedia\",\n", " \"language\": lang_code, \"title\": sample.get(\"title\",\"\")})\n", " bar.update(t); lang_tokens += t\n", " except Exception as e:\n", " print(f\"\\n [warn] Wikipedia {lang_name} failed: {e}\")\n", " bar.close()\n", " print(f\" ✓ {human_tokens(writer.tokens_written)} tokens | {writer.records_written:,} records\")\n", "\n", "\n", "def download_synthetic(writer):\n", " print(\"\\n[5/5] synthetic.jsonl — OpenHermes 2.5\")\n", " bar = tqdm(total=writer.target_tokens, initial=writer.tokens_written,\n", " unit=\"tok\", unit_scale=True, desc=\" synthetic tokens\")\n", " ds = load_dataset(\"teknium/OpenHermes-2.5\", split=\"train\", streaming=True)\n", " rounds = 0\n", " while not writer.done and rounds < 10:\n", " for sample in ds:\n", " if writer.done: break\n", " convs = sample.get(\"conversations\", [])\n", " parts = []\n", " for turn in convs:\n", " role, value = turn.get(\"from\",\"\"), turn.get(\"value\",\"\")\n", " if role == \"human\": parts.append(f\"### Instruction\\n{value}\")\n", " elif role == \"gpt\": parts.append(f\"### Response\\n{value}\")\n", " text = \"\\n\\n\".join(parts) or sample.get(\"text\",\"\")\n", " if not text: continue\n", " bar.update(writer.write({\"text\": text, \"source\": \"openhermes_2.5\",\n", " \"task\": \"instruction_following\"}))\n", " rounds += 1\n", " bar.close()\n", " print(f\" ✓ {human_tokens(writer.tokens_written)} tokens | {writer.records_written:,} records\")\n", "\n", "\n", "TARGETS = {\n", " \"general_web.jsonl\": 2_500_000_000,\n", " \"code.jsonl\": 1_000_000_000,\n", " \"math_science.jsonl\": 500_000_000,\n", " \"multilingual.jsonl\": 500_000_000,\n", " \"synthetic.jsonl\": 500_000_000,\n", "}\n", "DOWNLOADERS = {\n", " \"general_web.jsonl\": download_general_web,\n", " \"code.jsonl\": download_code,\n", " \"math_science.jsonl\": download_math,\n", " \"multilingual.jsonl\": download_multilingual,\n", " \"synthetic.jsonl\": download_synthetic,\n", "}\n", "\n", "\n", "def main():\n", " parser = argparse.ArgumentParser(description=\"Download ~5B tokens for SAGE training.\")\n", " parser.add_argument(\"--output-dir\", default=\"data/raw\")\n", " parser.add_argument(\"--resume\", action=\"store_true\")\n", " parser.add_argument(\"--only\", nargs=\"+\", choices=list(TARGETS.keys()))\n", " parser.add_argument(\"--scale\", type=float, default=1.0)\n", " args = parser.parse_args()\n", "\n", " out_dir = Path(args.output_dir)\n", " out_dir.mkdir(parents=True, exist_ok=True)\n", " files_to_run = args.only or list(TARGETS.keys())\n", " total_target = sum(int(TARGETS[f] * args.scale) for f in files_to_run)\n", "\n", " print(\"=\" * 60)\n", " print(\" SAGE — 5 Billion Token Downloader\")\n", " print(\"=\" * 60)\n", " print(f\" Output dir : {out_dir.resolve()}\")\n", " print(f\" Resume : {args.resume}\")\n", " print(f\" Scale : {args.scale}x\")\n", " print(f\" Target : {human_tokens(total_target)} tokens\")\n", " print(f\" Est. disk : ~{total_target // 40_000_000} GB\")\n", " print(\"=\" * 60)\n", "\n", " grand_start = time.time()\n", " grand_tokens = 0\n", "\n", " for filename in files_to_run:\n", " target = int(TARGETS[filename] * args.scale)\n", " with JSONLWriter(out_dir / filename, target, resume=args.resume) as writer:\n", " if writer.done:\n", " print(f\"\\n[skip] {filename} already complete ({human_tokens(writer.tokens_written)} tokens)\")\n", " grand_tokens += writer.tokens_written\n", " continue\n", " t0 = time.time()\n", " DOWNLOADERS[filename](writer)\n", " elapsed = time.time() - t0\n", " grand_tokens += writer.tokens_written\n", " size = (out_dir / filename).stat().st_size\n", " print(f\" Time: {elapsed/60:.1f} min | Size: {human_bytes(size)}\")\n", "\n", " elapsed_total = time.time() - grand_start\n", " print(\"\\n\" + \"=\" * 60)\n", " print(f\" DONE — {human_tokens(grand_tokens)} tokens downloaded\")\n", " print(f\" Total time: {elapsed_total/3600:.2f} hours\")\n", " print(f\" Files: {out_dir.resolve()}/\")\n", " print(\"=\" * 60)\n", "\n", "\n", "if __name__ == \"__main__\":\n", " main()" ], "id": "5751afbf64858f98", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "# Colab one-cell launcher for the real SAGE server\n", "# Before running:\n", "# 1. In Colab, open the Secrets panel (Key icon on the left) and add your NGROK_AUTHTOKEN\n", "# 2. If you want /generate, switch Colab to a T4 GPU runtime\n", "\n", "import os\n", "import sys\n", "import time\n", "import atexit\n", "import subprocess\n", "import importlib\n", "import secrets\n", "from pathlib import Path\n", "\n", "REPO_URL = \"https://huggingface.co/sage002/sage\"\n", "REPO_DIR = Path(\"/content/sage\")\n", "PORT = 8000\n", "RUN_GENERATE_SMOKE = False\n", "\n", "def run(cmd, cwd=None):\n", " print(\"+\", \" \".join(cmd))\n", " subprocess.run(cmd, cwd=cwd, check=True)\n", "\n", "# 1. Clone or update repo\n", "if not REPO_DIR.exists():\n", " run([\"git\", \"clone\", REPO_URL, str(REPO_DIR)])\n", "else:\n", " run([\"git\", \"-C\", str(REPO_DIR), \"pull\", \"--ff-only\"])\n", "\n", "# 2. Install dependencies\n", "run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", \"-U\", \"pip\"])\n", "run([\n", " sys.executable, \"-m\", \"pip\", \"install\", \"-q\",\n", " \"fastapi>=0.110.0\", \"uvicorn>=0.29.0\", \"python-multipart>=0.0.9\",\n", " \"pydantic>=2.7.0\", \"pyyaml>=6.0.1\", \"psutil>=5.9.8\",\n", " \"pyngrok>=7.2.0\", \"requests>=2.31.0\"\n", "])\n", "\n", "try:\n", " import torch\n", "except ImportError:\n", " run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", \"torch>=2.1.0\"])\n", " import torch\n", "\n", "# Refresh path caches so the cell can instantly import newly installed modules\n", "importlib.invalidate_caches()\n", "import requests\n", "from pyngrok import ngrok\n", "\n", "# 3. Retrieve Ngrok token securely via Colab Secrets (or fallback to environment variable)\n", "try:\n", " from google.colab import userdata\n", " NGROK_AUTHTOKEN = userdata.get(\"NGROK_AUTHTOKEN\")\n", "except Exception:\n", " NGROK_AUTHTOKEN = os.environ.get(\"NGROK_AUTHTOKEN\")\n", "\n", "if not NGROK_AUTHTOKEN:\n", " raise ValueError(\"Missing NGROK_AUTHTOKEN. Please add it to your Colab Secrets.\")\n", "\n", "# 4. Supply necessary SAGE environment variables for the server\n", "env = os.environ.copy()\n", "env[\"SAGE_WEB_PASSWORD\"] = env.get(\"SAGE_WEB_PASSWORD\") or secrets.token_urlsafe(12)\n", "env[\"SAGE_MODEL_CONFIG\"] = env.get(\"SAGE_MODEL_CONFIG\", \"configs/model/1b.yaml\")\n", "env[\"SAGE_CHECKPOINT_DIR\"] = env.get(\"SAGE_CHECKPOINT_DIR\", \"runs/sage-1b\")\n", "env[\"SAGE_TOKENIZER_MODEL\"] = env.get(\"SAGE_TOKENIZER_MODEL\", \"tokenizer/tokenizer.model\")\n", "\n", "USE_GPU_SERVER = torch.cuda.is_available()\n", "APP_TARGET = \"serve.server:app\" if USE_GPU_SERVER else \"serve.server_cpu:app\"\n", "\n", "print(f\"GPU available: {USE_GPU_SERVER}\")\n", "print(f\"Starting app target: {APP_TARGET}\")\n", "print(f\"SAGE_WEB_PASSWORD: {env['SAGE_WEB_PASSWORD']} <-- Use this to login to the IDE\")\n", "\n", "# 5. Start Uvicorn Server attached to the log file via Popen\n", "log_path = REPO_DIR / \"uvicorn.log\"\n", "log_file = open(log_path, \"w\", encoding=\"utf-8\")\n", "\n", "server_proc = subprocess.Popen(\n", " [\n", " sys.executable, \"-m\", \"uvicorn\",\n", " APP_TARGET,\n", " \"--host\", \"0.0.0.0\",\n", " \"--port\", str(PORT),\n", " ],\n", " cwd=str(REPO_DIR),\n", " env=env, # Required: Passes the SAGE environment variables to Uvicorn\n", " stdout=log_file,\n", " stderr=subprocess.STDOUT,\n", ")\n", "\n", "def cleanup():\n", " global server_proc, log_file\n", " print(\"Cleaning up...\")\n", " try:\n", " ngrok.disconnect(public_url)\n", " ngrok.kill()\n", " except Exception:\n", " pass\n", " if server_proc and server_proc.poll() is None:\n", " server_proc.terminate()\n", " try:\n", " server_proc.wait(timeout=10)\n", " except subprocess.TimeoutExpired:\n", " server_proc.kill()\n", " try:\n", " log_file.close()\n", " except Exception:\n", " pass\n", " print(\"Cleanup complete.\")\n", "\n", "atexit.register(cleanup)\n", "\n", "# 6. Wait for health check success\n", "health_url = f\"http://127.0.0.1:{PORT}/health\"\n", "for _ in range(60):\n", " if server_proc.poll() is not None:\n", " log_file.flush()\n", " raise RuntimeError(\"Uvicorn exited early.\\n\\n\" + log_path.read_text(encoding=\"utf-8\", errors=\"ignore\"))\n", " try:\n", " r = requests.get(health_url, timeout=2)\n", " if r.ok:\n", " print(\"Local health OK:\", r.json())\n", " break\n", " except Exception:\n", " pass\n", " time.sleep(2)\n", "else:\n", " log_file.flush()\n", " raise TimeoutError(\"Server did not become healthy.\\n\\n\" + log_path.read_text(encoding=\"utf-8\", errors=\"ignore\"))\n", "\n", "# 7. Start Ngrok HTTPs Tunnel\n", "try:\n", " ngrok.kill()\n", " ngrok.set_auth_token(NGROK_AUTHTOKEN)\n", " tunnel = ngrok.connect(addr=PORT, proto=\"http\", bind_tls=True) # Forces HTTPS UI which stops browser mixed-content blocks\n", " public_url = tunnel.public_url\n", "\n", " print(\"\\n============================================\")\n", " print(\" SAGE DASHBOARD \")\n", " print(\"==============================================\")\n", " print(f\"URL: {public_url}\")\n", " print(f\"PWD: {env['SAGE_WEB_PASSWORD']}\")\n", " print(\"==============================================\\n\")\n", "\n", " if USE_GPU_SERVER:\n", " print(\"Generate :\", f\"{public_url}/generate\")\n", " else:\n", " print(\"Wait: Generate is not available on CPU server in this repo\")\n", " print(\"Switch Colab to a GPU runtime if you want /generate.\")\n", "except Exception as e:\n", " print(\"Could not start Ngrok: \", e)\n", "\n", "\n", "# Optional /generate smoke test\n", "if USE_GPU_SERVER and RUN_GENERATE_SMOKE:\n", " print(\"\\nRunning /generate smoke test...\")\n", " try:\n", " resp = requests.post(\n", " f\"http://127.0.0.1:{PORT}/generate\",\n", " json={\"input_ids\": [1, 42, 99], \"max_new_tokens\": 4},\n", " timeout=300,\n", " )\n", " print(\"Generate response:\", resp.json())\n", " except Exception as e:\n", " print(\"Generate timeout or failure:\", e)\n", "\n", "\n", "print(f\"\\nServer log path: {log_path}\")\n", "print(\"The server will continuously run until you stop the Code Cell manually.\")\n" ], "id": "98ae55680033f413", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "outputs": [], "execution_count": null, "source": "python debug/download_5b_tokens.py --output-dir data/raw --scale 0.01", "id": "7cdcdbf0001d4933" } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.10" } }, "nbformat": 4, "nbformat_minor": 5 }