File size: 20,183 Bytes
082d5c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39fb0c4
 
082d5c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39fb0c4
 
082d5c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
#!/usr/bin/env python3
"""
HKLM β€” MITRE ATT&CK Log Analyzer (Groq API Version)

Uses Groq API for fast LLM inference.
Everything else is identical:
  - Same prompts
  - Same JSON parsing
  - Same KB lookup
  - Same caching
  - Same output format

Requires: pip install groq
Set Groq token: export GROQ_API_KEY=gsk_xxxxx (or pass to constructor)

IMPROVED: Better logging messages for UI display during processing
"""

import json
import pandas as pd
from pathlib import Path
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
import re
from tqdm import tqdm
import warnings
import hashlib
import os
import time
import sys
import os
from dotenv import load_dotenv

# Load the variables from .env into the environment
load_dotenv()


try:
    from groq import Groq
except ImportError:
    raise ImportError("Install groq: pip install groq")


@dataclass
class MITREPrediction:
    """Data class for MITRE predictions"""
    tactic: str
    technique: str
    technique_id: str
    technique_name: str
    confidence_score: float
    mitigation_strategies: List[Dict]
    detection_strategies: List[Dict]
    reasoning: str


class MITRELogAnalyzerAPI:
    """
    HKLM analyzer using Groq API.

    Same 3-stage pipeline:
      Stage 1: Raw log β†’ Tactic identification (API call)
      Stage 2: Log + Tactic β†’ Technique identification (API call)
      Stage 3: KB lookup for mitigation/detection strategies

    No GPU required. Just a Groq API key.
    """

    SUPPORTED_MODELS = {
        "llama-3.3-70b-versatile": "Llama 3.3 70B β€” Best reasoning, fast",
        "llama-3.1-8b-instant": "Llama 3.1 8B β€” Fastest",
        "llama-3.2-90b-text-preview": "Llama 3.2 90B β€” Most capable",
        "mixtral-8x7b-32768": "Mixtral 8x7B β€” Good balance",
    }

    def __init__(
        self,
        mitre_kb_path: str,
        model_name: str = "llama-3.1-8b-instant",
        groq_api_key: str = None,
        use_caching: bool = True,
        verbose: bool = False,
    ):
        self.model_name = model_name
        self.use_caching = use_caching
        self.verbose = verbose

        # Resolve token
        api_key = groq_api_key or os.environ.get("GROQ_API_KEY")
        if not api_key:
            raise ValueError(
                "Groq API key required. Set GROQ_API_KEY env var or pass groq_api_key parameter.\n"
                "Get your key at: https://console.groq.com/keys"
            )

        # Initialize Groq Client
        print(f"\nπŸ€– HKLM β€” Groq API Mode", flush=True)
        print(f"   Model: {model_name}", flush=True)
        print(f"   Caching: {'Enabled' if use_caching else 'Disabled'}", flush=True)
        print(f"   Auth: API key provided ({'*' * 8}...{api_key[-4:]})", flush=True)

        print(f"   Initializing Groq client...", flush=True)
        self.client = Groq(api_key=api_key)

        # Test the connection
        print(f"   Testing API connection...", flush=True)
        try:
            test_response = self.client.chat.completions.create(
                model=self.model_name,
                messages=[{"role": "user", "content": "Hello"}],
                max_tokens=5,
                temperature=0.1,
            )
            print(f"   βœ… API connection successful", flush=True)
        except Exception as e:
            print(f"   ⚠️  API test failed: {e}", flush=True)
            print(f"   Will retry on actual requests", flush=True)

        # Load MITRE knowledge base
        print(f"\nπŸ“š Loading MITRE knowledge base from {mitre_kb_path}...", flush=True)
        with open(mitre_kb_path, "r", encoding="utf-8") as f:
            self.mitre_kb = json.load(f)

        num_tactics = len(self.mitre_kb.get("tactics", {}))
        print(f"   βœ… Loaded {num_tactics} tactics from knowledge base", flush=True)

        # Initialize cache
        self.cache = {} if use_caching else None

        # Stats
        self.stats = {
            "cache_hits": 0,
            "cache_misses": 0,
            "total_processed": 0,
            "api_calls": 0,
            "api_errors": 0,
        }

        print(f"\nβœ… INITIALIZATION COMPLETE", flush=True)
        print(f"   Ready to analyze logs using {model_name}", flush=True)

    # ──────────────────────────────────────────────
    # PROMPT CONSTRUCTION
    # ──────────────────────────────────────────────

    def _create_tactic_prompt(self, log_entry: str) -> str:
        tactic_summaries = []
        for tactic_key, tactic_data in self.mitre_kb.get("tactics", {}).items():
            tactic_summaries.append(
                f"- Name: ({tactic_data['shortname']}) "
                f"Description: {tactic_data['description']}."
            )
        
        prompt = f"""
            You are a Senior Security Operations Center (SOC) Analyst and MITRE ATT&CK Specialist. Your task is to classify security logs into the most accurate Tactic.

            ### STEP 1: LOG DECONSTRUCTION
            Examine the log below. Even if field names are unfamiliar, identify:
            1. **The Actor (Subject):** Which process, user, or service is initiating the action?
            2. **The Action (Verb):** What is happening? (e.g., Is something being accessed, created, modified, executed, or moved?)
            3. **The Target (Object):** What is being acted upon? (e.g., a registry key, a file path, a network socket, a memory address, or a configuration setting.)

            ### STEP 2: INTENT INFERENCE
            Based on the Subject-Verb-Object relationship:
            - If the Verb is 'Read/Query' and the Object is 'System Config/Registry', the intent is **Discovery**.
            - If the Verb is 'Execute/Start' and the Object is 'Binary/Script', the intent is **Execution**.
            - If the Verb is 'Modify/Write' and the Object is 'Auto-run location/Service', the intent is **Persistence**.

            ### LOG ENTRY:
            {log_entry}

            ### MITRE ATT&CK TACTIC DEFINITIONS:
            {chr(10).join(tactic_summaries)}

            ### ANALYSIS GUIDELINES:
            - **Do not rely on field labels:** Focus on the *content* of the values.
            - **Context Matters:** Remote access tools performing routine reads are likely in a 'Discovery' phase.
            - **Match the Goal:** Choose the Tactic whose definition most closely matches the *primary goal* of the action.

            ### OUTPUT FORMAT (JSON):
            {{
                "tactic_name": "Exact name from provided list",
                "confidence": 0.0-1.0,
                "reasoning": "Explain in less than 50 words how the inferred action matches the specific MITRE description."
            }}

            JSON Response:
        """
        return prompt

    def _create_technique_prompt(
        self, log_entry: str, tactic_shortname: str
    ) -> str:
        tactic_shortname = tactic_shortname.lower()
        tactic_data = self.mitre_kb.get("tactics", {}).get(tactic_shortname, {})
        techniques = tactic_data.get("techniques", [])

        technique_summaries = []
        for tech in techniques:
            tech_summary = f"- (technique_id: {tech['attack_id']}, Name: {tech['name']})"
            technique_summaries.append(tech_summary)

        if not technique_summaries:
            print("No techniques available for this tactic.", flush=True)
        
        prompt = f"""
            You are a Senior Security Operations Center (SOC) Analyst and MITRE ATT&CK Specialist.

            ### CONTEXT
            You have already identified the MITRE ATT&CK Tactic for this log: **{tactic_shortname}**

            Now, **select the most specific Technique** from the list below that best explains *how* the attacker achieved this goal.

            ### LOG ENTRY:
            {log_entry}

            ### AVAILABLE TECHNIQUES FOR "{tactic_shortname}" TACTIC:
            {chr(10).join(technique_summaries)}

            ### SELECTION CRITERIA:
            1. **Match Technical Indicators:**
               - Does the log show a registry modification? β†’ Likely T1547 (Boot/Logon Autostart)
               - Does it show scheduled task creation? β†’ Likely T1053 (Scheduled Task/Job)
               - Does it access sensitive files (SAM, credentials)? β†’ Likely T1003 (Credential Dumping)

            2. **Be Specific:**
               - If the log shows "reading SAM database", choose T1003 (OS Credential Dumping) over generic Discovery techniques.

            3. **Confidence = Specificity:**
               - If the log precisely matches a technique's definition β†’ 0.8–1.0 confidence
               - If it's likely but ambiguous β†’ 0.5–0.7
               - If it's the best guess among poor fits β†’ 0.3–0.5

            ### OUTPUT FORMAT (JSON):
            {{
                "technique_id": "Exact technique_id from the list",
                "technique_name": "Exact name from the list",
                "confidence": 0.0-1.0,
                "reasoning": "In less than 50 words, explain which specific indicators in the log match this technique."
            }}

            JSON Response:
        """
        return prompt

    # ──────────────────────────────────────────────
    # JSON EXTRACTION
    # ──────────────────────────────────────────────

    def _extract_json_from_response(self, response: str, verbose: bool = False) -> Optional[Dict]:
        """
        Extract JSON from LLM response using multiple strategies.
        Returns parsed JSON dict or None if extraction fails.
        """
        try:
            # Strategy 1: Try direct JSON parse
            try:
                return json.loads(response)
            except json.JSONDecodeError:
                pass

            # Strategy 2: Extract JSON block from markdown
            json_pattern = r"```(?:json)?\s*(\{.*?\})\s*```"
            matches = re.findall(json_pattern, response, re.DOTALL)
            if matches:
                for match in matches:
                    try:
                        return json.loads(match)
                    except json.JSONDecodeError:
                        continue

            # Strategy 3: Find JSON object with regex
            json_obj_pattern = r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}'
            matches = re.findall(json_obj_pattern, response, re.DOTALL)
            if matches:
                for match in matches:
                    try:
                        parsed = json.loads(match)
                        if isinstance(parsed, dict):
                            return parsed
                    except json.JSONDecodeError:
                        continue

            # Strategy 4: Try cleaning common issues
            cleaned = response.strip()
            if not cleaned.startswith('{'):
                first_brace = cleaned.find('{')
                if first_brace != -1:
                    cleaned = cleaned[first_brace:]
            if not cleaned.endswith('}'):
                last_brace = cleaned.rfind('}')
                if last_brace != -1:
                    cleaned = cleaned[:last_brace + 1]
            try:
                return json.loads(cleaned)
            except json.JSONDecodeError:
                pass

            if verbose:
                print(f"  βœ— All JSON extraction strategies failed", flush=True)
            return None

        except Exception as e:
            if verbose:
                print(f"  βœ— JSON extraction error: {e}", flush=True)
            return None

    # ──────────────────────────────────────────────
    # GROQ API INFERENCE
    # ──────────────────────────────────────────────

    def _api_generate(self, prompt: str, retries: int = 3) -> str:
        """
        Generate a response via Groq API.
        """
        for attempt in range(retries):
            try:
                self.stats["api_calls"] += 1

                completion = self.client.chat.completions.create(
                    model=self.model_name,
                    messages=[{"role": "user", "content": prompt}],
                    max_tokens=512,
                    temperature=0.1,
                )
                return completion.choices[0].message.content

            except Exception as e:
                error_str = str(e)
                self.stats["api_errors"] += 1

                if "rate_limit" in error_str.lower() or "429" in error_str:
                    wait = (attempt + 1) * 10
                    print(f"  ⚠️  Rate limited, waiting {wait}s (attempt {attempt+1}/{retries})", flush=True)
                    time.sleep(wait)
                elif "503" in error_str or "unavailable" in error_str.lower():
                    wait = 30
                    print(f"  ⏳ Service unavailable, waiting {wait}s...", flush=True)
                    time.sleep(wait)
                else:
                    print(f"  ⚠️  API error: {error_str}", flush=True)
                    if attempt < retries - 1:
                        time.sleep(5)
                    else:
                        raise

        return ""

    # ──────────────────────────────────────────────
    # KB LOOKUP
    # ──────────────────────────────────────────────

    def _get_strategies(
        self, tactic_shortname: str, technique_id: str
    ) -> Tuple[List[Dict], List[Dict]]:
        mitigations = []
        detections = []
        try:
            tactic_data = self.mitre_kb.get("tactics", {}).get(tactic_shortname, {})
            techniques = tactic_data.get("techniques", [])
            for tech in techniques:
                if tech.get("attack_id") == technique_id:
                    raw_mitigations = tech.get("mitigations", [])
                    mitigations = [
                        {"name": m.get("name"), "strategy": m.get("description")}
                        for m in raw_mitigations
                    ]
                    detections = tech.get("detection_strategies", [])
                    break
        except Exception:
            pass
        return mitigations, detections

    def _get_cache_key(self, log_entry: str) -> str:
        normalized = log_entry.lower().strip()
        return hashlib.md5(normalized.encode()).hexdigest()

    # ──────────────────────────────────────────────
    # SINGLE EVENT ANALYSIS
    # ──────────────────────────────────────────────

    def _analyze_single(self, log_entry: str) -> Optional[MITREPrediction]:
        verbose = self.verbose

        if verbose:
            print(f"\n  [Stage 1/3] Identifying Tactic...", flush=True)
            print(f"    Log: {log_entry[:100]}...", flush=True)

        tactic_prompt = self._create_tactic_prompt(log_entry)
        tactic_response = self._api_generate(tactic_prompt)
        tactic_result = self._extract_json_from_response(tactic_response, verbose=verbose)

        if not tactic_result:
            if verbose:
                print(f"    βœ— Tactic extraction failed", flush=True)
            return None

        if verbose:
            print(f"    βœ“ Tactic: {tactic_result.get('tactic_name')} (conf: {tactic_result.get('confidence', 0):.2f})", flush=True)
            print(f"  [Stage 2/3] Identifying Technique...", flush=True)

        technique_prompt = self._create_technique_prompt(
            log_entry,
            tactic_result.get("tactic_name", ""),
        )
        technique_response = self._api_generate(technique_prompt)
        technique_result = self._extract_json_from_response(technique_response, verbose=verbose)

        if not technique_result:
            if verbose:
                print(f"    βœ— Technique extraction failed", flush=True)
            return None

        if verbose:
            print(f"    βœ“ Technique: {technique_result.get('technique_id')} β€” {technique_result.get('technique_name')} (conf: {technique_result.get('confidence', 0):.2f})", flush=True)
            print(f"  [Stage 3/3] Retrieving mitigation strategies...", flush=True)

        mitigation_strategies, detection_strategies = self._get_strategies(
            tactic_result.get("tactic_name", ""),
            technique_result.get("technique_id", ""),
        )

        if verbose:
            print(f"    βœ“ Retrieved {len(mitigation_strategies)} mitigations", flush=True)

        prediction = MITREPrediction(
            tactic=tactic_result.get("tactic_name", ""),
            technique=technique_result.get("technique_id", ""),
            technique_id=technique_result.get("technique_id", ""),
            technique_name=technique_result.get("technique_name", ""),
            confidence_score=min(
                tactic_result.get("confidence", 0),
                technique_result.get("confidence", 0),
            ),
            mitigation_strategies=mitigation_strategies,
            detection_strategies=detection_strategies,
            reasoning=f"Tactic: {tactic_result.get('reasoning', '')}; "
                      f"Technique: {technique_result.get('reasoning', '')}",
        )

        self.stats["total_processed"] += 1
        return prediction

    def _create_result_dict(self, idx, row, prediction: MITREPrediction) -> Dict:
        return {
            "log_index": idx,
            "raw_text": row["raw_text"],
            "tactic": prediction.tactic.title(),
            "technique_id": prediction.technique_id,
            "technique_name": prediction.technique_name,
            "confidence_score": prediction.confidence_score,
            "reasoning": prediction.reasoning,
            "num_mitigations": len(prediction.mitigation_strategies),
            "mitigation_strategies": json.dumps(prediction.mitigation_strategies),
        }

    # ──────────────────────────────────────────────
    # MAIN PROCESSING
    # ──────────────────────────────────────────────

    def _process_batched(self, df: pd.DataFrame) -> List[Dict]:
        """
        Process logs using Groq API inference.
        """
        all_results = []

        for idx, row in tqdm(df.iterrows(), desc="Processing events", total=len(df)):
            log_entry = row["raw_text"]

            if self.use_caching:
                cache_key = self._get_cache_key(log_entry)
                if cache_key in self.cache:
                    self.stats["cache_hits"] += 1
                    prediction = self.cache[cache_key]
                    if prediction:
                        all_results.append(self._create_result_dict(idx, row, prediction))
                    continue
                self.stats["cache_misses"] += 1

            print(f"\nπŸ“ Event {idx + 1}/{len(df)}", flush=True)
            try:
                prediction = self._analyze_single(log_entry)
            except Exception as e:
                print(f"  ❌ Failed: {e}", flush=True)
                prediction = None

            if self.use_caching:
                cache_key = self._get_cache_key(log_entry)
                self.cache[cache_key] = prediction

            if prediction:
                all_results.append(self._create_result_dict(idx, row, prediction))

        return all_results