3v324v23 commited on
Commit
daa0358
·
1 Parent(s): 91382db

updated docker

Browse files
Files changed (5) hide show
  1. .dockerignore +2 -1
  2. README.md +95 -70
  3. apps/start.sh +9 -0
  4. dockerfile +11 -7
  5. grpo_train.py +200 -102
.dockerignore CHANGED
@@ -1,3 +1,4 @@
1
  venv/
2
  __pycache__/
3
- *.pyc
 
 
1
  venv/
2
  __pycache__/
3
+ *.pyc
4
+ apps/start_all.bat
README.md CHANGED
@@ -1,110 +1,135 @@
1
 
2
- # MetaGuard: Enterprise Ad-Policy RL Sandbox
3
 
4
- [](https://www.google.com/search?q=https://github.com/openenv/openenv)
5
- [](https://opensource.org/licenses/MIT)
6
- [](https://www.python.org/)
7
- [](https://github.com/unslothai/unsloth)
8
 
9
- **MetaGuard** is a high-fidelity Reinforcement Learning (RL) environment designed to train and evaluate AI agents on complex, multi-step ad-policy moderation workflows. Developed for the **Meta x Scaler Hackathon**, this project tackles the challenge of ensuring LLM agents follow strict Standard Operating Procedures (SOPs) while navigating adversarial multimodal "traps."
 
 
10
 
11
- -----
12
 
13
- ## 🏆 Hackathon Submission Details
 
14
 
15
- - **Theme:** 3.1 (Multi-Step Reasoning & Policy Compliance)
16
- - **Bonus Track:** AI Scaler Lab
17
- - **Team Members:** Parth Singhal, Mehakveer Kaur, Kartik Goyal
18
 
19
- -----
 
20
 
21
- ## 🏗️ System Architecture: Distributed Microservices
22
 
23
- MetaGuard mimics a real-world enterprise ecosystem by decoupling environment logic from policy and data services. This ensures that the agent must interact with live APIs to gather context before making terminal decisions.
 
24
 
 
25
  ```mermaid
26
- flowchart LR
27
- A[Agent / LLM Policy] -->|/reset, /step| B[OpenEnv Environment Server :8000]
28
- B -->|query_regulations| C[Regulatory API :8001]
29
- B -->|check_history| D[CRM API :8002]
30
- B -->|submit_audit| E[Audit API :8003]
31
- B -->|observation + reward| A
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  ```
33
 
34
- ### Integrated Services
 
 
 
 
 
 
35
 
36
- * **Environment Hub (`:8000`)**: Orchestrates the episode lifecycle using **OpenEnv** and enforces procedural phase gates.
37
- * **Regulatory API (`:8001`)**: Provides category-specific policy constraints (e.g., Healthcare, Finance).
38
- * **Advertiser CRM (`:8002`)**: Manages trust scores and historical violation records to simulate risk-based decision-making.
39
- * **Audit API (`:8003`)**: Persists the "Chain of Thought" (CoT) and decision logs for full traceability.
40
 
41
- -----
 
42
 
43
- ## 🧠 Methodology: GRPO + Unsloth
 
 
 
 
44
 
45
- To move beyond simple instruction following, we utilize **Group Relative Policy Optimization (GRPO)** for training. This allows the model to optimize its decision-making based on relative performance within a group, eliminating the need for a separate Critic model.
46
 
47
- * **Efficiency:** Powered by **Unsloth**, enabling 8B model training on consumer-grade GPUs with a significantly reduced VRAM footprint.
48
- * **Live Environment Interaction:** The training loop interacts directly with the microservice stack, allowing the model to learn from real-time API feedback and reward signals.
49
- * **Critic-less RL:** GRPO calculates rewards based on group relative performance, ensuring stable and efficient policy updates.
50
 
51
- -----
 
 
 
52
 
53
- ## 🚦 Procedural Action Space & Reward Logic
 
 
 
54
 
55
- The environment enforces a strict **Standard Operating Procedure (SOP)**. Terminal actions (`approve`/`reject`) are blocked by "Phase Gates" until mandatory steps are completed.
56
 
57
- | Step | Action | Description | Requirement |
58
- | :--- | :--- | :--- | :--- |
59
- | 1 | `query_regulations` | Fetch category-specific policy constraints. | **Mandatory** |
60
- | 2 | `analyze_image` | Inspect visual assets for policy "dog whistles." | Required for Multimodal Tasks |
61
- | 3 | `submit_audit` | Log reasoning to the Audit API for traceability. | **Mandatory** |
62
- | 4 | `approve` / `reject` | Final terminal action. | Allowed after Gates 1-3 |
63
 
64
- **Reward Signal:** Correct decisions yield `+1.0`, while incorrect decisions or procedural violations (skipping a gate) result in heavy negative rewards (up to `-0.3` per violation).
 
 
 
 
65
 
66
- -----
67
 
68
  ## 🚀 Getting Started
69
 
70
- ### 1\. Setup Environment
71
-
72
  ```bash
73
- pip install -e .
 
74
  pip install -r requirements.txt
75
  ```
76
 
77
- ### 2\. Launch the Microservice Stack
78
-
79
  ```bash
80
- # Run the background services
81
- python apps/regulatory_api.py
82
- python apps/crm_api.py
83
- python apps/audit_api.py
84
-
85
- # Start the OpenEnv Hub
86
- uvicorn server.app:app --host 0.0.0.0 --port 8000
87
  ```
88
 
89
- ### 3\. Run GRPO Training
90
-
91
  ```bash
92
- python grpo_train.py
93
  ```
94
 
95
- -----
96
 
97
- ## 📊 Adversarial Task Families
98
-
99
- MetaGuard evaluates agents across four distinct challenge categories:
100
-
101
- * **Healthcare**: Unapproved medical claims and pharma violations.
102
- * **Financial**: Predatory services and high-pressure tactics.
103
- * **Multimodal**: Violations hidden within imagery (e.g., visual text bypass).
104
- * **Targeting**: Illegal demographic or age-restricted policy violations.
105
-
106
- -----
107
 
108
- ## 📜 License
109
 
110
- Distributed under the **MIT License**. See `LICENSE` for more information.
 
 
1
 
2
+ # 🚀 MetaGuard: Procedural RL for Automated Ad Moderation
3
 
4
+ > **Transforming "Black Box" AI into auditable, multi-step regulatory workflows.**
 
 
 
5
 
6
+ ![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)
7
+ ![Python 3.9+](https://img.shields.io/badge/Python-3.9%2B-blue.svg)
8
+ ![RL-Framework: GRPO](https://img.shields.io/badge/Framework-GRPO-success.svg)
9
 
10
+ ---
11
 
12
+ ## ⚠️ The Problem: "Single-Shot" Failures
13
+ Traditional AI moderation models treat policy enforcement as a simple classification task (Approve/Reject). This approach fails in enterprise environments because it lacks:
14
 
15
+ * **Traceability:** No explanation for *why* a decision was made.
16
+ * **Contextual Awareness:** Decisions are made without checking advertiser history or regional regulations.
17
+ * **Risk Management:** Approving high-risk content blindly without a verified audit trail.
18
 
19
+ ## ✅ The MetaGuard Solution
20
+ MetaGuard redefines moderation as a **step-by-step investigative process** powered by Reinforcement Learning. The agent is trained not just to provide the right answer, but to follow the **correct investigative procedure** required by global compliance standards.
21
 
22
+ ---
23
 
24
+ ## 🏗️ System Architecture
25
+ MetaGuard operates as a microservice ecosystem to simulate real-world API latency, data silos, and procedural constraints.
26
 
27
+ ### 🔄 Interaction Flow
28
  ```mermaid
29
+ graph LR
30
+ subgraph "Intelligent Agent"
31
+ A[RL Policy Agent]
32
+ end
33
+
34
+ subgraph "MetaGuard Core"
35
+ B(Environment Hub :8000)
36
+ end
37
+
38
+ subgraph "External Policy APIs"
39
+ C[[Regulatory API :8001]]
40
+ D[[CRM API :8002]]
41
+ E[[Audit API :8003]]
42
+ end
43
+
44
+ A -- "1. Action Selection" --> B
45
+ B -- "2. API Request" --> C
46
+ B -- "2. API Request" --> D
47
+
48
+ C -- "3. Policy Signal" --> B
49
+ D -- "3. Trust Score" --> B
50
+
51
+ B -- "4. State Update + Reward" --> A
52
+
53
+ A -- "5. Final Decision" --> B
54
+ B -- "6. Immutable Log" --> E
55
  ```
56
 
57
+ ### 🗂️ Microservice Responsibility Map
58
+ | Service | Endpoint | Responsibility |
59
+ | :--- | :--- | :--- |
60
+ | **Core Env** | `:8000` | State orchestration & Reward calculation |
61
+ | **Regulatory API**| `:8001` | Dynamic policy lookup & legal constraints |
62
+ | **CRM API** | `:8002` | Advertiser historical risk & trust scoring |
63
+ | **Audit API** | `:8003` | Immutable logging for decision accountability |
64
 
65
+ ---
 
 
 
66
 
67
+ ## 🧠 Methodology: GRPO & Procedural RL
68
+ We utilize **Group Relative Policy Optimization (GRPO)** to train the agent. Unlike standard LLMs, our agent learns an optimal **Action Sequence**:
69
 
70
+ 1. 📥 **Ingest:** Fetch policy constraints via `query_regulations`.
71
+ 2. 🔍 **Inspect:** Scan creative assets via `analyze_image`.
72
+ 3. 🛡️ **Validate:** Cross-reference advertiser reliability via `check_advertiser_history`.
73
+ 4. 📝 **Certify:** Generate an immutable record via `submit_audit`.
74
+ 5. ⚖️ **Decide:** Execute final `approve` or `reject` action.
75
 
76
+ ---
77
 
78
+ ## 🎬 Evaluation Trace
79
+ We compare a baseline "Naive" agent against the MetaGuard trained agent to demonstrate procedural intelligence via our `demo.py` execution.
 
80
 
81
+ ### 📉 Scenario 1: Naive Agent
82
+ * **Behavior:** Attempts to approve content without performing due diligence.
83
+ * **Outcome:** Procedural penalties triggered; audit trail missing.
84
+ * **Final Compliance Rating:** `0/10` 🚨
85
 
86
+ ### 📈 Scenario 2: MetaGuard Agent
87
+ * **Behavior:** Systematically investigates all signals before acting.
88
+ * **Trace:** `REGULATIONS` ➔ `IMAGE_SCAN` ➔ `CRM_CHECK` ➔ `AUDIT_LOG` ➔ `REJECT`.
89
+ * **Final Compliance Rating:** `9/10` 🌟
90
 
91
+ ---
92
 
93
+ ## 📊 Performance Metrics
 
 
 
 
 
94
 
95
+ | Metric | Pre-Training (Naive) | Post-Training (MetaGuard) |
96
+ | :--- | :--- | :--- |
97
+ | **Success Rate** | 43% | **77%** |
98
+ | **Procedural Compliance** | 12% | **94%** |
99
+ | **Avg. Reward Score** | -2.1 | **+1.35** |
100
 
101
+ ---
102
 
103
  ## 🚀 Getting Started
104
 
105
+ ### 1. Environment Setup
 
106
  ```bash
107
+ git clone [https://github.com/Parth380/meta-ad-policy-sandbox.git](https://github.com/Parth380/meta-ad-policy-sandbox.git)
108
+ cd meta-ad-policy-sandbox
109
  pip install -r requirements.txt
110
  ```
111
 
112
+ ### 2. Launch Microservices
113
+ Open three separate terminal windows and start the mock API infrastructure:
114
  ```bash
115
+ python apps/regulatory_api.py # Port 8001
116
+ python apps/crm_api.py # Port 8002
117
+ python apps/audit_api.py # Port 8003
 
 
 
 
118
  ```
119
 
120
+ ### 3. Run the Evaluation Demo
 
121
  ```bash
122
+ python demo.py
123
  ```
124
 
125
+ ---
126
 
127
+ ## 🏆 Hackathon Submission Details
128
+ * **Theme:** 3.1 Multi-Step Reasoning & Policy Compliance
129
+ * **Bonus Track:** AI Scaler Lab
130
+ * **Team Members:** Parth Singhal, Mehakveer Kaur, Kartik Goyal
 
 
 
 
 
 
131
 
132
+ ---
133
 
134
+ ### 📜 License
135
+ This project is licensed under the MIT License.
apps/start.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Start the background microservices
4
+ python apps/regulatory_api.py &
5
+ python apps/crm_api.py &
6
+ python apps/audit_api.py &
7
+
8
+ # Start the main environment server in the foreground
9
+ uvicorn server.app:app --host 0.0.0.0 --port 8000
dockerfile CHANGED
@@ -1,17 +1,21 @@
1
- # Use a lightweight Python image
2
  FROM python:3.11-slim
3
 
4
- # Set the working directory
5
  WORKDIR /app
6
 
7
- # Copy all your project files into the container
8
  COPY . .
9
 
10
- # Install dependencies directly from the new pyproject.toml
11
  RUN pip install --no-cache-dir .
 
12
 
13
- # Expose the port Uvicorn uses
 
 
 
14
  EXPOSE 8000
15
 
16
- # Start the server, pointing it to the new folder structure!
17
- CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "8000"]
 
1
+ # 1. Use a lightweight Python image
2
  FROM python:3.11-slim
3
 
4
+ # 2. Set the working directory inside the container
5
  WORKDIR /app
6
 
7
+ # 3. Copy all your project files into the container
8
  COPY . .
9
 
10
+ # 4. Install dependencies
11
  RUN pip install --no-cache-dir .
12
+ RUN pip install -r requirements.txt
13
 
14
+ # 5. Make the startup script executable (Bypasses Windows permission errors)
15
+ RUN chmod +x apps/start.sh
16
+
17
+ # 6. Expose the port the main server uses
18
  EXPOSE 8000
19
 
20
+ # 7. Start all services using the bash script
21
+ CMD ["./apps/start.sh"]
grpo_train.py CHANGED
@@ -1,152 +1,250 @@
 
 
 
 
1
  import json
2
- import torch
3
  import requests
 
 
4
  from datasets import Dataset
5
  from unsloth import FastLanguageModel, PatchFastRL
6
  from trl import GRPOTrainer, GRPOConfig
7
 
8
- # MUST be called before trainer instantiation
9
  PatchFastRL("GRPO", FastLanguageModel)
10
 
11
- ENV_URL = "http://localhost:8000"
12
- TASKS = ["task_1_healthcare", "task_2_financial",
13
- "task_3_multimodal", "task_4_targeting"]
14
 
15
- SYSTEM_PROMPT = """You are an enterprise Ad Policy Compliance Agent.
16
- Always respond with ONLY valid JSON, no markdown.
17
 
18
- REQUIRED PHASE ORDER:
19
- 1. query_regulations — always first
20
- 2. analyze_image — required for multimodal tasks
21
- 3. submit_audit — always before final decision
22
- 4. approve or reject — only after audit
 
 
 
23
 
24
- Format: {"action_type": "<action>", "reasoning": "<reason>"}"""
 
 
25
 
26
- # ── DATASET ───────────────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  def build_dataset():
29
  rows = []
30
- for task_id in TASKS:
31
- res = requests.post(f"{ENV_URL}/reset", json={"task_id": task_id})
32
- obs = res.json()
33
- prompt = (
34
- f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
35
- f"{SYSTEM_PROMPT}<|eot_id|>"
36
- f"<|start_header_id|>user<|end_header_id|>\n"
37
- f"Task: {task_id}\n"
38
- f"Ad: {obs.get('headline','N/A')} — {obs.get('body_text','N/A')}\n"
39
- f"Trust Score: {obs.get('advertiser_trust_score','N/A')}\n"
40
- f"Status: {obs.get('status_message','')}\n"
41
- f"What is your next action?"
42
- f"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
43
- )
44
- rows.append({"prompt": prompt, "task_id": task_id})
45
- # 25x repetition = 100 rows, enough for 1 epoch
46
- return Dataset.from_list(rows * 25)
47
-
48
- # ── REWARD FUNCTION (actually calls the environment) ──────────────────────────
49
-
50
- def reward_environment(prompts, completions, task_id, **kwargs):
51
- """
52
- This is the real reward — model outputs an action,
53
- we send it to the environment, environment returns the reward.
54
- """
 
55
  rewards = []
56
- # Notice we zip with task_id (from the dataset) and use t_id inside the loop
57
- for completion, t_id in zip(completions, task_id):
58
- try:
59
- # Parse model output
60
- content = completion.strip()
61
- if content.startswith("```"):
62
- content = content.split("```")[1]
63
- if content.startswith("json"):
64
- content = content[4:]
65
- action = json.loads(content.strip())
66
- action_type = action.get("action_type", "query_regulations")
67
- except Exception:
68
- # Malformed JSON = penalty
69
- rewards.append(-0.5)
70
  continue
71
 
72
- try:
73
- # Fresh episode for each reward calculation
74
- requests.post(f"{ENV_URL}/reset", json={"task_id": t_id})
75
-
76
- # Run a minimal sequence: if model says query_regulations,
77
- # run that then check what reward it generates
78
- step_res = requests.post(
79
- f"{ENV_URL}/step",
80
- json={"action": {"action_type": action_type,
81
- "reasoning": action.get("reasoning", "")}},
82
- timeout=5
83
- )
84
- data = step_res.json()
85
- rewards.append(float(data.get("reward", -0.1)))
86
- except Exception:
87
- rewards.append(-0.1)
88
 
89
- return rewards
 
 
 
 
 
 
 
90
 
91
- def reward_json_format(prompts, completions, **kwargs):
92
- """Bonus reward for valid JSON output."""
93
- rewards = []
94
- for completion in completions:
95
  try:
96
- content = completion.strip()
97
- if content.startswith("```"):
98
- content = content.split("```")[1]
99
- if content.startswith("json"):
100
- content = content[4:]
101
- json.loads(content.strip())
102
- rewards.append(0.5)
103
- except Exception:
104
- rewards.append(-0.5)
 
 
 
 
 
105
  return rewards
106
 
107
- # ── MODEL SETUP ───────────────────────────────────────────────────────────────
 
 
108
 
109
  model, tokenizer = FastLanguageModel.from_pretrained(
110
  model_name="unsloth/Llama-3.1-8B-Instruct",
111
- max_seq_length=1024,
112
  load_in_4bit=True,
 
113
  )
 
114
  model = FastLanguageModel.get_peft_model(
115
  model,
116
  r=16,
117
  target_modules=["q_proj", "v_proj"],
118
  lora_alpha=16,
119
- lora_dropout=0.0,
120
- use_gradient_checkpointing="unsloth",
121
  )
122
 
123
- # ── TRAINER ───────────────────────────────────────────────────────────────────
 
 
124
 
125
  dataset = build_dataset()
126
 
127
  trainer = GRPOTrainer(
128
  model=model,
129
- reward_funcs=[reward_environment, reward_json_format],
130
  args=GRPOConfig(
131
- output_dir="outputs/meta-ad-agent",
132
  learning_rate=5e-6,
133
  num_train_epochs=1,
134
- per_device_train_batch_size=2,
135
- gradient_accumulation_steps=4,
 
136
  max_prompt_length=512,
137
- max_completion_length=128,
138
- num_generations=4, # lower = faster, enough for demo
139
- logging_steps=5,
140
- save_steps=50,
141
- report_to="none",
142
  ),
143
  train_dataset=dataset,
144
- tokenizer=tokenizer,
145
  )
146
 
 
 
 
 
147
  if __name__ == "__main__":
148
- print("Starting GRPO training — environment must be running on :8000")
 
 
149
  trainer.train()
150
- model.save_pretrained("outputs/meta-ad-agent-final")
151
- tokenizer.save_pretrained("outputs/meta-ad-agent-final")
152
- print("Done. Model saved to outputs/meta-ad-agent-final")
 
 
 
1
+ # grpo_train.py
2
+
3
+ import os
4
+ import time
5
  import json
6
+ import random
7
  import requests
8
+ import torch
9
+
10
  from datasets import Dataset
11
  from unsloth import FastLanguageModel, PatchFastRL
12
  from trl import GRPOTrainer, GRPOConfig
13
 
14
+ # 🔥 MUST come before trainer
15
  PatchFastRL("GRPO", FastLanguageModel)
16
 
17
+ # =========================
18
+ # CONFIG
19
+ # =========================
20
 
21
+ ENV_URL = os.getenv("ENV_URL", "http://localhost:8000")
 
22
 
23
+ ALLOWED_ACTIONS = [
24
+ "query_regulations",
25
+ "analyze_image",
26
+ "check_advertiser_history",
27
+ "submit_audit",
28
+ "approve",
29
+ "reject"
30
+ ]
31
 
32
+ # =========================
33
+ # HEALTH CHECK
34
+ # =========================
35
 
36
+ def ensure_env_ready():
37
+ for _ in range(20):
38
+ try:
39
+ r = requests.post(
40
+ f"{ENV_URL}/reset",
41
+ json={"task_id": "task_1_healthcare"},
42
+ timeout=5
43
+ )
44
+ if r.status_code == 200:
45
+ print("✅ Environment ready")
46
+ return
47
+ except:
48
+ pass
49
+ time.sleep(1)
50
+ raise RuntimeError("❌ ENV not reachable")
51
+
52
+ # =========================
53
+ # SAFE CLIENT
54
+ # =========================
55
+
56
+ class EnvClient:
57
+ def __init__(self, url):
58
+ self.url = url
59
+
60
+ def reset(self, task_id):
61
+ return requests.post(
62
+ f"{self.url}/reset",
63
+ json={"task_id": task_id},
64
+ timeout=8
65
+ ).json()
66
+
67
+ def step(self, action):
68
+ return requests.post(
69
+ f"{self.url}/step",
70
+ json={"action": action},
71
+ timeout=8
72
+ ).json()
73
+
74
+ def safe_step(client, action):
75
+ for _ in range(3):
76
+ try:
77
+ return client.step(action)
78
+ except:
79
+ time.sleep(0.5)
80
+ return {"reward": -0.3}
81
+
82
+ # =========================
83
+ # JSON PARSER
84
+ # =========================
85
+
86
+ def extract_json(text):
87
+ try:
88
+ if "```" in text:
89
+ text = text.split("```")[1]
90
+ if text.startswith("json"):
91
+ text = text[4:]
92
+ return json.loads(text.strip())
93
+ except:
94
+ return None
95
+
96
+ # =========================
97
+ # DATASET (WITH SETUP ACTIONS)
98
+ # =========================
99
+
100
+ BASE_SCENARIOS = [
101
+ # 🔹 Fresh state
102
+ {
103
+ "task_id": "task_1_healthcare",
104
+ "text": "Ad: miracle supplement cures disease. Initial review.",
105
+ "setup_actions": []
106
+ },
107
+
108
+ # 🔹 Mid state
109
+ {
110
+ "task_id": "task_1_healthcare",
111
+ "text": "Ad: pharma product. Policy already checked. Next step?",
112
+ "setup_actions": [
113
+ {"action_type": "query_regulations", "reasoning": "step1"}
114
+ ]
115
+ },
116
+
117
+ # 🔹 Late state
118
+ {
119
+ "task_id": "task_2_financial",
120
+ "text": "Ad: investment scheme. Policy + history checked. Final decision?",
121
+ "setup_actions": [
122
+ {"action_type": "query_regulations", "reasoning": "step1"},
123
+ {"action_type": "check_advertiser_history", "reasoning": "step2"}
124
+ ]
125
+ }
126
+ ]
127
 
128
  def build_dataset():
129
  rows = []
130
+
131
+ for s in BASE_SCENARIOS:
132
+ prompt = f"""
133
+ You are an Ad Policy Agent.
134
+
135
+ Respond ONLY JSON:
136
+ {{"action_type": "...", "reasoning": "..."}}
137
+
138
+ {s['text']}
139
+ Next action?
140
+ """
141
+ rows.append({
142
+ "prompt": prompt,
143
+ "task_id": s["task_id"],
144
+ "setup_actions": s["setup_actions"]
145
+ })
146
+
147
+ return Dataset.from_list(rows * 20) # small repeat
148
+
149
+ # =========================
150
+ # REWARD FUNCTION (FIXED)
151
+ # =========================
152
+
153
+ def reward_environment(prompts, completions, task_id=None, setup_actions=None, **kwargs):
154
+ client = EnvClient(ENV_URL)
155
+
156
  rewards = []
157
+
158
+ for completion, t_id, setup in zip(completions, task_id, setup_actions):
159
+
160
+ parsed = extract_json(completion)
161
+
162
+ if not parsed:
163
+ rewards.append(-1.0)
 
 
 
 
 
 
 
164
  continue
165
 
166
+ action_type = parsed.get("action_type")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
+ if action_type not in ALLOWED_ACTIONS:
169
+ rewards.append(-1.0)
170
+ continue
171
+
172
+ action = {
173
+ "action_type": action_type,
174
+ "reasoning": parsed.get("reasoning", "")
175
+ }
176
 
 
 
 
 
177
  try:
178
+ client.reset(t_id)
179
+
180
+ # 🔥 FAST-FORWARD STATE
181
+ for s in setup:
182
+ safe_step(client, s)
183
+
184
+ result = safe_step(client, action)
185
+
186
+ reward = float(result.get("reward", -0.2))
187
+ rewards.append(reward)
188
+
189
+ except:
190
+ rewards.append(-0.3)
191
+
192
  return rewards
193
 
194
+ # =========================
195
+ # MODEL
196
+ # =========================
197
 
198
  model, tokenizer = FastLanguageModel.from_pretrained(
199
  model_name="unsloth/Llama-3.1-8B-Instruct",
 
200
  load_in_4bit=True,
201
+ max_seq_length=1024,
202
  )
203
+
204
  model = FastLanguageModel.get_peft_model(
205
  model,
206
  r=16,
207
  target_modules=["q_proj", "v_proj"],
208
  lora_alpha=16,
209
+ lora_dropout=0,
 
210
  )
211
 
212
+ # =========================
213
+ # TRAINER
214
+ # =========================
215
 
216
  dataset = build_dataset()
217
 
218
  trainer = GRPOTrainer(
219
  model=model,
220
+ reward_funcs=[reward_environment],
221
  args=GRPOConfig(
222
+ output_dir="outputs",
223
  learning_rate=5e-6,
224
  num_train_epochs=1,
225
+ per_device_train_batch_size=1,
226
+ gradient_accumulation_steps=2,
227
+ num_generations=2,
228
  max_prompt_length=512,
229
+ max_completion_length=64,
230
+ logging_steps=2,
231
+ report_to="none"
 
 
232
  ),
233
  train_dataset=dataset,
234
+ tokenizer=tokenizer
235
  )
236
 
237
+ # =========================
238
+ # RUN
239
+ # =========================
240
+
241
  if __name__ == "__main__":
242
+ ensure_env_ready()
243
+
244
+ print("🚀 Starting training...")
245
  trainer.train()
246
+
247
+ model.save_pretrained("outputs/final")
248
+ tokenizer.save_pretrained("outputs/final")
249
+
250
+ print("✅ Done")