File size: 6,686 Bytes
9eeb647
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
"""
CyberGuide Model Loading and Inference
Handles 4-bit quantized Llama model with LoRA adapter
"""

import os
import torch
from dotenv import load_dotenv
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TextIteratorStreamer,
)
from peft import PeftModel
from threading import Thread

# Load environment variables
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")

# Model configuration
BASE_MODEL = "samch183/abliterated-model-fp16"
ADAPTER_MODEL = "samch183/cyber-llama-adapter"
ADAPTER_SUBFOLDER = "final_model"

SYSTEM_PROMPT = """You are CyberGuide, an AI assistant built for the cyber security team. 

You help junior analysts with:
- Penetration testing procedures
- CTF challenge solving  
- Security tool usage (nmap, metasploit, burpsuite etc)
- SOC analysis and incident response
- Defensive security techniques

Rules:
- Always give step by step answers
- Include actual commands and code
- Mention the tools needed
- Keep answers simple and clear
- This tool is for authorized testing only"""


class CyberGuideModel:
    """Model wrapper for inference with streaming support"""

    def __init__(self, device_preference: str = "auto"):
        self.device_preference = device_preference.lower().strip()
        self.device = self._resolve_device(self.device_preference)
        self.model = None
        self.tokenizer = None
        self.load_model()

    def _resolve_device(self, device_preference: str) -> str:
        """Resolve desired compute mode to an available runtime device."""
        if device_preference == "gpu":
            if torch.cuda.is_available():
                return "cuda"
            print("GPU requested but CUDA is not available. Falling back to CPU.")
            return "cpu"

        if device_preference == "cpu":
            return "cpu"

        # auto mode
        return "cuda" if torch.cuda.is_available() else "cpu"

    def load_model(self):
        print(f"Loading CyberGuide model in {self.device.upper()} mode...")

        if self.device == "cuda":
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16,
            )
            self.model = AutoModelForCausalLM.from_pretrained(
                BASE_MODEL,
                quantization_config=bnb_config,
                device_map="auto",
                token=HF_TOKEN,
            )
        else:
            self.model = AutoModelForCausalLM.from_pretrained(
            BASE_MODEL,
            device_map={"": "cpu"},
            dtype=torch.float32,      # ← changed from torch_dtype
            token=HF_TOKEN,
            low_cpu_mem_usage=True,
        )

    # Load LoRA adapter
        self.model = PeftModel.from_pretrained(
            self.model,
            ADAPTER_MODEL,
            subfolder=ADAPTER_SUBFOLDER,
            token=HF_TOKEN,
        )

    # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            BASE_MODEL,
            token=HF_TOKEN,
        )
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        print(f"✓ Model loaded on {self.device.upper()}")

    def format_message(self, user_message: str) -> str:
        """Format message using Llama 3.1 chat template"""
        # Llama 3.1 uses this format
        formatted = f"<|start_header_id|>system<|end_header_id|>\n\n{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{user_message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
        return formatted

    def generate_streaming(
        self,
        message: str,
        temperature: float = 0.7,
        max_tokens: int = 512,
        top_p: float = 0.9,
    ):
        """
        Generate response with streaming
        Yields tokens as they are generated
        """
        # Format input
        formatted_input = self.format_message(message)

        # Tokenize
        inputs = self.tokenizer(
            formatted_input,
            return_tensors="pt",
            truncation=True,
            max_length=4096,
        )

        input_ids = inputs["input_ids"].to(self.device)
        attention_mask = inputs["attention_mask"].to(self.device)

        # Streaming setup
        streamer = TextIteratorStreamer(
            self.tokenizer,
            skip_special_tokens=True,
            skip_prompt=True,
        )

        # Generation in thread
        generation_kwargs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "max_new_tokens": max_tokens,
            "temperature": temperature,
            "top_p": top_p,
            "do_sample": True,
            "streamer": streamer,
            "pad_token_id": self.tokenizer.eos_token_id,
        }

        thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
        thread.start()

        # Yield tokens as they arrive
        for text in streamer:
            yield text

    def generate(
        self,
        message: str,
        temperature: float = 0.7,
        max_tokens: int = 512,
        top_p: float = 0.9,
    ) -> str:
        """Generate response (non-streaming, for compatibility)"""
        full_response = ""
        for chunk in self.generate_streaming(message, temperature, max_tokens, top_p):
            full_response += chunk
        return full_response


# Global model instance
model_instance = None
model_mode = None


def get_model(device_preference: str = "auto") -> CyberGuideModel:
    """Get or initialize the global model instance"""
    global model_instance, model_mode

    normalized_mode = device_preference.lower().strip()
    if normalized_mode not in {"auto", "gpu", "cpu"}:
        normalized_mode = "auto"

    if model_instance is None or model_mode != normalized_mode:
        model_instance = CyberGuideModel(device_preference=normalized_mode)
        model_mode = normalized_mode

    return model_instance


def chat(
    message: str,
    temperature: float = 0.7,
    max_tokens: int = 512,
    device_preference: str = "auto",
) -> str:
    """Simple chat interface"""
    model = get_model(device_preference=device_preference)
    return model.generate(message, temperature, max_tokens)


def chat_streaming(
    message: str,
    temperature: float = 0.7,
    max_tokens: int = 512,
    device_preference: str = "auto",
):
    """Streaming chat interface"""
    model = get_model(device_preference=device_preference)
    yield from model.generate_streaming(message, temperature, max_tokens)