{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Orbital LoRA - MRPC Benchmark Example\n", "\n", "**Expected:** performance parity with baseline + adaptive behavior\n" ] }, { "cell_type": "code", "metadata": {}, "source": [ "!pip install -q transformers datasets evaluate scikit-learn accelerate" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "import torch\n", "from datasets import load_dataset\n", "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n", "from torch.utils.data import DataLoader\n", "import evaluate\n", "\n", "import sys\n", "sys.path.append('..')\n", "\n", "from nested_lora import inject_nested_lora\n", "from orbital_controller import OrbitalController\n", "from controller import set_rank\n", "\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "print(device)" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "dataset = load_dataset('glue','mrpc')\n", "tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')\n", "\n", "def tok(x):\n", " return tokenizer(x['sentence1'], x['sentence2'], truncation=True, padding='max_length', max_length=128)\n", "\n", "train = dataset['train'].map(tok, batched=True)\n", "val = dataset['validation'].map(tok, batched=True)\n", "\n", "train.set_format(type='torch', columns=['input_ids','attention_mask','label'])\n", "val.set_format(type='torch', columns=['input_ids','attention_mask','label'])\n", "\n", "train_loader = DataLoader(train, batch_size=16, shuffle=True)\n", "val_loader = DataLoader(val, batch_size=16)\n", "\n", "metric = evaluate.load('glue','mrpc')" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "def eval_model(model):\n", " model.eval()\n", " preds, labels = [], []\n", " with torch.no_grad():\n", " for b in val_loader:\n", " x=b['input_ids'].to(device)\n", " m=b['attention_mask'].to(device)\n", " y=b['label'].to(device)\n", " p=model(input_ids=x,attention_mask=m).logits.argmax(-1)\n", " preds.extend(p.cpu().numpy()); labels.extend(y.cpu().numpy())\n", " return metric.compute(predictions=preds,references=labels)['f1']" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "# BASELINE\n", "model = AutoModelForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2)\n", "model = inject_nested_lora(model,16).to(device)\n", "set_rank(model,16)\n", "\n", "opt = torch.optim.AdamW(model.parameters(), lr=5e-5)\n", "\n", "for step,b in enumerate(train_loader):\n", " if step>200: break\n", " x=b['input_ids'].to(device); m=b['attention_mask'].to(device); y=b['label'].to(device)\n", " loss=model(input_ids=x,attention_mask=m,labels=y).loss\n", " loss.backward(); opt.step(); opt.zero_grad()\n", "\n", "f1_base = eval_model(model)\n", "print('Baseline F1:', round(f1_base,3))" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "# ORBITAL\n", "model = AutoModelForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2)\n", "model = inject_nested_lora(model,16).to(device)\n", "\n", "ctrl = OrbitalController(warmup=10, stable_window=6)\n", "set_rank(model,4)\n", "\n", "opt = torch.optim.AdamW(model.parameters(), lr=5e-5)\n", "\n", "for step,b in enumerate(train_loader):\n", " if step>200: break\n", " x=b['input_ids'].to(device); m=b['attention_mask'].to(device); y=b['label'].to(device)\n", " loss=model(input_ids=x,attention_mask=m,labels=y).loss\n", " loss.backward()\n", "\n", " r = ctrl.step(loss.item())\n", " r = max(4,min(16,r))\n", " set_rank(model,r)\n", "\n", " opt.step(); opt.zero_grad()\n", "\n", "f1_orb = eval_model(model)\n", "print('Orbital F1:', round(f1_orb,3))" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "print('\\nBaseline:', round(f1_base,3))\n", "print('Orbital:', round(f1_orb,3))\n", "print('Delta:', round(f1_orb-f1_base,3))" ], "outputs": [], "execution_count": null } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 4 }