File size: 4,882 Bytes
e650f0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import random
from typing import List, Tuple

from server.contract_validation_environment import ContractValidationEnvironment
from models import ContractValidationAction


class ContractRLAgent:
    def __init__(self, clauses: List[int], risk_types: List[str]):
        self.clauses = clauses
        self.risk_types = risk_types
        self.q_table = {}
        self.alpha = 0.1
        self.gamma = 0.9
        self.epsilon = 1.0
        self.epsilon_decay = 0.95
        self.min_epsilon = 0.01

    def _get_state(self, obs) -> frozenset:
        return frozenset(obs.flagged_risks.items())

    def _get_possible_actions(self) -> List[Tuple[int, str]]:
        actions = []
        for c in self.clauses:
            for r in self.risk_types:
                actions.append((c, r))
        actions.append((0, "submit"))
        return actions

    def choose_action(self, state: frozenset) -> Tuple[int, str]:
        possible_actions = self._get_possible_actions()
        if state not in self.q_table:
            self.q_table[state] = {a: 0.0 for a in possible_actions}

        if random.random() < self.epsilon:
            return random.choice(possible_actions)

        return max(self.q_table[state], key=self.q_table[state].get)

    def learn(self, state: frozenset, action: Tuple[int, str], reward: float, next_state: frozenset, done: bool):
        possible_actions = self._get_possible_actions()
        if state not in self.q_table:
            self.q_table[state] = {a: 0.0 for a in possible_actions}
        if next_state not in self.q_table:
            self.q_table[next_state] = {a: 0.0 for a in possible_actions}

        # BUG FIX 2: If the episode is done, there is no future reward!
        best_next_q = 0.0 if done else max(self.q_table[next_state].values())

        current_q = self.q_table[state][action]
        self.q_table[state][action] = current_q + self.alpha * \
            (reward + self.gamma * best_next_q - current_q)

    def train(self, env: ContractValidationEnvironment, episodes: int = 300):
        print(f"Starting training for {episodes} episodes...")
        for episode in range(episodes):
            obs = env.reset(task_level="easy")
            state = self._get_state(obs)
            done = obs.done

            while not done and obs.step_count < 15:
                action_tuple = self.choose_action(state)
                clause_id, risk_type = action_tuple

                is_submit = (risk_type == "submit")

                env_action = ContractValidationAction(
                    clause_id=clause_id,
                    risk_type="none" if is_submit else risk_type,
                    submit_final=is_submit,
                    explanation="RL Agent exploration"
                )

                next_obs = env.step(env_action)
                next_state = self._get_state(next_obs)

                # BUG FIX 1: The environment now returns the exact step reward directly!
                step_reward = next_obs.reward

                # Pass the 'done' flag so the table knows when to stop looking forward
                self.learn(state, action_tuple, step_reward,
                           next_state, next_obs.done)

                state = next_state
                done = next_obs.done

            if self.epsilon > self.min_epsilon:
                self.epsilon *= self.epsilon_decay

        print("Training complete!")

    def test(self, env: ContractValidationEnvironment):
        print("\n--- Testing Trained Agent ---")
        self.epsilon = 0.0

        obs = env.reset(task_level="easy")
        state = self._get_state(obs)
        done = False

        while not done and obs.step_count < 15:
            action_tuple = self.choose_action(state)
            clause_id, risk_type = action_tuple

            is_submit = (risk_type == "submit")

            env_action = ContractValidationAction(
                clause_id=clause_id,
                risk_type="none" if is_submit else risk_type,
                submit_final=is_submit,
                explanation="Agent's learned optimal choice"
            )

            print(
                f"Agent Action -> Clause: {clause_id} | Flagged Risk: {risk_type} | Submitted: {is_submit}")
            obs = env.step(env_action)
            state = self._get_state(obs)
            done = obs.done

        print(f"\nFinal Environment Reward: {obs.reward}")
        print(f"Final Flagged Risks: {obs.flagged_risks}")


if __name__ == "__main__":
    env = ContractValidationEnvironment()
    valid_clauses = [1]
    potential_risks = ["liability", "payment", "none"]

    agent = ContractRLAgent(clauses=valid_clauses, risk_types=potential_risks)
    agent.train(env, episodes=300)
    agent.test(env)