File size: 10,142 Bytes
c395f6a
 
 
d2449aa
 
c395f6a
d2449aa
 
c395f6a
 
 
 
 
 
 
 
 
d2449aa
 
c395f6a
 
 
 
 
 
 
d2449aa
c395f6a
 
 
 
 
 
 
 
d2449aa
 
 
c395f6a
d2449aa
 
 
 
 
c395f6a
d2449aa
c395f6a
 
d2449aa
c395f6a
 
 
d2449aa
c395f6a
 
 
 
 
 
 
d2449aa
c395f6a
 
 
 
 
 
 
d2449aa
c395f6a
 
 
 
 
 
 
 
 
d2449aa
c395f6a
 
 
 
 
 
 
 
 
d2449aa
 
 
 
c395f6a
d2449aa
 
c395f6a
d2449aa
c395f6a
d2449aa
bdc9954
c395f6a
 
 
 
 
 
 
 
 
 
 
c70e17d
d2449aa
 
 
bdc9954
c395f6a
 
bdc9954
c395f6a
 
 
 
 
c70e17d
d2449aa
 
 
 
 
c70e17d
 
 
 
 
 
bdc9954
c70e17d
 
 
 
d2449aa
c395f6a
bdc9954
c70e17d
c395f6a
 
 
 
 
 
d2449aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c395f6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2449aa
c395f6a
 
 
 
 
 
 
 
 
d2449aa
c395f6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2449aa
c395f6a
 
d2449aa
c395f6a
 
 
 
 
d2449aa
c395f6a
 
 
 
 
 
 
 
 
 
 
 
d2449aa
c395f6a
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
#!/usr/bin/env python3
"""
GridMind-RL Unsloth GRPO Training Script
----------------------------------------------
Fine-tunes Qwen2.5-1.5B-Instruct using Unsloth's 4-bit LoRA and TRL's GRPOTrainer.
The environment rewards are gathered by hitting the OpenEnv HTTP server directly.

FIXED: Removed reward hacking, added entropy bonus, diverse seeds, proper normalization.
"""

import argparse
import json
import os
import re
import sys
import requests
import pandas as pd
import random
from collections import Counter
from datasets import Dataset
from trl import GRPOTrainer, GRPOConfig
from unsloth import FastLanguageModel
from transformers import TrainerCallback

os.makedirs("results", exist_ok=True)

SYSTEM_PROMPT = """You are an expert industrial building energy controller.
Each turn you receive the current building state and must respond with 
ONLY a valid JSON action object.

Action format:
{"hvac_power_level": <0.0-1.0>, "thermal_charge_rate": <-1.0 to 1.0>, 
 "batch_job_slot": <0-4>, "load_shed_fraction": <0.0-0.5>, "building_id": 0}

Strategy:
- Always respond with valid JSON containing all required keys
- Vary your actions - don't repeat the same pattern
- Optimize for low cost + comfort maintenance + grid response"""

def make_prompt(i, obs=None, task_desc=""):
    system_content = SYSTEM_PROMPT
    if obs and task_desc:
        system_content += f"\n\nCurrent observation:\n- Temperature: {obs.get('indoor_temperature', 21):.1f}°C\n- Price: ${obs.get('current_price', 0.10):.3f}/kWh\n- Grid stress: {obs.get('grid_stress_signal', 0):.2f}\n- Hour: {obs.get('hour_of_day', 12)}\n- Storage: {obs.get('thermal_storage_level', 0.5):.1%}"
    
    return [{
        "role": "system", "content": system_content
    }, {
        "role": "user",
        "content": f"Episode {i+1}: {task_desc}\nOutput action as JSON."
    }]

def reward_valid_json(completions, **kwargs):
    """Reward 0.25 for any valid JSON output."""
    rewards = []
    for completion in completions:
        text = completion[0]["content"] if isinstance(completion, list) else completion
        try:
            match = re.search(r'\{.*?\}', text, re.DOTALL)
            if match:
                json.loads(match.group())
                rewards.append(0.25)
            else:
                rewards.append(0.0)
        except Exception:
            rewards.append(0.0)
    return rewards

def reward_has_required_keys(completions, **kwargs):
    """Reward 0.25 if JSON has all 4 required action keys."""
    required = {"hvac_power_level", "thermal_charge_rate", "batch_job_slot", "load_shed_fraction"}
    rewards = []
    for completion in completions:
        text = completion[0]["content"] if isinstance(completion, list) else completion
        try:
            match = re.search(r'\{.*?\}', text, re.DOTALL)
            if match:
                action = json.loads(match.group())
                if required.issubset(action.keys()):
                    rewards.append(0.25)
                else:
                    rewards.append(0.1)
            else:
                rewards.append(0.0)
        except Exception:
            rewards.append(0.0)
    return rewards

