cyberguide / model.py
samch183's picture
Update model.py
9eeb647 verified
"""
CyberGuide Model Loading and Inference
Handles 4-bit quantized Llama model with LoRA adapter
"""
import os
import torch
from dotenv import load_dotenv
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TextIteratorStreamer,
)
from peft import PeftModel
from threading import Thread
# Load environment variables
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
# Model configuration
BASE_MODEL = "samch183/abliterated-model-fp16"
ADAPTER_MODEL = "samch183/cyber-llama-adapter"
ADAPTER_SUBFOLDER = "final_model"
SYSTEM_PROMPT = """You are CyberGuide, an AI assistant built for the cyber security team.
You help junior analysts with:
- Penetration testing procedures
- CTF challenge solving
- Security tool usage (nmap, metasploit, burpsuite etc)
- SOC analysis and incident response
- Defensive security techniques
Rules:
- Always give step by step answers
- Include actual commands and code
- Mention the tools needed
- Keep answers simple and clear
- This tool is for authorized testing only"""
class CyberGuideModel:
"""Model wrapper for inference with streaming support"""
def __init__(self, device_preference: str = "auto"):
self.device_preference = device_preference.lower().strip()
self.device = self._resolve_device(self.device_preference)
self.model = None
self.tokenizer = None
self.load_model()
def _resolve_device(self, device_preference: str) -> str:
"""Resolve desired compute mode to an available runtime device."""
if device_preference == "gpu":
if torch.cuda.is_available():
return "cuda"
print("GPU requested but CUDA is not available. Falling back to CPU.")
return "cpu"
if device_preference == "cpu":
return "cpu"
# auto mode
return "cuda" if torch.cuda.is_available() else "cpu"
def load_model(self):
print(f"Loading CyberGuide model in {self.device.upper()} mode...")
if self.device == "cuda":
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
self.model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
quantization_config=bnb_config,
device_map="auto",
token=HF_TOKEN,
)
else:
self.model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
device_map={"": "cpu"},
dtype=torch.float32, # ← changed from torch_dtype
token=HF_TOKEN,
low_cpu_mem_usage=True,
)
# Load LoRA adapter
self.model = PeftModel.from_pretrained(
self.model,
ADAPTER_MODEL,
subfolder=ADAPTER_SUBFOLDER,
token=HF_TOKEN,
)
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
BASE_MODEL,
token=HF_TOKEN,
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
print(f"✓ Model loaded on {self.device.upper()}")
def format_message(self, user_message: str) -> str:
"""Format message using Llama 3.1 chat template"""
# Llama 3.1 uses this format
formatted = f"<|start_header_id|>system<|end_header_id|>\n\n{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{user_message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
return formatted
def generate_streaming(
self,
message: str,
temperature: float = 0.7,
max_tokens: int = 512,
top_p: float = 0.9,
):
"""
Generate response with streaming
Yields tokens as they are generated
"""
# Format input
formatted_input = self.format_message(message)
# Tokenize
inputs = self.tokenizer(
formatted_input,
return_tensors="pt",
truncation=True,
max_length=4096,
)
input_ids = inputs["input_ids"].to(self.device)
attention_mask = inputs["attention_mask"].to(self.device)
# Streaming setup
streamer = TextIteratorStreamer(
self.tokenizer,
skip_special_tokens=True,
skip_prompt=True,
)
# Generation in thread
generation_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"max_new_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"do_sample": True,
"streamer": streamer,
"pad_token_id": self.tokenizer.eos_token_id,
}
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
# Yield tokens as they arrive
for text in streamer:
yield text
def generate(
self,
message: str,
temperature: float = 0.7,
max_tokens: int = 512,
top_p: float = 0.9,
) -> str:
"""Generate response (non-streaming, for compatibility)"""
full_response = ""
for chunk in self.generate_streaming(message, temperature, max_tokens, top_p):
full_response += chunk
return full_response
# Global model instance
model_instance = None
model_mode = None
def get_model(device_preference: str = "auto") -> CyberGuideModel:
"""Get or initialize the global model instance"""
global model_instance, model_mode
normalized_mode = device_preference.lower().strip()
if normalized_mode not in {"auto", "gpu", "cpu"}:
normalized_mode = "auto"
if model_instance is None or model_mode != normalized_mode:
model_instance = CyberGuideModel(device_preference=normalized_mode)
model_mode = normalized_mode
return model_instance
def chat(
message: str,
temperature: float = 0.7,
max_tokens: int = 512,
device_preference: str = "auto",
) -> str:
"""Simple chat interface"""
model = get_model(device_preference=device_preference)
return model.generate(message, temperature, max_tokens)
def chat_streaming(
message: str,
temperature: float = 0.7,
max_tokens: int = 512,
device_preference: str = "auto",
):
"""Streaming chat interface"""
model = get_model(device_preference=device_preference)
yield from model.generate_streaming(message, temperature, max_tokens)