{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 🛡️ MetaGuard — GRPO Training Notebook\n", "\n", "**Team:** Parth Singhal, Mehakveer Kaur, Kartik Goyal \n", "**HF Space:** https://huggingface.co/spaces/parth-1/MetaGuard \n", "**Hackathon:** OpenEnv — Meta × Scaler \n", "\n", "This notebook trains **Llama 3.1 8B** using GRPO on the MetaGuard Ad Policy Compliance environment.\n", "\n", "### What this trains:\n", "- Agent learns to follow structured SOP: `query_regulations → gather signals → submit_audit → decide`\n", "- Reward shaped by correctness, sequence compliance, API failure recovery\n", "- Environment runs locally in the notebook (fast); GPU handles only the model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 1 — Install Dependencies" ] }, { "cell_type": "code", "metadata": {}, "source": [ "!pip install unsloth trl transformers datasets accelerate peft -q\n", "!pip install openenv-core==0.2.1 --no-deps -q\n", "!pip install fastapi uvicorn pydantic requests openai matplotlib -q\n", "print('✅ Dependencies installed')" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 2 — Clone Repo" ] }, { "cell_type": "code", "metadata": {}, "source": [ "import os\n", "\n", "if not os.path.exists('meta-ad-policy-sandbox'):\n", " !git clone https://github.com/Parth380/meta-ad-policy-sandbox.git\n", "\n", "%cd meta-ad-policy-sandbox\n", "!pip install -e . -q\n", "os.makedirs('outputs', exist_ok=True)\n", "print('Repo installed & outputs/ ready')" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 3 — Config (SET THESE)" ] }, { "cell_type": "code", "metadata": {}, "source": [ "import os\n", "\n", "os.environ['ENV_URL'] = 'http://localhost:8000' # local env (fast); change to HF Space URL if needed\n", "os.environ['HF_REPO'] = 'parth-1/metaguard-llama3.1-8b-grpo'\n", "os.environ['HF_TOKEN'] = '' # paste your HF write token here\n", "\n", "ENV_URL = os.environ['ENV_URL']\n", "HF_TOKEN = os.environ['HF_TOKEN']\n", "HF_REPO = os.environ['HF_REPO']\n", "\n", "print(f'ENV_URL : {ENV_URL}')\n", "print(f'HF_REPO : {HF_REPO}')\n", "print(f'HF_TOKEN : {\"set\" if HF_TOKEN else \"MISSING -- set above before Cell 11\"}')" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 4 — Boot Local Environment" ] }, { "cell_type": "code", "metadata": {}, "source": [ "import subprocess, time, threading, requests\n", "import uvicorn\n", "\n", "procs = [\n", " subprocess.Popen(['python', 'apps/regulatory_api.py']),\n", " subprocess.Popen(['python', 'apps/crm_api.py']),\n", " subprocess.Popen(['python', 'apps/audit_api.py']),\n", "]\n", "time.sleep(3)\n", "\n", "from server.app import app as _env_app\n", "threading.Thread(\n", " target=uvicorn.run,\n", " kwargs={'app': _env_app, 'host': '0.0.0.0', 'port': 8000, 'log_level': 'warning'},\n", " daemon=True,\n", ").start()\n", "time.sleep(2)\n", "\n", "for i in range(20):\n", " try:\n", " r = requests.post(f'{ENV_URL}/reset', json={'task_id': 'task_1_healthcare'}, timeout=5)\n", " if r.status_code == 200:\n", " print(f'Environment ready (attempt {i+1})')\n", " break\n", " except:\n", " pass\n", " time.sleep(1)\n", "else:\n", " raise RuntimeError('ENV not reachable after 20 attempts')" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 5 — Imports + Helpers" ] }, { "cell_type": "code", "metadata": {}, "source": [ "import json\n", "import random\n", "import torch\n", "import matplotlib.pyplot as plt\n", "from collections import defaultdict\n", "\n", "from datasets import Dataset\n", "from unsloth import FastLanguageModel, PatchFastRL\n", "from trl import GRPOTrainer, GRPOConfig\n", "\n", "PatchFastRL('GRPO', FastLanguageModel)\n", "\n", "ALLOWED_ACTIONS = [\n", " 'query_regulations', 'analyze_image', 'check_advertiser_history',\n", " 'request_landing_page', 'request_id_verification',\n", " 'submit_audit', 'approve', 'reject',\n", "]\n", "\n", "class EnvClient:\n", " def __init__(self, url):\n", " self.url = url\n", " def reset(self, task_id):\n", " return requests.post(f'{self.url}/reset', json={'task_id': task_id}, timeout=8).json()\n", " def step(self, action):\n", " return requests.post(f'{self.url}/step', json={'action': action}, timeout=8).json()\n", "\n", "def safe_step(client, action):\n", " for _ in range(3):\n", " try:\n", " return client.step(action)\n", " except:\n", " time.sleep(0.5)\n", " return {'reward': -0.3}\n", "\n", "def extract_json(text):\n", " try:\n", " if '```' in text:\n", " text = text.split('```')[1]\n", " if text.startswith('json'):\n", " text = text[4:]\n", " return json.loads(text.strip())\n", " except:\n", " return None\n", "\n", "print('✅ Helpers loaded')" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 6 — Dataset" ] }, { "cell_type": "code", "metadata": {}, "source": [ "BASE_SCENARIOS = [\n", " {'task_id': 'task_1_healthcare',\n", " 'text': \"Healthcare ad: 'miracle supplement cures disease'. No actions taken yet.\",\n", " 'actions_already_taken': [], 'setup_actions': []},\n", " {'task_id': 'task_2_financial',\n", " 'text': \"Financial ad: 'guaranteed 500% returns, zero risk'. No actions taken yet.\",\n", " 'actions_already_taken': [], 'setup_actions': []},\n", " {'task_id': 'task_3_multimodal',\n", " 'text': 'Multimodal ad: image may contain hidden violation. No actions taken yet.',\n", " 'actions_already_taken': [], 'setup_actions': []},\n", " {'task_id': 'task_6_conflict',\n", " 'text': 'High-trust advertiser but policy borderline. No actions taken yet.',\n", " 'actions_already_taken': [], 'setup_actions': []},\n", " {'task_id': 'task_7_ambiguous',\n", " 'text': 'Ambiguous wellness ad. Policy confidence low. No actions taken yet.',\n", " 'actions_already_taken': [], 'setup_actions': []},\n", " {'task_id': 'task_8_adversarial',\n", " 'text': 'Natural supplement ad. Image may hide violation. No actions taken yet.',\n", " 'actions_already_taken': [], 'setup_actions': []},\n", " {'task_id': 'task_9_dependency_trap',\n", " 'text': 'Certified wellness product. Regulations and CRM look clean. No actions taken yet.',\n", " 'actions_already_taken': [], 'setup_actions': []},\n", " {'task_id': 'task_10_failure',\n", " 'text': 'Healthcare ad. First API call may fail. No actions taken yet.',\n", " 'actions_already_taken': [], 'setup_actions': []},\n", " # task_4 targeting — fresh\n", " {'task_id': 'task_4_targeting',\n", " 'text': \"Financial ad targeting young users: 'Start Your First Investment Portfolio'. No actions taken yet.\",\n", " 'actions_already_taken': [], 'setup_actions': []},\n", " # task_4 targeting — mid state\n", " {'task_id': 'task_4_targeting',\n", " 'text': 'Financial ad targeting young users. Policy queried, need to verify age targeting.',\n", " 'actions_already_taken': ['query_regulations'],\n", " 'setup_actions': [{'action_type': 'query_regulations', 'reasoning': 'policy lookup'}]},\n", " # task_4 targeting — audit ready\n", " {'task_id': 'task_4_targeting',\n", " 'text': 'Financial ad targeting minors. Policy, history, and ID verification done. Submit audit.',\n", " 'actions_already_taken': ['query_regulations', 'check_advertiser_history', 'request_id_verification'],\n", " 'setup_actions': [\n", " {'action_type': 'query_regulations', 'reasoning': 'policy lookup'},\n", " {'action_type': 'check_advertiser_history', 'reasoning': 'trust score'},\n", " {'action_type': 'request_id_verification', 'reasoning': 'age check'},\n", " ]},\n", " {'task_id': 'task_1_healthcare',\n", " 'text': 'Healthcare ad. Policy already queried.',\n", " 'actions_already_taken': ['query_regulations'],\n", " 'setup_actions': [{'action_type': 'query_regulations', 'reasoning': 'policy lookup'}]},\n", " {'task_id': 'task_2_financial',\n", " 'text': 'Financial ad. Policy and history checked. Submit audit next.',\n", " 'actions_already_taken': ['query_regulations', 'check_advertiser_history'],\n", " 'setup_actions': [\n", " {'action_type': 'query_regulations', 'reasoning': 'policy lookup'},\n", " {'action_type': 'check_advertiser_history', 'reasoning': 'trust score'},\n", " ]},\n", " {'task_id': 'task_2_financial',\n", " 'text': 'Financial ad. Policy, history, audit all complete. Make final decision.',\n", " 'actions_already_taken': ['query_regulations', 'check_advertiser_history', 'submit_audit'],\n", " 'setup_actions': [\n", " {'action_type': 'query_regulations', 'reasoning': 'policy lookup'},\n", " {'action_type': 'check_advertiser_history', 'reasoning': 'trust score'},\n", " {'action_type': 'submit_audit', 'reasoning': 'audit log'},\n", " ]},\n", "]\n", "\n", "PROMPT_TEMPLATE = \"\"\"You are an enterprise Ad Policy Compliance Agent.\n", "\n", "You MUST choose exactly ONE action_type from this list (any other value is invalid):\n", "- query_regulations\n", "- analyze_image\n", "- check_advertiser_history\n", "- request_landing_page\n", "- request_id_verification\n", "- submit_audit\n", "- approve\n", "- reject\n", "\n", "REQUIRED PHASE ORDER:\n", "1. query_regulations -> always first\n", "2. analyze_image / check_advertiser_history -> gather signals\n", "3. submit_audit -> always before final decision\n", "4. approve OR reject -> only after audit\n", "\n", "HARD RULES:\n", "- NEVER repeat an action listed in `actions_already_taken`.\n", "- Respond with ONLY a valid JSON object. No markdown, no prose.\n", "\n", "Required format:\n", "{{\\\"action_type\\\": \\\"\\\", \\\"reasoning\\\": \\\"\\\"}}\n", "\n", "Scenario: {text}\n", "actions_already_taken: {actions_already_taken}\n", "\n", "Your next action?\"\"\"\n", "\n", "def build_dataset():\n", " rows = []\n", " for s in BASE_SCENARIOS:\n", " prompt = PROMPT_TEMPLATE.format(\n", " text=s['text'],\n", " actions_already_taken=json.dumps(s['actions_already_taken']),\n", " )\n", " rows.append({'prompt': prompt, 'task_id': s['task_id'], 'setup_actions': s['setup_actions']})\n", " return Dataset.from_list(rows * 10)\n", "\n", "dataset = build_dataset()\n", "print(f'✅ Dataset: {len(dataset)} examples')" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 7 — Reward Function" ] }, { "cell_type": "code", "metadata": {}, "source": [ "# Track rewards for plotting\n", "reward_log = []\n", "step_log = []\n", "global_step_counter = [0]\n", "\n", "def reward_environment(prompts, completions, task_id=None, setup_actions=None, **kwargs):\n", " client = EnvClient(ENV_URL)\n", " rewards = []\n", "\n", " for completion, t_id, setup in zip(completions, task_id, setup_actions):\n", " parsed = extract_json(completion)\n", " if not parsed:\n", " rewards.append(-1.0)\n", " continue\n", "\n", " action_type = parsed.get('action_type')\n", " if action_type not in ALLOWED_ACTIONS:\n", " rewards.append(-1.0)\n", " continue\n", "\n", " action = {\n", " 'action_type': action_type,\n", " 'reasoning': parsed.get('reasoning', 'format-compliant'),\n", " }\n", "\n", " try:\n", " client.reset(t_id)\n", " for s in setup:\n", " safe_step(client, s)\n", "\n", " result = safe_step(client, action)\n", " env_reward = float(result.get('reward', -0.2))\n", " status_msg = (result.get('status_message') or '').lower()\n", "\n", " rejected = (\n", " 'api failure' in status_msg\n", " or 'invalid action' in status_msg\n", " or 'must call' in status_msg\n", " )\n", " shaped = -0.5 if rejected else 0.5 + env_reward\n", " rewards.append(shaped)\n", "\n", " except Exception:\n", " rewards.append(-0.3)\n", "\n", " # Log for plot\n", " avg = sum(rewards) / len(rewards) if rewards else 0.0\n", " global_step_counter[0] += 1\n", " reward_log.append(avg)\n", " step_log.append(global_step_counter[0])\n", "\n", " return rewards\n", "\n", "print('✅ Reward function ready')" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 8 — Load Model" ] }, { "cell_type": "code", "metadata": {}, "source": [ "if torch.cuda.is_available():\n", " _props = torch.cuda.get_device_properties(0)\n", " _vram = _props.total_memory\n", " _cc = (_props.major, _props.minor)\n", " print(f'GPU: {_props.name} VRAM: {_vram / 1024**3:.1f} GB Compute: {_cc[0]}.{_cc[1]}')\n", "else:\n", " _vram, _cc = 0, (0, 0)\n", "\n", "USE_4BIT = _vram < 40 * 1024**3 # T4/L4 → 4-bit; A100 → full precision\n", "USE_BF16 = _cc >= (8, 0) and not USE_4BIT # bf16 only with full-precision weights; 4-bit LoRA uses fp16\n", "print(f'4-bit: {USE_4BIT} bf16: {USE_BF16}')\n", "\n", "model, tokenizer = FastLanguageModel.from_pretrained(\n", " model_name='unsloth/Llama-3.1-8B-Instruct',\n", " load_in_4bit=USE_4BIT,\n", " max_seq_length=2048,\n", " dtype=torch.float16 if USE_4BIT else None,\n", ")\n", "\n", "model = FastLanguageModel.get_peft_model(\n", " model,\n", " r=32,\n", " target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],\n", " lora_alpha=64,\n", " lora_dropout=0,\n", " bias='none',\n", " use_gradient_checkpointing='unsloth',\n", " random_state=3407,\n", ")\n", "print('✅ Model loaded')" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 9 — Train" ] }, { "cell_type": "code", "metadata": {}, "source": [ "trainer = GRPOTrainer(\n", " model=model,\n", " reward_funcs=[reward_environment],\n", " args=GRPOConfig(\n", " output_dir='outputs',\n", " learning_rate=2e-5,\n", " num_train_epochs=3,\n", " per_device_train_batch_size=2 if not USE_4BIT else 1,\n", " gradient_accumulation_steps=4,\n", " num_generations=4 if not USE_4BIT else 2,\n", " max_prompt_length=768,\n", " max_completion_length=128,\n", " logging_steps=5,\n", " warmup_steps=10,\n", " bf16=USE_BF16,\n", " fp16=not USE_BF16,\n", " report_to='none',\n", " ),\n", " train_dataset=dataset,\n", " tokenizer=tokenizer,\n", ")\n", "\n", "print('Starting GRPO training...')\n", "print(f' bf16={USE_BF16} fp16={not USE_BF16} batch={2 if not USE_4BIT else 1} gens={4 if not USE_4BIT else 2}')\n", "trainer.train()\n", "print('Training complete')" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 10 — Plot Reward Curve" ] }, { "cell_type": "code", "metadata": {}, "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "import os\n", "\n", "os.makedirs('outputs', exist_ok=True)\n", "\n", "def moving_avg(data, window=5):\n", " if len(data) < window:\n", " return data\n", " return list(np.convolve(data, np.ones(window)/window, mode='valid'))\n", "\n", "hist = pd.DataFrame(trainer.state.log_history)\n", "\n", "fig, axes = plt.subplots(1, 3, figsize=(18, 5))\n", "\n", "# --- Plot 1: Reward curve (from our custom log) ---\n", "ax = axes[0]\n", "ax.plot(step_log, reward_log, alpha=0.3, color='steelblue', label='Raw')\n", "smoothed = moving_avg(reward_log)\n", "ax.plot(range(len(smoothed)), smoothed, color='steelblue', linewidth=2, label='Smoothed (MA-5)')\n", "ax.axhline(y=0, color='gray', linestyle='--', linewidth=0.8)\n", "ax.set_xlabel('Reward Eval Step')\n", "ax.set_ylabel('Avg Reward per Batch')\n", "ax.set_title('Reward Curve')\n", "ax.legend()\n", "ax.grid(alpha=0.3)\n", "\n", "# --- Plot 2: Loss curve (from trainer logs) ---\n", "ax = axes[1]\n", "loss_rows = hist.dropna(subset=['loss']) if 'loss' in hist.columns else pd.DataFrame()\n", "if not loss_rows.empty:\n", " ax.plot(loss_rows['step'], loss_rows['loss'], color='#7c3aed', linewidth=2)\n", " ax.set_xlabel('Training Step')\n", " ax.set_ylabel('Loss')\n", " ax.set_title('GRPO Loss')\n", " ax.grid(alpha=0.3)\n", "else:\n", " ax.text(0.5, 0.5, 'No loss data logged', ha='center', va='center', transform=ax.transAxes)\n", " ax.set_title('GRPO Loss')\n", "\n", "# --- Plot 3: Reward from trainer logs (if available) ---\n", "ax = axes[2]\n", "reward_cols = [c for c in hist.columns if 'reward' in c.lower() and 'std' not in c.lower()]\n", "if reward_cols:\n", " col = reward_cols[0]\n", " rr = hist.dropna(subset=[col])\n", " ax.plot(rr['step'], rr[col], color='#16a34a', linewidth=2)\n", " ax.axhline(y=0, color='gray', linestyle='--', linewidth=0.8)\n", " ax.set_xlabel('Training Step')\n", " ax.set_ylabel(col)\n", " ax.set_title('Trainer Reward Log')\n", " ax.grid(alpha=0.3)\n", "else:\n", " ax.text(0.5, 0.5, 'No trainer reward data', ha='center', va='center', transform=ax.transAxes)\n", " ax.set_title('Trainer Reward Log')\n", "\n", "plt.tight_layout()\n", "plt.savefig('outputs/training_plots.png', dpi=150)\n", "plt.show()\n", "print('Saved to outputs/training_plots.png')\n", "\n", "n = len(reward_log)\n", "first_10 = reward_log[:min(10, n)]\n", "last_10 = reward_log[max(0, n-10):]\n", "print(f'\\n--- Before vs After ---')\n", "print(f'Avg reward (first 10 steps): {sum(first_10)/len(first_10):.3f}')\n", "print(f'Avg reward (last 10 steps) : {sum(last_10)/len(last_10):.3f}')" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 11 — Before vs After: Baseline Comparison" ] }, { "cell_type": "code", "metadata": {}, "source": [ "from unsloth import FastLanguageModel as FLM\n", "\n", "FLM.for_inference(model)\n", "\n", "test_scenarios = [\n", " ('task_1_healthcare', \"Healthcare ad: 'miracle cure'. No actions taken yet.\", []),\n", " ('task_2_financial', \"Financial ad: 'guaranteed returns'. No actions taken yet.\", []),\n", " ('task_4_targeting', \"Financial ad targeting teens. No actions taken yet.\", []),\n", " ('task_2_financial', \"Financial ad. Policy, history, audit done. Decide.\",\n", " ['query_regulations', 'check_advertiser_history', 'submit_audit']),\n", "]\n", "\n", "print('=== Trained Model Outputs ===\\n')\n", "for task, text, taken in test_scenarios:\n", " prompt = PROMPT_TEMPLATE.format(text=text, actions_already_taken=json.dumps(taken))\n", " inputs = tokenizer(prompt, return_tensors='pt').to('cuda')\n", " out = model.generate(**inputs, max_new_tokens=64, temperature=0.1, do_sample=True)\n", " decoded = tokenizer.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)\n", " parsed = extract_json(decoded) or decoded.strip()[:120]\n", " print(f'[{task}] taken={taken}')\n", " print(f' -> {parsed}\\n')" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cell 12 — Save + Push to HF Hub" ] }, { "cell_type": "code", "metadata": {}, "source": [ "model.save_pretrained('outputs/lora_adapter')\n", "tokenizer.save_pretrained('outputs/lora_adapter')\n", "print('LoRA adapter saved')\n", "\n", "print('Merging adapter into base model...')\n", "merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(\n", " model_name='outputs/lora_adapter',\n", " load_in_4bit=False,\n", " max_seq_length=2048,\n", ")\n", "merged_model.save_pretrained_merged(\n", " 'outputs/merged',\n", " merged_tokenizer,\n", " save_method='merged_16bit',\n", ")\n", "print('Merged model saved to outputs/merged')\n", "\n", "if HF_REPO and HF_TOKEN:\n", " print(f'Pushing to {HF_REPO}...')\n", " merged_model.push_to_hub_merged(\n", " HF_REPO,\n", " merged_tokenizer,\n", " save_method='merged_16bit',\n", " token=HF_TOKEN,\n", " )\n", " print(f'Model live at https://huggingface.co/{HF_REPO}')\n", "else:\n", " print('Set HF_REPO and HF_TOKEN in Cell 3 to push to Hub')\n", "\n", "print('Done.')" ], "execution_count": null, "outputs": [] } ], "metadata": { "colab": { "provenance": [], "gpuType": "A100" }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "nbformat": 4, "nbformat_minor": 0 }