def get_reward_env_interaction(env_url):
    """Episode-level reward from /grade endpoint with diverse seeds.
    
    FIXED: Uses raw /grade score directly (0.0-1.0), no normalization that causes reward hacking.
    Each sample gets a different seed/task to prevent mode collapse.
    """
    last_observations = []
    
    def reward_env_interaction(completions, **kwargs):
        nonlocal last_observations
        rewards = []
        
        for i, completion in enumerate(completions):
            text = completion[0]["content"] if isinstance(completion, list) else completion
            try:
                match = re.search(r'\{.*?\}', text, re.DOTALL)
                action = json.loads(match.group()) if match else {}
                step_action = {
                    "hvac_power_level": float(max(0, min(1, action.get("hvac_power_level", 0.5)))),
                    "thermal_charge_rate": float(max(-1, min(1, action.get("thermal_charge_rate", 0.0)))),
                    "batch_job_slot": int(max(0, min(4, action.get("batch_job_slot", 0)))),
                    "load_shed_fraction": float(max(0, min(0.5, action.get("load_shed_fraction", 0.0)))),
                    "building_id": 0
                }

                # Diverse seeds to prevent mode collapse
                seed = 2000 + (i * 17) % 500
                task_id = (i % 3) + 1

                reset_resp = requests.post(
                    f"{env_url}/reset",
                    json={"task_id": task_id, "seed": seed},
                    timeout=30
                )
                if reset_resp.status_code != 200:
                    rewards.append(0.0)
                    continue

                obs = reset_resp.json().get("observations", [{}])[0] if reset_resp.json().get("observations") else {}
                last_observations.append(obs)

                # 4-step mini-rollout for faster training
                for _ in range(4):
                    step_resp = requests.post(
                        f"{env_url}/step",
                        json=[step_action],
                        timeout=30
                    )
                    if step_resp.status_code != 200:
                        break

                grade_resp = requests.get(f"{env_url}/grade", timeout=30)
                if grade_resp.status_code == 200:
                    episode_score = float(grade_resp.json().get("score", 0.5))
                    rewards.append(episode_score)
                else:
                    rewards.append(0.0)

            except Exception as e:
                print(f"Env error: {e}", file=sys.stderr)
                rewards.append(0.0)
        return rewards
    return reward_env_interaction

def reward_entropy_bonus(completions, **kwargs):
    """Reward action diversity to prevent mode collapse - bonus for varied actions."""
    rewards = []
    actions_seen = []
    
    for completion in completions:
        text = completion[0]["content"] if isinstance(completion, list) else completion
        try:
            match = re.search(r'\{.*?\}', text, re.DOTALL)
            if match:
                action = json.loads(match.group())
                actions_seen.append(json.dumps(action, sort_keys=True))
        except:
            pass
    
    if len(actions_seen) > 1:
        unique_actions = len(set(actions_seen))
        diversity_ratio = unique_actions / len(actions_seen)
        rewards = [0.1 * diversity_ratio] * len(actions_seen)
    else:
        rewards = [0.05] * len(completions)
    
    return rewards

class CSVLogCallback(TrainerCallback):
    """Custom callback to continuously log training metrics to a CSV file."""
    def __init__(self, output_path):
        self.output_path = output_path
        self.log_history = []

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None and "loss" in logs:
            logs_copy = logs.copy()
            logs_copy["step"] = state.global_step
            self.log_history.append(logs_copy)
            pd.DataFrame(self.log_history).to_csv(self.output_path, index=False)

def main():
    parser = argparse.ArgumentParser(description="Train GridMind-RL agent with Unsloth GRPO")
    parser.add_argument("--env-url", type=str, default="http://localhost:7860", help="OpenEnv server URL")
    parser.add_argument("--model-name", type=str, default="unsloth/Qwen2.5-1.5B-Instruct", help="Base model")
    parser.add_argument("--prompts", type=int, default=300, help="Number of training prompts")
    parser.add_argument("--epochs", type=int, default=1, help="Training epochs")
    parser.add_argument("--max-steps", type=int, default=-1, help="Max steps (overrides epochs if > 0)")
    parser.add_argument("--output-csv", type=str, default="results/training_log.csv", help="Metrics output")
    parser.add_argument("--output-dir", type=str, default="gridmind-grpo-unsloth", help="Model save dir")
    args = parser.parse_args()

    print(f"🚀 Loading model: {args.model_name}")
    max_seq_length = 512
    lora_rank = 16  # Increased for better learning capacity

    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=args.model_name,
        max_seq_length=max_seq_length,
        load_in_4bit=True,
    )

    model = FastLanguageModel.get_peft_model(
        model,
        r=lora_rank,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                        "gate_proj", "up_proj", "down_proj"],
        lora_alpha=lora_rank * 2,
        use_gradient_checkpointing="unsloth",
        random_state=42,
    )
    print("✅ Model loaded with Unsloth 4-bit LoRA")

    dataset = Dataset.from_dict({
        "prompt": [make_prompt(i) for i in range(args.prompts)]
    })
    print(f"✅ Dataset ready: {len(dataset)} training prompts")

    training_args = GRPOConfig(
        output_dir=args.output_dir,
        num_train_epochs=args.epochs,
        max_steps=args.max_steps,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        num_generations=4,
        max_prompt_length=256,
        max_completion_length=128,
        learning_rate=3e-6,  # Lower LR for stability
        lr_scheduler_type="cosine",
        warmup_ratio=0.1,
        logging_steps=5,
        save_steps=100,
        fp16=True,
        report_to="none",
        seed=42,
    )

    trainer = GRPOTrainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=dataset,
        reward_funcs=[
            reward_valid_json,
            reward_has_required_keys,
            get_reward_env_interaction(args.env_url),
            reward_entropy_bonus,
        ],
        callbacks=[CSVLogCallback(args.output_csv)]
    )

    print("🚀 Starting GRPO training...")
    trainer.train()

    print(f"✅ Training complete! Checkpoints saved to {args.output_dir}")
    print(f"✅ Logs saved to {args.output_csv}")

if __name__ == "__main__":
    main()