File size: 11,399 Bytes
af6094d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
import hashlib
import json
import time
import os


def hash_tool(data):
    """
    Hash tool: Convert input to deterministic string and return SHA-256 hex digest.
    """
    if isinstance(data, dict):
        data_str = json.dumps(data, sort_keys=True)
    else:
        data_str = str(data)
    
    return hashlib.sha256(data_str.encode()).hexdigest()


class MerkleTree:
    """
    Merkle Tree: Maintains a list of leaves and recalculates root on each update.
    """
    def __init__(self):
        self.leaves = []
        self.root = None
    
    def _calculate_root(self, leaves):
        """
        Calculate Merkle root from a list of leaves.
        If empty, return None. If single leaf, return it. Otherwise, build tree bottom-up.
        """
        if not leaves:
            return None
        
        if len(leaves) == 1:
            return leaves[0]
        
        current_level = leaves[:]
        
        while len(current_level) > 1:
            next_level = []
            for i in range(0, len(current_level), 2):
                left = current_level[i]
                right = current_level[i + 1] if i + 1 < len(current_level) else left
                combined = hashlib.sha256((left + right).encode()).hexdigest()
                next_level.append(combined)
            current_level = next_level
        
        return current_level[0]
    
    def update(self, new_hash):
        """
        Add a new leaf and recalculate the Merkle root.
        Returns the new root.
        """
        self.leaves.append(new_hash)
        self.root = self._calculate_root(self.leaves)
        return self.root
    
    def get_proof(self, index):
        """
        Generate Merkle proof for a leaf at given index.
        Returns a list of (sibling_hash, position) tuples where position is 'left' or 'right'.
        """
        if index < 0 or index >= len(self.leaves):
            return None
        
        proof = []
        current_index = index
        current_level = self.leaves[:]
        
        while len(current_level) > 1:
            # Determine if current_index is odd (right) or even (left)
            is_right = current_index % 2 == 1
            sibling_index = current_index - 1 if is_right else current_index + 1
            
            # Get sibling hash if it exists
            if sibling_index < len(current_level):
                sibling_hash = current_level[sibling_index]
                position = "left" if is_right else "right"
                proof.append({"hash": sibling_hash, "position": position})
            
            # Move to next level
            current_index = current_index // 2
            next_level = []
            for i in range(0, len(current_level), 2):
                left = current_level[i]
                right = current_level[i + 1] if i + 1 < len(current_level) else left
                combined = hashlib.sha256((left + right).encode()).hexdigest()
                next_level.append(combined)
            current_level = next_level
        
        return proof


def worm_write_tool(step_data, hash_value, merkle_root, filename="worm_log.jsonl"):
    """
    WORM (Write Once, Read Many) storage: Append a JSON record to JSONL file.
    Returns the record written.
    """
    # Determine the next ID by counting existing lines
    next_id = 0
    if os.path.exists(filename):
        with open(filename, "r") as f:
            next_id = sum(1 for _ in f)
    
    record = {
        "id": next_id,
        "timestamp": time.time(),
        "step": step_data,
        "hash": hash_value,
        "root": merkle_root
    }
    
    with open(filename, "a") as f:
        f.write(json.dumps(record) + "\n")
    
    return record


def proof_generate_tool(record_id, filename="worm_log.jsonl"):
    """
    Generate a Merkle proof for a specific record in the WORM log.
    Rehydrates the Merkle Tree from the log to ensure proof is against current state.
    Returns a JSON proof containing hash, merkle_proof, root, and timestamp.
    """
    if not os.path.exists(filename):
        print(f"[PROOF] Error: {filename} does not exist.")
        return None
    
    # Read all records from WORM log
    records = []
    hashes = []
    target_record = None
    
    with open(filename, "r") as f:
        for line in f:
            record = json.loads(line.strip())
            records.append(record)
            hashes.append(record["hash"])
            if record["id"] == record_id:
                target_record = record
    
    if target_record is None:
        print(f"[PROOF] Error: Record with ID {record_id} not found.")
        return None
    
    # Rehydrate Merkle Tree from hashes
    tree = MerkleTree()
    for h in hashes:
        tree.update(h)
    
    # Get proof for the target record index
    proof = tree.get_proof(record_id)
    
    proof_result = {
        "record_id": record_id,
        "hash": target_record["hash"],
        "merkle_proof": proof,
        "merkle_root": tree.root,
        "timestamp": target_record["timestamp"],
        "step_details": target_record["step"]
    }
    
    return proof_result


def verify_proof_tool(target_hash, merkle_proof, merkle_root):
    """
    Verify if a target_hash belongs to the merkle_root using the merkle_proof.
    
    Logic:
    - Start with current_hash = target_hash
    - Loop through proof items (sibling hashes with positions)
    - Reconstruct the path up to the root
    - Compare final calculated root with provided merkle_root
    
    Returns True if valid, False otherwise.
    """
    if merkle_proof is None:
        return False
    
    current_hash = target_hash
    
    # Traverse the proof path
    for proof_item in merkle_proof:
        sibling_hash = proof_item["hash"]
        position = proof_item["position"]
        
        # Combine hashes based on position
        if position == "left":
            # Sibling is on the left, so: hash(sibling + current)
            combined_str = sibling_hash + current_hash
        elif position == "right":
            # Sibling is on the right, so: hash(current + sibling)
            combined_str = current_hash + sibling_hash
        else:
            return False
        
        # Calculate the next level hash
        current_hash = hashlib.sha256(combined_str.encode()).hexdigest()
    
    # Final check: does calculated root match provided root?
    return current_hash == merkle_root


