| |
| """ |
| HKLM β Hierarchical Knowledge-grounded Log Mapper |
| API Inference Version β No GPU Required |
| |
| Uses HuggingFace Inference API for LLM calls. |
| Deployable locally, on Google Colab, or HuggingFace Spaces (free CPU tier). |
| |
| FIXED: Added threading + polling loop for continuous log streaming |
| """ |
|
|
| import gradio as gr |
| import pandas as pd |
| from pathlib import Path |
| import json |
| import time |
| import threading |
| from datetime import datetime |
| import sys |
| import io |
| from mitre_analyzer_api import MITRELogAnalyzerAPI |
|
|
|
|
| class LiveLogger: |
| """Thread-safe logger that captures print statements and streams them to Gradio""" |
|
|
| def __init__(self): |
| self.logs = [] |
| self.original_stdout = sys.stdout |
| self._lock = threading.Lock() |
|
|
| def write(self, text): |
| """Capture print statements with proper multi-line and newline handling""" |
| if not text: |
| return |
| |
| |
| self.original_stdout.write(text) |
| self.original_stdout.flush() |
| |
| |
| text = text.rstrip('\n') |
| |
| if text: |
| timestamp = datetime.now().strftime("%H:%M:%S") |
| with self._lock: |
| |
| |
| for line in text.split('\n'): |
| if line.strip(): |
| self.logs.append(f"[{timestamp}] {line}") |
|
|
| def flush(self): |
| self.original_stdout.flush() |
|
|
| def isatty(self): |
| return False |
|
|
| def get_logs(self): |
| """Return all accumulated logs for display""" |
| with self._lock: |
| if not self.logs: |
| return "Waiting for output...\n" |
| |
| return "\n".join(self.logs) |
|
|
| def clear(self): |
| with self._lock: |
| self.logs = [] |
|
|
|
|
| |
| SAMPLE_LOGS = r"""type=EVENT_CONNECT | pid=8428 | cmd=ssh admin@128.55.12.56 |
| type=EVENT_READ | pid=3488 | cmd=scp -r C:\Users\admin\Documents admin@128.55.12.106:./files/ | path=\REGISTRY\MACHINE\SOFTWARE\Microsoft\Windows\CurrentVersion\SideBySide\ |
| type=EVENT_READ | pid=7980 | cmd="C:\Program Files\OpenSSH-Win64\sshd.exe" | path=\REGISTRY\MACHINE\SYSTEM\ControlSet001\Services\Tcpip\Parameters\Winsock\ |
| type=EVENT_CONNECT | pid=9448 | cmd="C:\Program Files\OpenSSH-Win64\ssh.exe" "-x" "-oForwardAgent=no" "-oPermitLocalCommand=no" "-oClearAllForwardings=yes" |
| type=EVENT_READ | pid=4364 | cmd="C:\Program Files\TightVNC\tvnserver.exe" -desktopserver -logdir "C:\WINDOWS\system32\config\systemprofile\AppData\Roami | path=\REGISTRY\MACHINE\SOFTWARE\Microsoft\Windows\CurrentVersion\SideBySide\ |
| type=EVENT_SENDTO | bytes=68 |
| type=EVENT_MODIFY_FILE_ATTRIBUTES | pid=7104 | cmd="C:\Program Files\OpenSSH-Win64\ssh.exe" "-x" "-oForwardAgent=no" "-oPermitLocalCommand=no" "-oClearAllForwardings=yes" | path=\REGISTRY\MACHINE\SYSTEM\ControlSet001\Services\WinSock2\Parameters\NameSpace_Catalog5\ |
| type=EVENT_READ | path=\REGISTRY\MACHINE\SYSTEM\ControlSet001\Services\WinSock2\Parameters\NameSpace_Catalog5\Catalog_Entries64\000000000001\ |
| type=EVENT_READ | path=\REGISTRY\MACHINE\SAM\SAM\Domains\Account\Users\Names\admin\ |
| type=EVENT_MODIFY_FILE_ATTRIBUTES | pid=3784 | cmd=scp -r C:\Users\admin\Documents admin@128.55.12.51:./test/ | path=\REGISTRY\MACHINE\SYSTEM\ControlSet001\Services\WinSock2\Parameters\Protocol_Catalog9\ |
| type=EVENT_READ | pid=3784 | cmd=scp -r C:\Users\admin\Documents admin@128.55.12.51:./test/ | path=\REGISTRY\MACHINE\SOFTWARE\Microsoft\Windows NT\CurrentVersion\Time Zones\Eastern Standard Time\Dynamic DST\ |
| type=EVENT_READ | pid=3784 | cmd=scp -r C:\Users\admin\Documents admin@128.55.12.51:./test/ | path=\REGISTRY\MACHINE\SOFTWARE\Microsoft\Windows NT\CurrentVersion\GRE_Initialize\ |
| type=EVENT_READ | pid=8668 | cmd=C:\WINDOWS\system32\cmd.exe | path=\REGISTRY\USER\S-1-5-21-231540947-922634896-4161786520-1004\Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders\ |
| type=EVENT_READ | pid=8736 | cmd=C:\WINDOWS\system32\cmd.exe | path=\REGISTRY\USER\S-1-5-21-231540947-922634896-4161786520-1004\Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders\ |
| type=EVENT_READ | pid=8460 | cmd=C:\WINDOWS\system32\cmd.exe | path=\REGISTRY\USER\S-1-5-21-231540947-922634896-4161786520-1004\Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders\ |
| """.strip() |
|
|
|
|
| |
| _analyzer = None |
| _current_model = None |
|
|
|
|
| def get_analyzer(model_name, use_caching): |
| """Get or create analyzer (cached globally to avoid reloading)""" |
| global _analyzer, _current_model |
|
|
| if _analyzer is not None and _current_model == model_name: |
| return _analyzer |
|
|
| _analyzer = MITRELogAnalyzerAPI( |
| mitre_kb_path="mitre_detection_kb.json", |
| model_name=model_name, |
| use_caching=use_caching, |
| verbose=True, |
| ) |
| _current_model = model_name |
| return _analyzer |
|
|
|
|
| class GradioMITREAnalyzer: |
| """Wrapper for MITRE analyzer with Gradio UI and live logging""" |
|
|
| def __init__(self): |
| self.logger = None |
|
|
| def _build_dataframe_from_text(self, raw_text): |
| """Convert pasted log text into a DataFrame with raw_text column.""" |
| lines = [l.strip() for l in raw_text.strip().split("\n") if l.strip()] |
| if not lines: |
| return None |
|
|
| if "raw_text" in lines[0].lower() and "," in lines[0]: |
| try: |
| df = pd.read_csv(io.StringIO(raw_text)) |
| if "raw_text" in df.columns: |
| return df |
| except Exception: |
| pass |
|
|
| return pd.DataFrame({"raw_text": lines}) |
|
|
| def _format_statistics(self, df, processed, total, elapsed): |
| stats = [] |
| stats.append("=" * 50) |
| stats.append("π STATISTICS") |
| stats.append("=" * 50) |
| pct = (processed / total * 100) if total > 0 else 0 |
| stats.append(f"Progress: {processed:,} / {total:,} ({pct:.1f}%)") |
| rate = processed / elapsed if elapsed > 0 else 0 |
| stats.append(f"Rate: {rate:.2f} events/sec") |
| stats.append(f"Time: {elapsed:.1f}s") |
|
|
| for col in ["tactic", "predicted_tactic"]: |
| if col in df.columns: |
| stats.append(f"\nπ― Tactic Distribution:") |
| for tactic, count in df[col].value_counts().head(5).items(): |
| stats.append(f" β’ {tactic}: {count} ({count/len(df)*100:.1f}%)") |
| break |
|
|
| for col in ["confidence_score", "confidence"]: |
| if col in df.columns: |
| stats.append(f"\nπ Avg Confidence: {df[col].mean():.1%}") |
| break |
|
|
| stats.append("=" * 50) |
| return "\n".join(stats) |
|
|
| def analyze( |
| self, |
| file_path, |
| text_input, |
| model_name, |
| max_logs, |
| use_caching, |
| verbose, |
| progress=gr.Progress() |
| ): |
| """Analyze logs using HF Inference API with live streaming""" |
|
|
| |
| df = None |
| input_source = None |
|
|
| if text_input and text_input.strip(): |
| df = self._build_dataframe_from_text(text_input) |
| input_source = "text" |
| if df is None or len(df) == 0: |
| yield None, "β οΈ Could not parse any log events from text input!", "", "" |
| return |
| elif file_path is not None: |
| try: |
| df = pd.read_csv(file_path) |
| input_source = "file" |
| except Exception as e: |
| yield None, f"β οΈ Error reading CSV: {e}", "", "" |
| return |
| else: |
| yield None, "β οΈ Please upload a CSV file or paste logs in the text field!", "", "" |
| return |
|
|
| |
| if "raw_text" not in df.columns: |
| if len(df.columns) == 1: |
| df.columns = ["raw_text"] |
| else: |
| yield None, "β οΈ CSV must contain a 'raw_text' column!", "", "" |
| return |
|
|
| old_stdout = sys.stdout |
| old_stderr = sys.stderr |
|
|
| try: |
| self.logger = LiveLogger() |
| sys.stdout = self.logger |
| sys.stderr = self.logger |
|
|
| yield None, "π Starting analysis...", "", self.logger.get_logs() |
|
|
| |
| progress(0.0, desc="π§ Initializing...") |
| print("=" * 80) |
| print("π€ HKLM β API Inference Mode") |
| print("=" * 80) |
| print(f"Model: {model_name}") |
| print(f"Caching: {'Enabled' if use_caching else 'Disabled'}") |
| print(f"Input source: {'Text field' if input_source == 'text' else 'CSV file'}") |
| print(f"β‘ Using HF Inference API β no local GPU required") |
| print("=" * 80) |
|
|
| yield None, "π§ Connecting to API...", "", self.logger.get_logs() |
|
|
| |
| analyzer_container = {"analyzer": None, "done": False, "error": None} |
|
|
| def _load_analyzer(): |
| try: |
| analyzer_container["analyzer"] = get_analyzer(model_name, use_caching) |
| except Exception as e: |
| analyzer_container["error"] = e |
| finally: |
| analyzer_container["done"] = True |
|
|
| analyzer_thread = threading.Thread(target=_load_analyzer, daemon=True) |
| analyzer_thread.start() |
|
|
| |
| while not analyzer_container["done"]: |
| time.sleep(0.5) |
| yield None, "π§ Connecting to API...", "", self.logger.get_logs() |
|
|
| analyzer_thread.join() |
|
|
| if analyzer_container["error"]: |
| raise analyzer_container["error"] |
|
|
| analyzer = analyzer_container["analyzer"] |
|
|
| if verbose: |
| analyzer.verbose = True |
| print("π Verbose mode enabled") |
|
|
| print("β
Analyzer ready!") |
| yield None, "β
Analyzer ready!", "", self.logger.get_logs() |
|
|
| |
| progress(0.05, desc="π Loading log data...") |
| total_logs = len(df) |
| print("") |
| print("=" * 80) |
| print("π LOADING LOG DATA") |
| print("=" * 80) |
| print(f"Input source: {'Pasted text' if input_source == 'text' else 'CSV file'}") |
| print(f"Total events: {total_logs:,}") |
|
|
| if max_logs and max_logs > 0: |
| df = df.head(int(max_logs)) |
| print(f"Processing first {len(df):,} events (max_logs setting)") |
|
|
| print(f"β
Ready to process {len(df):,} events") |
| print("=" * 80) |
|
|
| num_events = len(df) |
| yield None, f"π Loaded {num_events:,} events", "", self.logger.get_logs() |
|
|
| |
| progress(0.1, desc="π Processing events...") |
| all_results = [] |
| start_time = time.time() |
|
|
| for i, (idx, row) in enumerate(df.iterrows()): |
| log_entry = row["raw_text"] |
| event_progress = 0.1 + 0.85 * (i / num_events) |
| progress(event_progress, desc=f"β‘ Event {i+1}/{num_events}") |
|
|
| print("") |
| print("β" * 80) |
| print(f"π EVENT {i+1}/{num_events}") |
| print("β" * 80) |
| print(f"Log: {log_entry[:120]}...") |
|
|
| |
| if use_caching and analyzer.use_caching: |
| cache_key = analyzer._get_cache_key(log_entry) |
| if cache_key in analyzer.cache: |
| analyzer.stats["cache_hits"] += 1 |
| prediction = analyzer.cache[cache_key] |
| if prediction: |
| result = analyzer._create_result_dict(idx, row, prediction) |
| all_results.append(result) |
| print(f" β‘ CACHE HIT β {prediction.tactic} / {prediction.technique_id}") |
| |
| |
| current_df = pd.DataFrame(all_results) |
| elapsed = time.time() - start_time |
| status_msg = f"β
{len(all_results)} results | Event {i+1}/{num_events} β‘ Cached" |
| stats = self._format_statistics(current_df, len(all_results), num_events, elapsed) |
| yield current_df, status_msg, stats, self.logger.get_logs() |
| continue |
| analyzer.stats["cache_misses"] += 1 |
|
|
| |
| event_container = {"prediction": None, "done": False, "error": None} |
|
|
| def _analyze_event(): |
| try: |
| event_container["prediction"] = analyzer._analyze_single(log_entry) |
| except Exception as e: |
| event_container["error"] = e |
| finally: |
| event_container["done"] = True |
|
|
| event_thread = threading.Thread(target=_analyze_event, daemon=True) |
| event_thread.start() |
|
|
| |
| while not event_container["done"]: |
| time.sleep(0.5) |
| elapsed = time.time() - start_time |
| if all_results: |
| current_df = pd.DataFrame(all_results) |
| status_msg = f"β
{len(all_results)} results | Event {i+1}/{num_events} analyzing...\nCheck Results Table for updates." |
| stats = self._format_statistics(current_df, len(all_results), num_events, elapsed) |
| yield current_df, status_msg, stats, self.logger.get_logs() |
| else: |
| yield None, f"β³ Event {i+1}/{num_events} analyzing...", "", self.logger.get_logs() |
|
|
| event_thread.join() |
|
|
| if event_container["error"]: |
| print(f" β Failed: {event_container['error']}") |
| prediction = None |
| else: |
| prediction = event_container["prediction"] |
|
|
| |
| if prediction: |
| if use_caching and analyzer.use_caching: |
| cache_key = analyzer._get_cache_key(log_entry) |
| analyzer.cache[cache_key] = prediction |
|
|
| result = analyzer._create_result_dict(idx, row, prediction) |
| all_results.append(result) |
|
|
| print(f"") |
| print(f" β
{prediction.tactic} β {prediction.technique_id} ({prediction.technique_name})") |
| print(f" Confidence: {prediction.confidence_score:.2f}") |
| print(f" Mitigations: {len(prediction.mitigation_strategies)}") |
|
|
| |
| elapsed = time.time() - start_time |
| if all_results: |
| current_df = pd.DataFrame(all_results) |
| status_msg = f"β
{len(all_results)} results | Event {i+1}/{num_events} complete" |
| stats = self._format_statistics(current_df, len(all_results), num_events, elapsed) |
| else: |
| current_df = None |
| status_msg = f"β³ Event {i+1}/{num_events} complete" |
| stats = "" |
| |
| yield current_df, status_msg, stats, self.logger.get_logs() |
|
|
| |
| progress(0.95, desc="πΎ Done!") |
|
|
| if not all_results: |
| yield None, "β No results produced!", "", self.logger.get_logs() |
| return |
|
|
| final_df = pd.DataFrame(all_results) |
|
|
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| output_dir = Path("gradio_results") |
| output_dir.mkdir(exist_ok=True) |
| output_path = output_dir / f"hklm_results_{timestamp}.csv" |
| final_df.to_csv(output_path, index=False) |
|
|
| total_time = time.time() - start_time |
| final_stats = self._format_statistics(final_df, len(final_df), num_events, total_time) |
|
|
| print("") |
| print("=" * 80) |
| print("β
ANALYSIS COMPLETE") |
| print("=" * 80) |
| print(f"Total events: {len(final_df):,}") |
| print(f"Time: {total_time:.1f}s ({len(final_df)/total_time:.1f} events/sec)") |
| print(f"API calls: {analyzer.stats['api_calls']}") |
| print(f"Cache hits: {analyzer.stats['cache_hits']}") |
| print(f"Results saved: {output_path}") |
| print("=" * 80) |
|
|
| yield final_df, f"β
Complete! {len(final_df):,} events in {total_time:.1f}s", final_stats, self.logger.get_logs() |
|
|
| except Exception as e: |
| print(f"") |
| print(f"β ERROR: {e}") |
| import traceback |
| print(traceback.format_exc()) |
| yield None, f"β Error: {e}", "", self.logger.get_logs() |
|
|
| finally: |
| sys.stdout = old_stdout |
| sys.stderr = old_stderr |
|
|
|
|
| def create_interface(): |
| """Create the Gradio interface""" |
|
|
| analyzer = GradioMITREAnalyzer() |
|
|
| with gr.Blocks( |
| title="HKLM β MITRE ATT&CK Log Mapper", |
| theme=gr.themes.Soft(), |
| ) as interface: |
|
|
| gr.Markdown( |
| "# π‘οΈ HKLM β Hierarchical Knowledge-grounded Log Mapper\n" |
| "Map any raw system log to MITRE ATT&CK tactics and techniques using open-source LLMs\n\n" |
| "*CIS 544-01: Cyber Defense and Operations β Minal Ali & Fnu Mahnoor*\n\n" |
| "> β‘ **Demo mode:** Uses HuggingFace Inference API β no GPU required. " |
| "For full-speed GPU inference, see the " |
| "[GitHub repo](https://github.com/mahnoor-khalid9/hierarchical-knowledge-grounded-log-mapper)." |
| ) |
|
|
| |
| with gr.Row(): |
| |
| with gr.Column(scale=1): |
| gr.Markdown("### π Log Input") |
| with gr.Tabs() as input_tabs: |
| with gr.Tab("π Paste Logs"): |
| text_input = gr.Textbox( |
| label="Paste log events (one per line)", |
| lines=8, |
| max_lines=20, |
| value=SAMPLE_LOGS, |
| ) |
| load_sample_btn = gr.Button("π Load Sample Logs", size="sm") |
|
|
| with gr.Tab("π Upload CSV"): |
| file_input = gr.File( |
| label="Upload CSV with 'raw_text' column", |
| file_types=[".csv", ".tsv"] |
| ) |
|
|
| gr.Markdown("### βοΈ Settings") |
| model_choice = gr.Dropdown( |
| choices=[ |
| "llama-3.1-8b-instant", |
| "allam-2-7b", |
| ], |
| value="llama-3.1-8b-instant", |
| label="π€ Model", |
| info="Mistral 7B: Best reasoning | Qwen 3B: Fast | Phi-3.5: Good JSON", |
| ) |
| max_logs = gr.Number( |
| value=None, |
| label="Max Events", |
| info="Leave empty to process all events" |
| ) |
| use_caching = gr.Checkbox( |
| value=True, |
| label="Semantic Caching", |
| info="Skip duplicate events via MD5 hash" |
| ) |
| verbose = gr.Checkbox( |
| value=True, |
| label="Verbose Logging", |
| info="Show per-event model outputs" |
| ) |
|
|
| |
| with gr.Column(scale=2): |
| |
| with gr.Row(variant="panel"): |
| with gr.Column(scale=1): |
| gr.Markdown("### π Results") |
| with gr.Column(scale=0, min_width=120): |
| analyze_btn = gr.Button( |
| "π RUN", |
| variant="primary", |
| size="lg" |
| ) |
|
|
| |
| status_box = gr.Textbox( |
| label="Status", |
| value="Ready β paste logs or upload a CSV to begin", |
| max_lines=1 |
| ) |
|
|
| |
| with gr.Tabs(): |
| with gr.Tab("π Live Logs"): |
| logs_box = gr.Textbox( |
| label="Processing Logs (Live Stream)", |
| lines=25, |
| max_lines=50, |
| autoscroll=True |
| ) |
| with gr.Tab("π Results Table"): |
| results_table = gr.Dataframe( |
| label="Analysis Results", |
| headers=["raw_text", "tactic", "technique_id", "technique_name", "confidence_score"], |
| max_height=400 |
| ) |
| with gr.Tab("π Statistics"): |
| stats_box = gr.Textbox( |
| label="Statistics", |
| lines=20, |
| max_lines=30 |
| ) |
|
|
| |
| load_sample_btn.click(fn=lambda: SAMPLE_LOGS, outputs=[text_input]) |
|
|
| analyze_btn.click( |
| fn=analyzer.analyze, |
| inputs=[ |
| file_input, text_input, model_choice, |
| max_logs, use_caching, verbose, |
| ], |
| outputs=[results_table, status_box, stats_box, logs_box], |
| ) |
|
|
| gr.Markdown(""" |
| ### π How to Use |
| 1. **Paste logs** directly into the text field (one event per line) **or upload a CSV** with a `raw_text` column |
| 2. Select a model and adjust settings |
| 3. Click **Run Analysis** β watch live progress in the Live Logs tab |
| 4. View results in the Results Table tab |
| |
| ### π Key Concepts |
| - **Log-source agnostic:** works with any raw text log β Windows events, syslog, firewall, cloud, etc. |
| - **Post-analysis framework:** processes collected logs in batch, not real-time |
| - **Knowledge-constrained:** LLM picks from ATT&CK KB options β doesn't hallucinate from memory |
| - **API demo mode:** uses HF Inference API β for GPU-accelerated batch processing, use the [local version](https://github.com/mahnoor-khalid9/hierarchical-knowledge-grounded-log-mapper) |
| """) |
|
|
| return interface |
|
|
|
|
| if __name__ == "__main__": |
| print("π Starting HKLM β API Inference Mode...") |
| print("=" * 80) |
|
|
| interface = create_interface() |
| interface.launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| share=True, |
| show_error=True, |
| inbrowser=True, |
| ) |