File size: 4,882 Bytes
e650f0f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | import random
from typing import List, Tuple
from server.contract_validation_environment import ContractValidationEnvironment
from models import ContractValidationAction
class ContractRLAgent:
def __init__(self, clauses: List[int], risk_types: List[str]):
self.clauses = clauses
self.risk_types = risk_types
self.q_table = {}
self.alpha = 0.1
self.gamma = 0.9
self.epsilon = 1.0
self.epsilon_decay = 0.95
self.min_epsilon = 0.01
def _get_state(self, obs) -> frozenset:
return frozenset(obs.flagged_risks.items())
def _get_possible_actions(self) -> List[Tuple[int, str]]:
actions = []
for c in self.clauses:
for r in self.risk_types:
actions.append((c, r))
actions.append((0, "submit"))
return actions
def choose_action(self, state: frozenset) -> Tuple[int, str]:
possible_actions = self._get_possible_actions()
if state not in self.q_table:
self.q_table[state] = {a: 0.0 for a in possible_actions}
if random.random() < self.epsilon:
return random.choice(possible_actions)
return max(self.q_table[state], key=self.q_table[state].get)
def learn(self, state: frozenset, action: Tuple[int, str], reward: float, next_state: frozenset, done: bool):
possible_actions = self._get_possible_actions()
if state not in self.q_table:
self.q_table[state] = {a: 0.0 for a in possible_actions}
if next_state not in self.q_table:
self.q_table[next_state] = {a: 0.0 for a in possible_actions}
# BUG FIX 2: If the episode is done, there is no future reward!
best_next_q = 0.0 if done else max(self.q_table[next_state].values())
current_q = self.q_table[state][action]
self.q_table[state][action] = current_q + self.alpha * \
(reward + self.gamma * best_next_q - current_q)
def train(self, env: ContractValidationEnvironment, episodes: int = 300):
print(f"Starting training for {episodes} episodes...")
for episode in range(episodes):
obs = env.reset(task_level="easy")
state = self._get_state(obs)
done = obs.done
while not done and obs.step_count < 15:
action_tuple = self.choose_action(state)
clause_id, risk_type = action_tuple
is_submit = (risk_type == "submit")
env_action = ContractValidationAction(
clause_id=clause_id,
risk_type="none" if is_submit else risk_type,
submit_final=is_submit,
explanation="RL Agent exploration"
)
next_obs = env.step(env_action)
next_state = self._get_state(next_obs)
# BUG FIX 1: The environment now returns the exact step reward directly!
step_reward = next_obs.reward
# Pass the 'done' flag so the table knows when to stop looking forward
self.learn(state, action_tuple, step_reward,
next_state, next_obs.done)
state = next_state
done = next_obs.done
if self.epsilon > self.min_epsilon:
self.epsilon *= self.epsilon_decay
print("Training complete!")
def test(self, env: ContractValidationEnvironment):
print("\n--- Testing Trained Agent ---")
self.epsilon = 0.0
obs = env.reset(task_level="easy")
state = self._get_state(obs)
done = False
while not done and obs.step_count < 15:
action_tuple = self.choose_action(state)
clause_id, risk_type = action_tuple
is_submit = (risk_type == "submit")
env_action = ContractValidationAction(
clause_id=clause_id,
risk_type="none" if is_submit else risk_type,
submit_final=is_submit,
explanation="Agent's learned optimal choice"
)
print(
f"Agent Action -> Clause: {clause_id} | Flagged Risk: {risk_type} | Submitted: {is_submit}")
obs = env.step(env_action)
state = self._get_state(obs)
done = obs.done
print(f"\nFinal Environment Reward: {obs.reward}")
print(f"Final Flagged Risks: {obs.flagged_risks}")
if __name__ == "__main__":
env = ContractValidationEnvironment()
valid_clauses = [1]
potential_risks = ["liability", "payment", "none"]
agent = ContractRLAgent(clauses=valid_clauses, risk_types=potential_risks)
agent.train(env, episodes=300)
agent.test(env)
|