def secure_agent_action(action_type, details, merkle_tree):
    """
    Gatekeeper Logic: Cite-Before-Act mechanism.
    - READ: Auto-approve
    - WRITE/MUTATE: Require human approval via CLI
    All actions (approved or denied) are logged to WORM storage.
    """
    action_type = action_type.upper()
    
    if action_type == "READ":
        # Auto-approve READ actions
        print(f"\n[GATEKEEPER] READ action detected: {details}")
        print("[GATEKEEPER] Auto-approving READ action.")
        
        step_data = {
            "action_type": action_type,
            "details": details,
            "status": "APPROVED"
        }
        
        step_hash = hash_tool(step_data)
        merkle_root = merkle_tree.update(step_hash)
        worm_write_tool(step_data, step_hash, merkle_root)
        
        print(f"[GATEKEEPER] Merkle Root: {merkle_root}")
        print(f"[GATEKEEPER] Action logged.\n")
        return True
    
    elif action_type in ["WRITE", "MUTATE", "DELETE"]:
        # Require approval for mutation actions
        print(f"\n[GATEKEEPER] ⚠️  WRITE/MUTATE action detected: {details}")
        print("[GATEKEEPER] This action requires human approval.")
        
        approval = input("[GATEKEEPER] Approve this action? (y/n): ").strip().lower()
        
        if approval == "y":
            print("[GATEKEEPER] ✓ Action APPROVED by user.")
            status = "APPROVED"
            result = True
        else:
            print("[GATEKEEPER] ✗ Action DENIED by user.")
            status = "DENIED"
            result = False
        
        # Log the action (approved or denied) to maintain audit trail
        step_data = {
            "action_type": action_type,
            "details": details,
            "status": status
        }
        
        step_hash = hash_tool(step_data)
        merkle_root = merkle_tree.update(step_hash)
        worm_write_tool(step_data, step_hash, merkle_root)
        
        print(f"[GATEKEEPER] Merkle Root: {merkle_root}")
        print(f"[GATEKEEPER] Audit logged.\n")
        return result
    
    else:
        print(f"\n[GATEKEEPER] Unknown action type: {action_type}\n")
        return False


if __name__ == "__main__":
    print("=" * 70)
    print("SECURE REASONING MCP SERVER - TEST SCENARIO")
    print("=" * 70)
    
    # Initialize Merkle Tree
    mt = MerkleTree()
    
    # Test 1: READ action (auto-approved)
    print("\n[TEST 1] Simulating READ action...")
    secure_agent_action("READ", "Query user database for profile info", mt)
    
    # Test 2: WRITE action (user approval - simulate "y")
    print("[TEST 2] Simulating WRITE action (approve with 'y')...")
    secure_agent_action("WRITE", "Update user profile with new email address", mt)
    
    # Test 3: WRITE action (user denial - simulate "n")
    print("[TEST 3] Simulating WRITE action (deny with 'n')...")
    secure_agent_action("WRITE", "Delete user account permanently", mt)
    
    print("=" * 70)
    print("TEST SCENARIO COMPLETE")
    print("=" * 70)
    print("\nWORM Log saved to: worm_log.jsonl")
    print("Review the file to verify all actions are logged with hashes and Merkle roots.\n")
    
    # Test 4: Generate proof for record_id=1
    print("=" * 70)
    print("PROOF GENERATION TEST")
    print("=" * 70)
    print("\n[TEST 4] Generating Merkle proof for record_id=1...")
    proof = proof_generate_tool(1)
    if proof:
        print("\n[PROOF] Generated Merkle Proof:")
        print(json.dumps(proof, indent=2))
    else:
        proof = None
    print()
    
    # Test 5: Verify the proof (positive case)
    print("=" * 70)
    print("PROOF VERIFICATION TEST")
    print("=" * 70)
    if proof:
        print("\n[TEST 5a] Verifying proof with correct hash and root...")
        is_valid = verify_proof_tool(proof["hash"], proof["merkle_proof"], proof["merkle_root"])
        print(f"[VERIFY] Verification Result (POSITIVE): {is_valid}")
        
        # Test 5b: Verify with tampered hash (negative case)
        print("\n[TEST 5b] Verifying proof with tampered hash (should fail)...")
        tampered_hash = proof["hash"][:-2] + "XX"  # Change last 2 characters
        is_valid_tampered = verify_proof_tool(tampered_hash, proof["merkle_proof"], proof["merkle_root"])
        print(f"[VERIFY] Verification Result (NEGATIVE - tampered hash): {is_valid_tampered}")
        
        # Test 5c: Verify with tampered root (negative case)
        print("\n[TEST 5c] Verifying proof with tampered root (should fail)...")
        tampered_root = proof["merkle_root"][:-2] + "XX"  # Change last 2 characters
        is_valid_tampered_root = verify_proof_tool(proof["hash"], proof["merkle_proof"], tampered_root)
        print(f"[VERIFY] Verification Result (NEGATIVE - tampered root): {is_valid_tampered_root}")
    
    print("\n" + "=" * 70)
    print("ALL TESTS COMPLETE - SECURE REASONING MCP SERVER OPERATIONAL")
    print("=" * 70 + "\n")