File size: 4,242 Bytes
4440ec1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent))

import gradio as gr
from server.network_forensics_environment import NetworkForensicsEnvironment
from models import NetworkForensicsAction


env = None
current_obs = None


def reset_env(task_name):
    global env, current_obs
    env = NetworkForensicsEnvironment(task_id=task_name)
    current_obs = env.reset()
    return format_obs(current_obs)


def format_obs(obs):
    lines = [
        f"**Step**: {obs.step_number}/{obs.steps_remaining}",
        f"**Score**: {obs.current_score_estimate:.2f}",
        f"**Total Packets**: {obs.total_packets}",
        f"**Flagged**: {len(obs.flagged_packet_ids)} packets",
    ]
    if obs.grouped_sessions:
        lines.append(f"**Sessions**: {', '.join(obs.grouped_sessions.keys())}")
    if obs.tagged_patterns:
        lines.append(f"**Tags**: {obs.tagged_patterns}")
    packet_table = "ID|Src|Dst|Port|Protocol|TTL|Size|Preview\n"
    packet_table += "-|-|-|-|-|-|-|-\n"
    for p in obs.visible_packets[:20]:
        preview = p.full_payload if p.is_revealed and p.full_payload else p.payload_preview
        packet_table += f"{p.packet_id}|{p.src_ip}|{p.dst_ip}|{p.dst_port}|{p.protocol}|{p.ttl}|{p.payload_size}|{preview}\n"
    return "\n".join(lines), packet_table


def step(action_type, packet_id, packet_ids, session_name, pattern_type, claimed_entry_point):
    global env, current_obs
    if env is None:
        return "Please select a task and click Run Episode first", ""

    parsed_packet_ids = [value.strip() for value in (packet_ids or "").split(",") if value.strip()]
    action = NetworkForensicsAction(
        action_type=action_type,
        packet_id=packet_id if packet_id else None,
        packet_ids=parsed_packet_ids or None,
        session_name=session_name if session_name else None,
        pattern_type=pattern_type if pattern_type else None,
        claimed_entry_point=claimed_entry_point if claimed_entry_point else None,
    )
    current_obs = env.step(action)

    if current_obs.done:
        result = f"Episode complete! Final score: {current_obs.current_score_estimate:.2f}"
    else:
        result = f"Step {current_obs.step_number}: reward = {current_obs.reward:.2f}"

    return format_obs(current_obs)[0], result


with gr.Blocks(title="Network Forensics") as demo:
    gr.Markdown("# Network Packet Forensics RL Environment")
    gr.Markdown("Analyze network packet captures to identify attack patterns")

    with gr.Row():
        with gr.Column():
            task_select = gr.Radio(["easy", "medium", "hard"], label="Task", value="easy")
            run_btn = gr.Button("Run Episode", variant="primary")

        with gr.Column():
            output_text = gr.Markdown("Click Run Episode to start")

    gr.Markdown("### Packet Stream")

    packet_display = gr.Dataframe(
        headers=["ID", "Src IP", "Dst IP", "Port", "Protocol", "TTL", "Size"],
        datatype=["str", "str", "str", "number", "str", "number", "number"],
        interactive=False,
    )

    gr.Markdown("### Actions")

    with gr.Row():
        action_type = gr.Dropdown(
            ["inspect_packet", "flag_as_suspicious", "group_into_session", "tag_pattern", "identify_entry_point", "submit_report"],
            label="Action",
            value="inspect_packet",
        )
        packet_id = gr.Textbox(label="Packet ID", placeholder="pkt_0001")
        packet_ids = gr.Textbox(label="Packet IDs", placeholder="pkt_0001,pkt_0002")
        session_name = gr.Textbox(label="Session Name", placeholder="session_1")
        pattern_type = gr.Textbox(label="Pattern", placeholder="ddos / web_xss / heartbleed")
        claimed_entry_point = gr.Textbox(label="Claimed Entry Point", placeholder="pkt_0001")

    step_btn = gr.Button("Execute Action")

    result_display = gr.Markdown("")

    run_btn.click(reset_env, task_select, [output_text, packet_display])
    step_btn.click(
        step,
        [action_type, packet_id, packet_ids, session_name, pattern_type, claimed_entry_point],
        [output_text, result_display],
    )

demo.launch(server_port=7860, server_name="0.0.0.0")


if __name__ == "__main__":
    demo.launch(server_port=7860, server_name="0.0.0.0")