Spaces:
Sleeping
Sleeping
| import sys | |
| from pathlib import Path | |
| sys.path.insert(0, str(Path(__file__).parent)) | |
| import gradio as gr | |
| from server.network_forensics_environment import NetworkForensicsEnvironment | |
| from models import NetworkForensicsAction | |
| env = None | |
| current_obs = None | |
| def reset_env(task_name): | |
| global env, current_obs | |
| env = NetworkForensicsEnvironment(task_id=task_name) | |
| current_obs = env.reset() | |
| return format_obs(current_obs) | |
| def format_obs(obs): | |
| lines = [ | |
| f"**Step**: {obs.step_number}/{obs.steps_remaining}", | |
| f"**Score**: {obs.current_score_estimate:.2f}", | |
| f"**Total Packets**: {obs.total_packets}", | |
| f"**Flagged**: {len(obs.flagged_packet_ids)} packets", | |
| ] | |
| if obs.grouped_sessions: | |
| lines.append(f"**Sessions**: {', '.join(obs.grouped_sessions.keys())}") | |
| if obs.tagged_patterns: | |
| lines.append(f"**Tags**: {obs.tagged_patterns}") | |
| packet_table = "ID|Src|Dst|Port|Protocol|TTL|Size|Preview\n" | |
| packet_table += "-|-|-|-|-|-|-|-\n" | |
| for p in obs.visible_packets[:20]: | |
| preview = p.full_payload if p.is_revealed and p.full_payload else p.payload_preview | |
| packet_table += f"{p.packet_id}|{p.src_ip}|{p.dst_ip}|{p.dst_port}|{p.protocol}|{p.ttl}|{p.payload_size}|{preview}\n" | |
| return "\n".join(lines), packet_table | |
| def step(action_type, packet_id, packet_ids, session_name, pattern_type, claimed_entry_point): | |
| global env, current_obs | |
| if env is None: | |
| return "Please select a task and click Run Episode first", "" | |
| parsed_packet_ids = [value.strip() for value in (packet_ids or "").split(",") if value.strip()] | |
| action = NetworkForensicsAction( | |
| action_type=action_type, | |
| packet_id=packet_id if packet_id else None, | |
| packet_ids=parsed_packet_ids or None, | |
| session_name=session_name if session_name else None, | |
| pattern_type=pattern_type if pattern_type else None, | |
| claimed_entry_point=claimed_entry_point if claimed_entry_point else None, | |
| ) | |
| current_obs = env.step(action) | |
| if current_obs.done: | |
| result = f"Episode complete! Final score: {current_obs.current_score_estimate:.2f}" | |
| else: | |
| result = f"Step {current_obs.step_number}: reward = {current_obs.reward:.2f}" | |
| return format_obs(current_obs)[0], result | |
| with gr.Blocks(title="Network Forensics") as demo: | |
| gr.Markdown("# Network Packet Forensics RL Environment") | |
| gr.Markdown("Analyze network packet captures to identify attack patterns") | |
| with gr.Row(): | |
| with gr.Column(): | |
| task_select = gr.Radio(["easy", "medium", "hard"], label="Task", value="easy") | |
| run_btn = gr.Button("Run Episode", variant="primary") | |
| with gr.Column(): | |
| output_text = gr.Markdown("Click Run Episode to start") | |
| gr.Markdown("### Packet Stream") | |
| packet_display = gr.Dataframe( | |
| headers=["ID", "Src IP", "Dst IP", "Port", "Protocol", "TTL", "Size"], | |
| datatype=["str", "str", "str", "number", "str", "number", "number"], | |
| interactive=False, | |
| ) | |
| gr.Markdown("### Actions") | |
| with gr.Row(): | |
| action_type = gr.Dropdown( | |
| ["inspect_packet", "flag_as_suspicious", "group_into_session", "tag_pattern", "identify_entry_point", "submit_report"], | |
| label="Action", | |
| value="inspect_packet", | |
| ) | |
| packet_id = gr.Textbox(label="Packet ID", placeholder="pkt_0001") | |
| packet_ids = gr.Textbox(label="Packet IDs", placeholder="pkt_0001,pkt_0002") | |
| session_name = gr.Textbox(label="Session Name", placeholder="session_1") | |
| pattern_type = gr.Textbox(label="Pattern", placeholder="ddos / web_xss / heartbleed") | |
| claimed_entry_point = gr.Textbox(label="Claimed Entry Point", placeholder="pkt_0001") | |
| step_btn = gr.Button("Execute Action") | |
| result_display = gr.Markdown("") | |
| run_btn.click(reset_env, task_select, [output_text, packet_display]) | |
| step_btn.click( | |
| step, | |
| [action_type, packet_id, packet_ids, session_name, pattern_type, claimed_entry_point], | |
| [output_text, result_display], | |
| ) | |
| demo.launch(server_port=7860, server_name="0.0.0.0") | |
| if __name__ == "__main__": | |
| demo.launch(server_port=7860, server_name="0.0.0.0") | |