File size: 2,651 Bytes
e2485ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c64ec47
 
 
 
 
 
b49335b
 
 
 
7a78f7e
e2485ba
f259f2b
 
c64ec47
e2485ba
 
f259f2b
 
 
 
 
 
e2485ba
f259f2b
 
c64ec47
e2485ba
 
 
 
 
 
 
f259f2b
e2485ba
 
 
 
 
 
 
f259f2b
e2485ba
 
 
 
 
 
 
 
f259f2b
 
 
e2485ba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from src.environment import TrafficEnv
from src.models import State, StepResult
from typing import Dict, Any, Optional

class BaseTask:
    def __init__(self, config: Dict[str, Any]):
        self.config = config
        self.env = TrafficEnv(config)
        
    def reset(self, seed: Optional[int] = None) -> State:
        return self.env.reset(seed=seed)
        
    def step(self, action_type: int) -> StepResult:
        return self.env.step(action_type)
        
    def state(self) -> State:
        return self.env.state()
        
    def evaluate(self) -> float:
        total_arrived = self.env.total_cleared + max(0,
            self.env.north + self.env.south + self.env.east + self.env.west)

        if total_arrived == 0 and self.env.total_cleared == 0:
            return 1.0

        # Accurately map max clearance logic to prevent low-traffic penalties & high-traffic inflation.
        expected_arrived = self.env.max_time * self.config.get("arrival_rate", 2.0) * 4 * self.config.get("congestion_multiplier", 1.0)
        max_possible = min(float(total_arrived), float(expected_arrived))
        clear_score = min(1.0, self.env.total_cleared / max(1.0, max_possible))

        avg_wait = self.env.total_waiting_time / max(1, self.env.total_cleared)
        max_wait = 30.0
        wait_score = max(0.0, 1.0 - (avg_wait / max_wait))

        if self.config.get("emergency_prob", 0) > 0:
            handled = self.env.emergencies_handled
            total_emergencies = self.env.total_emergencies_generated
            
            em_score = (handled / max(1, total_emergencies)) if total_emergencies > 0 else 1.0
            
            # User defined balanced score
            total = (0.5 * clear_score) + (0.3 * wait_score) + (0.2 * em_score)
        else:
            # Rescaled symmetrically if no emergency component applies
            total = (0.625 * clear_score) + (0.375 * wait_score)

        return min(1.0, max(0.0, total))

class EasyTask(BaseTask):
    def __init__(self):
        super().__init__({
            "max_time": 100,
            "arrival_rate": 2.0,
            "congestion_multiplier": 1.0,
            "emergency_prob": 0.0
        })

class MediumTask(BaseTask):
    def __init__(self):
        super().__init__({
            "max_time": 200,
            "arrival_rate": 2.2,
            "congestion_multiplier": 1.5, 
            "emergency_prob": 0.0
        })

class HardTask(BaseTask):
    def __init__(self):
        super().__init__({
            "max_time": 300,
            "arrival_rate": 2.0, 
            "congestion_multiplier": 1.75, 
            "emergency_prob": 0.08
        })