File size: 6,630 Bytes
08731ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
#!/usr/bin/env python3
"""
Diagnostic cell to verify reward function is working before training.
Run this BEFORE training to catch zero-loss issues early.
"""

import json
import numpy as np
import random
import re
import requests

ENV_URL = "https://prajwal782007-gridmind.hf.space"


def gridmind_reward_fn(completions, env_url=ENV_URL, **kwargs):
    """
    Fixed reward function for GRPO with environment reset per completion.
    Returns varied rewards to enable GRPO learning.
    """
    rewards = []
    batch_rewards = []
    call_count = 0
    
    for i, completion in enumerate(completions):
        call_count += 1
        
        text = completion[0]["content"] if isinstance(completion, list) else completion
        
        try:
            match = re.search(r'\{.*?\}', text, re.DOTALL)
            if not match:
                rewards.append(-1.0)
                batch_rewards.append(-1.0)
                continue
            
            action = json.loads(match.group())
            
            step_action = {
                "hvac_power_level": float(max(0, min(1, action.get("hvac_power_level", 0.5)))),
                "thermal_charge_rate": float(max(-1, min(1, action.get("thermal_charge_rate", 0.0)))),
                "batch_job_slot": int(max(0, min(4, action.get("batch_job_slot", 0)))),
                "load_shed_fraction": float(max(0, min(0.5, action.get("load_shed_fraction", 0.0)))),
                "building_id": 0
            }
            
            # VARY SEED each call to ensure different episodes
            seed = 1000 + call_count
            task_id = (call_count % 3) + 1
            
            # CRITICAL: Reset environment for each completion
            reset_resp = requests.post(
                f"{env_url}/reset",
                json={"task_id": task_id, "seed": seed},
                timeout=30
            )
            if reset_resp.status_code != 200:
                rewards.append(-0.5)
                batch_rewards.append(-0.5)
                continue
            
            # Run 8 steps
            num_steps = 8
            total_reward = 0.0
            for _ in range(num_steps):
                step_resp = requests.post(
                    f"{env_url}/step",
                    json=[step_action],
                    timeout=30
                )
                if step_resp.status_code != 200:
                    break
                step_data = step_resp.json()
                if isinstance(step_data, list):
                    step_data = step_data[0]
                total_reward += float(step_data.get("reward", 0))
            
            avg_reward = total_reward / num_steps if num_steps > 0 else 0
            
            # Get episode score from /grade
            grade_resp = requests.get(f"{env_url}/grade", timeout=30)
            if grade_resp.status_code == 200:
                episode_score = float(grade_resp.json().get("score", 0.5))
                normalized = max(0.0, min(1.0, (episode_score - 0.4) / 0.32))
                final_reward = normalized
            else:
                final_reward = max(-1.0, min(1.0, avg_reward / 10.0))
            
            rewards.append(final_reward)
            batch_rewards.append(final_reward)
            
        except json.JSONDecodeError:
            rewards.append(-0.8)
            batch_rewards.append(-0.8)
        except Exception as e:
            print(f"Reward error: {e}")
            rewards.append(-0.5)
            batch_rewards.append(-0.5)
    
    return rewards


def run_diagnostic():
    print("=== PRE-TRAINING REWARD FUNCTION DIAGNOSTIC ===")
    print("Testing reward variance with 8 random actions...\n")
    
    requests.post(f"{ENV_URL}/reset", json={"task_id": 1}, timeout=10)
    
    test_completions = [
        # Good action — efficient
        '{"hvac_power_level": 0.3, "thermal_charge_rate": 0.8, "batch_job_slot": 2, "load_shed_fraction": 0.0, "building_id": 0}',
        # Bad action — wasteful
        '{"hvac_power_level": 1.0, "thermal_charge_rate": -1.0, "batch_job_slot": 0, "load_shed_fraction": 0.5, "building_id": 0}',
        # Medium action
        '{"hvac_power_level": 0.5, "thermal_charge_rate": 0.0, "batch_job_slot": 1, "load_shed_fraction": 0.1, "building_id": 0}',
        # Invalid JSON — should get -1.0
        'I will set the HVAC to medium power level',
        # Another good action
        '{"hvac_power_level": 0.2, "thermal_charge_rate": 0.6, "batch_job_slot": 3, "load_shed_fraction": 0.0, "building_id": 0}',
        # Another bad action
        '{"hvac_power_level": 0.9, "thermal_charge_rate": -0.8, "batch_job_slot": 0, "load_shed_fraction": 0.4, "building_id": 0}',
        # Good charge during cheap hours
        '{"hvac_power_level": 0.4, "thermal_charge_rate": 0.9, "batch_job_slot": 2, "load_shed_fraction": 0.0, "building_id": 0}',
        # Bad during peak
        '{"hvac_power_level": 0.8, "thermal_charge_rate": -0.5, "batch_job_slot": 0, "load_shed_fraction": 0.3, "building_id": 0}',
    ]
    
    test_rewards = gridmind_reward_fn(test_completions)
    
    print("Completion type            → Reward")
    print("-" * 45)
    labels = [
        "Good (efficient)",
        "Bad (wasteful)",
        "Medium",
        "Invalid JSON",
        "Good (store)",
        "Bad (discharge peak)",
        "Good (charge cheap)",
        "Bad (peak demand)",
    ]
    for label, reward in zip(labels, test_rewards):
        bar = "█" * int(abs(reward) * 20)
        sign = "+" if reward >= 0 else "-"
        print(f"  {label:<25}{reward:+.4f}  {bar}")
    
    if len(test_rewards) > 1:
        variance = np.var(test_rewards)
        reward_range = max(test_rewards) - min(test_rewards)
        print(f"\nReward variance:  {variance:.4f}")
        print(f"Reward range:     {reward_range:.4f}")
        
        if variance < 0.01:
            print("\n❌ CRITICAL: Reward variance is near zero!")
            print("   GRPO cannot learn from this. Fix the reward function before training.")
            print("   Check that the environment is being reset between calls.")
            return False
        elif variance < 0.05:
            print("\n⚠️  WARNING: Low reward variance. Training may be slow.")
            print("   Consider amplifying reward differences.")
            return True
        else:
            print("\n✓ Reward variance is sufficient for GRPO training.")
            print("  Proceed to training.")
            return True
    
    return False


if __name__ == "__main__":
    success = run_diagnostic()
    exit(0 if success else 1)