Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| GridMind-RL Unsloth GRPO Training Script | |
| ---------------------------------------------- | |
| Fine-tunes Qwen2.5-1.5B-Instruct using Unsloth's 4-bit LoRA and TRL's GRPOTrainer. | |
| The environment rewards are gathered by hitting the OpenEnv HTTP server directly. | |
| FIXED: Removed reward hacking, added entropy bonus, diverse seeds, proper normalization. | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| import re | |
| import sys | |
| import requests | |
| import pandas as pd | |
| import random | |
| from collections import Counter | |
| from datasets import Dataset | |
| from trl import GRPOTrainer, GRPOConfig | |
| from unsloth import FastLanguageModel | |
| from transformers import TrainerCallback | |
| os.makedirs("results", exist_ok=True) | |
| SYSTEM_PROMPT = """You are an expert industrial building energy controller. | |
| Each turn you receive the current building state and must respond with | |
| ONLY a valid JSON action object. | |
| Action format: | |
| {"hvac_power_level": <0.0-1.0>, "thermal_charge_rate": <-1.0 to 1.0>, | |
| "batch_job_slot": <0-4>, "load_shed_fraction": <0.0-0.5>, "building_id": 0} | |
| Strategy: | |
| - Always respond with valid JSON containing all required keys | |
| - Vary your actions - don't repeat the same pattern | |
| - Optimize for low cost + comfort maintenance + grid response""" | |
| def make_prompt(i, obs=None, task_desc=""): | |
| system_content = SYSTEM_PROMPT | |
| if obs and task_desc: | |
| system_content += f"\n\nCurrent observation:\n- Temperature: {obs.get('indoor_temperature', 21):.1f}°C\n- Price: ${obs.get('current_price', 0.10):.3f}/kWh\n- Grid stress: {obs.get('grid_stress_signal', 0):.2f}\n- Hour: {obs.get('hour_of_day', 12)}\n- Storage: {obs.get('thermal_storage_level', 0.5):.1%}" | |
| return [{ | |
| "role": "system", "content": system_content | |
| }, { | |
| "role": "user", | |
| "content": f"Episode {i+1}: {task_desc}\nOutput action as JSON." | |
| }] | |
| def reward_valid_json(completions, **kwargs): | |
| """Reward 0.25 for any valid JSON output.""" | |
| rewards = [] | |
| for completion in completions: | |
| text = completion[0]["content"] if isinstance(completion, list) else completion | |
| try: | |
| match = re.search(r'\{.*?\}', text, re.DOTALL) | |
| if match: | |
| json.loads(match.group()) | |
| rewards.append(0.25) | |
| else: | |
| rewards.append(0.0) | |
| except Exception: | |
| rewards.append(0.0) | |
| return rewards | |
| def reward_has_required_keys(completions, **kwargs): | |
| """Reward 0.25 if JSON has all 4 required action keys.""" | |
| required = {"hvac_power_level", "thermal_charge_rate", "batch_job_slot", "load_shed_fraction"} | |
| rewards = [] | |
| for completion in completions: | |
| text = completion[0]["content"] if isinstance(completion, list) else completion | |
| try: | |
| match = re.search(r'\{.*?\}', text, re.DOTALL) | |
| if match: | |
| action = json.loads(match.group()) | |
| if required.issubset(action.keys()): | |
| rewards.append(0.25) | |
| else: | |
| rewards.append(0.1) | |
| else: | |
| rewards.append(0.0) | |
| except Exception: | |
| rewards.append(0.0) | |
| return rewards | |
| def get_reward_env_interaction(env_url): | |
| """Episode-level reward from /grade endpoint with diverse seeds. | |
| FIXED: Uses raw /grade score directly (0.0-1.0), no normalization that causes reward hacking. | |
| Each sample gets a different seed/task to prevent mode collapse. | |
| """ | |
| last_observations = [] | |
| def reward_env_interaction(completions, **kwargs): | |
| nonlocal last_observations | |
| rewards = [] | |
| for i, completion in enumerate(completions): | |
| text = completion[0]["content"] if isinstance(completion, list) else completion | |
| try: | |
| match = re.search(r'\{.*?\}', text, re.DOTALL) | |
| action = json.loads(match.group()) if match else {} | |
| step_action = { | |
| "hvac_power_level": float(max(0, min(1, action.get("hvac_power_level", 0.5)))), | |
| "thermal_charge_rate": float(max(-1, min(1, action.get("thermal_charge_rate", 0.0)))), | |
| "batch_job_slot": int(max(0, min(4, action.get("batch_job_slot", 0)))), | |
| "load_shed_fraction": float(max(0, min(0.5, action.get("load_shed_fraction", 0.0)))), | |
| "building_id": 0 | |
| } | |
| # Diverse seeds to prevent mode collapse | |
| seed = 2000 + (i * 17) % 500 | |
| task_id = (i % 3) + 1 | |
| reset_resp = requests.post( | |
| f"{env_url}/reset", | |
| json={"task_id": task_id, "seed": seed}, | |
| timeout=30 | |
| ) | |
| if reset_resp.status_code != 200: | |
| rewards.append(0.0) | |
| continue | |
| obs = reset_resp.json().get("observations", [{}])[0] if reset_resp.json().get("observations") else {} | |
| last_observations.append(obs) | |
| # 4-step mini-rollout for faster training | |
| for _ in range(4): | |
| step_resp = requests.post( | |
| f"{env_url}/step", | |
| json=[step_action], | |
| timeout=30 | |
| ) | |
| if step_resp.status_code != 200: | |
| break | |
| grade_resp = requests.get(f"{env_url}/grade", timeout=30) | |
| if grade_resp.status_code == 200: | |
| episode_score = float(grade_resp.json().get("score", 0.5)) | |
| rewards.append(episode_score) | |
| else: | |
| rewards.append(0.0) | |
| except Exception as e: | |
| print(f"Env error: {e}", file=sys.stderr) | |
| rewards.append(0.0) | |
| return rewards | |
| return reward_env_interaction | |
| def reward_entropy_bonus(completions, **kwargs): | |
| """Reward action diversity to prevent mode collapse - bonus for varied actions.""" | |
| rewards = [] | |
| actions_seen = [] | |
| for completion in completions: | |
| text = completion[0]["content"] if isinstance(completion, list) else completion | |
| try: | |
| match = re.search(r'\{.*?\}', text, re.DOTALL) | |
| if match: | |
| action = json.loads(match.group()) | |
| actions_seen.append(json.dumps(action, sort_keys=True)) | |
| except: | |
| pass | |
| if len(actions_seen) > 1: | |
| unique_actions = len(set(actions_seen)) | |
| diversity_ratio = unique_actions / len(actions_seen) | |
| rewards = [0.1 * diversity_ratio] * len(actions_seen) | |
| else: | |
| rewards = [0.05] * len(completions) | |
| return rewards | |
| class CSVLogCallback(TrainerCallback): | |
| """Custom callback to continuously log training metrics to a CSV file.""" | |
| def __init__(self, output_path): | |
| self.output_path = output_path | |
| self.log_history = [] | |
| def on_log(self, args, state, control, logs=None, **kwargs): | |
| if logs is not None and "loss" in logs: | |
| logs_copy = logs.copy() | |
| logs_copy["step"] = state.global_step | |
| self.log_history.append(logs_copy) | |
| pd.DataFrame(self.log_history).to_csv(self.output_path, index=False) | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Train GridMind-RL agent with Unsloth GRPO") | |
| parser.add_argument("--env-url", type=str, default="http://localhost:7860", help="OpenEnv server URL") | |
| parser.add_argument("--model-name", type=str, default="unsloth/Qwen2.5-1.5B-Instruct", help="Base model") | |
| parser.add_argument("--prompts", type=int, default=300, help="Number of training prompts") | |
| parser.add_argument("--epochs", type=int, default=1, help="Training epochs") | |
| parser.add_argument("--max-steps", type=int, default=-1, help="Max steps (overrides epochs if > 0)") | |
| parser.add_argument("--output-csv", type=str, default="results/training_log.csv", help="Metrics output") | |
| parser.add_argument("--output-dir", type=str, default="gridmind-grpo-unsloth", help="Model save dir") | |
| args = parser.parse_args() | |
| print(f"🚀 Loading model: {args.model_name}") | |
| max_seq_length = 512 | |
| lora_rank = 16 # Increased for better learning capacity | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=args.model_name, | |
| max_seq_length=max_seq_length, | |
| load_in_4bit=True, | |
| ) | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| r=lora_rank, | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", | |
| "gate_proj", "up_proj", "down_proj"], | |
| lora_alpha=lora_rank * 2, | |
| use_gradient_checkpointing="unsloth", | |
| random_state=42, | |
| ) | |
| print("✅ Model loaded with Unsloth 4-bit LoRA") | |
| dataset = Dataset.from_dict({ | |
| "prompt": [make_prompt(i) for i in range(args.prompts)] | |
| }) | |
| print(f"✅ Dataset ready: {len(dataset)} training prompts") | |
| training_args = GRPOConfig( | |
| output_dir=args.output_dir, | |
| num_train_epochs=args.epochs, | |
| max_steps=args.max_steps, | |
| per_device_train_batch_size=1, | |
| gradient_accumulation_steps=4, | |
| num_generations=4, | |
| max_prompt_length=256, | |
| max_completion_length=128, | |
| learning_rate=3e-6, # Lower LR for stability | |
| lr_scheduler_type="cosine", | |
| warmup_ratio=0.1, | |
| logging_steps=5, | |
| save_steps=100, | |
| fp16=True, | |
| report_to="none", | |
| seed=42, | |
| ) | |
| trainer = GRPOTrainer( | |
| model=model, | |
| tokenizer=tokenizer, | |
| args=training_args, | |
| train_dataset=dataset, | |
| reward_funcs=[ | |
| reward_valid_json, | |
| reward_has_required_keys, | |
| get_reward_env_interaction(args.env_url), | |
| reward_entropy_bonus, | |
| ], | |
| callbacks=[CSVLogCallback(args.output_csv)] | |
| ) | |
| print("🚀 Starting GRPO training...") | |
| trainer.train() | |
| print(f"✅ Training complete! Checkpoints saved to {args.output_dir}") | |
| print(f"✅ Logs saved to {args.output_csv}") | |
| if __name__ == "__main__": | |
| main() |