Small_Cyber / app.py
madox81's picture
Update app.py
8f4334d verified
import os
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
# --- Configuration ---
MODEL_ID = "madox81/SmolLM2-Cyber-Insight-GGUF"
FILENAME = "SmolLM2-1.7b-Instruct.Q4_K_M.gguf"
MITRE_INSTRUCT = """You are a cybersecurity threat analysis assistant.
Analyze the following sequence of security events as a single coordinated attack.
Tasks:
1. Identify all relevant MITRE ATT&CK tactics
2. Identify all relevant MITRE ATT&CK techniques (use names, not IDs)
3. Consider the full sequence, not individual events
4. Include only behaviors clearly supported by the events
Strict Rules:
- Only include techniques that are directly observable from the events
- Do NOT assume or infer techniques without clear evidence
- Do NOT include phishing unless an email is explicitly mentioned
- Do NOT include PowerShell unless explicitly stated
- Distinguish carefully between:
• Command and Control (communication)
• Exfiltration (data theft)
- If uncertain, omit the technique rather than guess
Allowed Tactics (use only these):
- Initial Access
- Execution
- Persistence
- Privilege Escalation
- Defense Evasion
- Credential Access
- Discovery
- Lateral Movement
- Collection
- Command and Control
- Exfiltration
- Impact
Return ONLY valid JSON:
{
"tactics": [...],
"techniques": [...]
}
Validation Step:
Before producing the final answer, internally verify that each technique is directly supported by at least one event.
"""
# --- LLM Class ---
class LLM:
def __init__(self, model_id, filename):
print("Loading model...")
# 1. Device & Dtype
if torch.cuda.is_available():
dtype = torch.float16
device_map = "auto"
else:
dtype = torch.float32 # CPU stability
device_map = "cpu"
# 2. Load Tokenizer
self.tokenizer = AutoTokenizer.from_pretrained('madox81/SmolLM2-Cyber-Insight', use_fast=True)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# 3. Load Base Model
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
gguf_file=filename,
torch_dtype=dtype,
device_map=device_map
)
# # 4. Load LoRA Adapters
# print(f"Loading adapters from {lora_id}...")
# self.model = PeftModel.from_pretrained(
# self.model,
# lora_id,
# torch_dtype=dtype
# )
print("Model loaded!")
def generate_resp(self, user_input, task_type):
def format_events(user_input):
lines = [l.strip() for l in user_input.split("\n") if l.strip()]
return " ".join(lines)
# CRITICAL: Use EXACT instructions from dataset generation
if task_type == "MITRE Mapping":
# Matches MITRE_INSTRUCTIONS[0] in your script
instruction = MITRE_INSTRUCT
elif task_type == "Severity Assessment":
# Matches SEVERITY_INSTRUCTIONS[0] in your script
instruction = "Assess the severity and business risk of the following incident."
else:
instruction = "Analyze the following:"
# Format: "Instruction...\n\nInput: ..."
formatted_message = f"{instruction}\n\nInput: {format_events(user_input)}"
messages = [{"role": "user", "content": formatted_message}]
# Apply Chat Template
inputs = self.tokenizer.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors='pt',
add_generation_prompt=True
).to(self.model.device)
# Generation
with torch.no_grad():
output = self.model.generate(
**inputs,
max_new_tokens=256,
do_sample=False,
temperature=None,
repetition_penalty=1.15,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
# Decode
prompt_length = inputs['input_ids'].shape[1]
generated_tokens = output[0][prompt_length:]
response = self.tokenizer.decode(
generated_tokens,
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
return response.strip()
# --- Initialize ---
llm_instance = LLM(MODEL_ID, FILENAME)
# --- Gradio Interface ---
def process_input(user_input, task_type):
return llm_instance.generate_resp(user_input, task_type)
with gr.Blocks(title="SmolLM2-Cyber-Insight") as demo:
gr.Markdown("# 🛡️ SmolLM2-Cyber-Insight (Dual Task - GGUF Optimized)")
with gr.Row():
with gr.Column(scale=2):
task_selector = gr.Dropdown(
label="Select Task Type",
choices=["MITRE Mapping", "Severity Assessment"],
value="MITRE Mapping"
)
input_box = gr.Textbox(
label="Input Data",
placeholder="Paste log, procedure, or incident description here...",
lines=5
)
submit_btn = gr.Button("Analyze")
output_box = gr.Textbox(label="Model Response (JSON)", lines=5)
# Examples matching the new training data distribution
gr.Markdown("### Examples")
gr.Examples(
examples=[
# MITRE Example (Blue Team style)
["MITRE Mapping", "selection: CommandLine contains 'Invoke-Expression'"],
# MITRE Example (Playbook style)
["MITRE Mapping", "Incident Type: Ransomware\nTarget: Finance Server"],
# Severity Example
["Severity Assessment", "Incident: Ransomware affecting Finance Server."]
],
inputs=[task_selector, input_box]
)
submit_btn.click(fn=process_input, inputs=[input_box, task_selector], outputs=output_box)
if __name__ == "__main__":
demo.launch()