Spaces:
Sleeping
Sleeping
File size: 6,686 Bytes
9eeb647 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 | """
CyberGuide Model Loading and Inference
Handles 4-bit quantized Llama model with LoRA adapter
"""
import os
import torch
from dotenv import load_dotenv
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TextIteratorStreamer,
)
from peft import PeftModel
from threading import Thread
# Load environment variables
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
# Model configuration
BASE_MODEL = "samch183/abliterated-model-fp16"
ADAPTER_MODEL = "samch183/cyber-llama-adapter"
ADAPTER_SUBFOLDER = "final_model"
SYSTEM_PROMPT = """You are CyberGuide, an AI assistant built for the cyber security team.
You help junior analysts with:
- Penetration testing procedures
- CTF challenge solving
- Security tool usage (nmap, metasploit, burpsuite etc)
- SOC analysis and incident response
- Defensive security techniques
Rules:
- Always give step by step answers
- Include actual commands and code
- Mention the tools needed
- Keep answers simple and clear
- This tool is for authorized testing only"""
class CyberGuideModel:
"""Model wrapper for inference with streaming support"""
def __init__(self, device_preference: str = "auto"):
self.device_preference = device_preference.lower().strip()
self.device = self._resolve_device(self.device_preference)
self.model = None
self.tokenizer = None
self.load_model()
def _resolve_device(self, device_preference: str) -> str:
"""Resolve desired compute mode to an available runtime device."""
if device_preference == "gpu":
if torch.cuda.is_available():
return "cuda"
print("GPU requested but CUDA is not available. Falling back to CPU.")
return "cpu"
if device_preference == "cpu":
return "cpu"
# auto mode
return "cuda" if torch.cuda.is_available() else "cpu"
def load_model(self):
print(f"Loading CyberGuide model in {self.device.upper()} mode...")
if self.device == "cuda":
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
self.model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
quantization_config=bnb_config,
device_map="auto",
token=HF_TOKEN,
)
else:
self.model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
device_map={"": "cpu"},
dtype=torch.float32, # ← changed from torch_dtype
token=HF_TOKEN,
low_cpu_mem_usage=True,
)
# Load LoRA adapter
self.model = PeftModel.from_pretrained(
self.model,
ADAPTER_MODEL,
subfolder=ADAPTER_SUBFOLDER,
token=HF_TOKEN,
)
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
BASE_MODEL,
token=HF_TOKEN,
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
print(f"✓ Model loaded on {self.device.upper()}")
def format_message(self, user_message: str) -> str:
"""Format message using Llama 3.1 chat template"""
# Llama 3.1 uses this format
formatted = f"<|start_header_id|>system<|end_header_id|>\n\n{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{user_message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
return formatted
def generate_streaming(
self,
message: str,
temperature: float = 0.7,
max_tokens: int = 512,
top_p: float = 0.9,
):
"""
Generate response with streaming
Yields tokens as they are generated
"""
# Format input
formatted_input = self.format_message(message)
# Tokenize
inputs = self.tokenizer(
formatted_input,
return_tensors="pt",
truncation=True,
max_length=4096,
)
input_ids = inputs["input_ids"].to(self.device)
attention_mask = inputs["attention_mask"].to(self.device)
# Streaming setup
streamer = TextIteratorStreamer(
self.tokenizer,
skip_special_tokens=True,
skip_prompt=True,
)
# Generation in thread
generation_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"max_new_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"do_sample": True,
"streamer": streamer,
"pad_token_id": self.tokenizer.eos_token_id,
}
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
# Yield tokens as they arrive
for text in streamer:
yield text
def generate(
self,
message: str,
temperature: float = 0.7,
max_tokens: int = 512,
top_p: float = 0.9,
) -> str:
"""Generate response (non-streaming, for compatibility)"""
full_response = ""
for chunk in self.generate_streaming(message, temperature, max_tokens, top_p):
full_response += chunk
return full_response
# Global model instance
model_instance = None
model_mode = None
def get_model(device_preference: str = "auto") -> CyberGuideModel:
"""Get or initialize the global model instance"""
global model_instance, model_mode
normalized_mode = device_preference.lower().strip()
if normalized_mode not in {"auto", "gpu", "cpu"}:
normalized_mode = "auto"
if model_instance is None or model_mode != normalized_mode:
model_instance = CyberGuideModel(device_preference=normalized_mode)
model_mode = normalized_mode
return model_instance
def chat(
message: str,
temperature: float = 0.7,
max_tokens: int = 512,
device_preference: str = "auto",
) -> str:
"""Simple chat interface"""
model = get_model(device_preference=device_preference)
return model.generate(message, temperature, max_tokens)
def chat_streaming(
message: str,
temperature: float = 0.7,
max_tokens: int = 512,
device_preference: str = "auto",
):
"""Streaming chat interface"""
model = get_model(device_preference=device_preference)
yield from model.generate_streaming(message, temperature, max_tokens)
|