Spaces:
Sleeping
Sleeping
| """ | |
| CyberGuide Model Loading and Inference | |
| Handles 4-bit quantized Llama model with LoRA adapter | |
| """ | |
| import os | |
| import torch | |
| from dotenv import load_dotenv | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| BitsAndBytesConfig, | |
| TextIteratorStreamer, | |
| ) | |
| from peft import PeftModel | |
| from threading import Thread | |
| # Load environment variables | |
| load_dotenv() | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| # Model configuration | |
| BASE_MODEL = "samch183/abliterated-model-fp16" | |
| ADAPTER_MODEL = "samch183/cyber-llama-adapter" | |
| ADAPTER_SUBFOLDER = "final_model" | |
| SYSTEM_PROMPT = """You are CyberGuide, an AI assistant built for the cyber security team. | |
| You help junior analysts with: | |
| - Penetration testing procedures | |
| - CTF challenge solving | |
| - Security tool usage (nmap, metasploit, burpsuite etc) | |
| - SOC analysis and incident response | |
| - Defensive security techniques | |
| Rules: | |
| - Always give step by step answers | |
| - Include actual commands and code | |
| - Mention the tools needed | |
| - Keep answers simple and clear | |
| - This tool is for authorized testing only""" | |
| class CyberGuideModel: | |
| """Model wrapper for inference with streaming support""" | |
| def __init__(self, device_preference: str = "auto"): | |
| self.device_preference = device_preference.lower().strip() | |
| self.device = self._resolve_device(self.device_preference) | |
| self.model = None | |
| self.tokenizer = None | |
| self.load_model() | |
| def _resolve_device(self, device_preference: str) -> str: | |
| """Resolve desired compute mode to an available runtime device.""" | |
| if device_preference == "gpu": | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| print("GPU requested but CUDA is not available. Falling back to CPU.") | |
| return "cpu" | |
| if device_preference == "cpu": | |
| return "cpu" | |
| # auto mode | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |
| def load_model(self): | |
| print(f"Loading CyberGuide model in {self.device.upper()} mode...") | |
| if self.device == "cuda": | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| ) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| token=HF_TOKEN, | |
| ) | |
| else: | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| device_map={"": "cpu"}, | |
| dtype=torch.float32, # ← changed from torch_dtype | |
| token=HF_TOKEN, | |
| low_cpu_mem_usage=True, | |
| ) | |
| # Load LoRA adapter | |
| self.model = PeftModel.from_pretrained( | |
| self.model, | |
| ADAPTER_MODEL, | |
| subfolder=ADAPTER_SUBFOLDER, | |
| token=HF_TOKEN, | |
| ) | |
| # Load tokenizer | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| BASE_MODEL, | |
| token=HF_TOKEN, | |
| ) | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| print(f"✓ Model loaded on {self.device.upper()}") | |
| def format_message(self, user_message: str) -> str: | |
| """Format message using Llama 3.1 chat template""" | |
| # Llama 3.1 uses this format | |
| formatted = f"<|start_header_id|>system<|end_header_id|>\n\n{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{user_message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" | |
| return formatted | |
| def generate_streaming( | |
| self, | |
| message: str, | |
| temperature: float = 0.7, | |
| max_tokens: int = 512, | |
| top_p: float = 0.9, | |
| ): | |
| """ | |
| Generate response with streaming | |
| Yields tokens as they are generated | |
| """ | |
| # Format input | |
| formatted_input = self.format_message(message) | |
| # Tokenize | |
| inputs = self.tokenizer( | |
| formatted_input, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=4096, | |
| ) | |
| input_ids = inputs["input_ids"].to(self.device) | |
| attention_mask = inputs["attention_mask"].to(self.device) | |
| # Streaming setup | |
| streamer = TextIteratorStreamer( | |
| self.tokenizer, | |
| skip_special_tokens=True, | |
| skip_prompt=True, | |
| ) | |
| # Generation in thread | |
| generation_kwargs = { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| "max_new_tokens": max_tokens, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "do_sample": True, | |
| "streamer": streamer, | |
| "pad_token_id": self.tokenizer.eos_token_id, | |
| } | |
| thread = Thread(target=self.model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| # Yield tokens as they arrive | |
| for text in streamer: | |
| yield text | |
| def generate( | |
| self, | |
| message: str, | |
| temperature: float = 0.7, | |
| max_tokens: int = 512, | |
| top_p: float = 0.9, | |
| ) -> str: | |
| """Generate response (non-streaming, for compatibility)""" | |
| full_response = "" | |
| for chunk in self.generate_streaming(message, temperature, max_tokens, top_p): | |
| full_response += chunk | |
| return full_response | |
| # Global model instance | |
| model_instance = None | |
| model_mode = None | |
| def get_model(device_preference: str = "auto") -> CyberGuideModel: | |
| """Get or initialize the global model instance""" | |
| global model_instance, model_mode | |
| normalized_mode = device_preference.lower().strip() | |
| if normalized_mode not in {"auto", "gpu", "cpu"}: | |
| normalized_mode = "auto" | |
| if model_instance is None or model_mode != normalized_mode: | |
| model_instance = CyberGuideModel(device_preference=normalized_mode) | |
| model_mode = normalized_mode | |
| return model_instance | |
| def chat( | |
| message: str, | |
| temperature: float = 0.7, | |
| max_tokens: int = 512, | |
| device_preference: str = "auto", | |
| ) -> str: | |
| """Simple chat interface""" | |
| model = get_model(device_preference=device_preference) | |
| return model.generate(message, temperature, max_tokens) | |
| def chat_streaming( | |
| message: str, | |
| temperature: float = 0.7, | |
| max_tokens: int = 512, | |
| device_preference: str = "auto", | |
| ): | |
| """Streaming chat interface""" | |
| model = get_model(device_preference=device_preference) | |
| yield from model.generate_streaming(message, temperature, max_tokens) | |