AYI-NEDJIMI's picture
Upload folder using huggingface_hub
f5b81f8 verified
import gc
import json
import os
import threading
from pathlib import Path
from typing import Generator
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from peft import PeftModel
from sentence_transformers import SentenceTransformer
import numpy as np
from datasets import load_dataset
# ---------------------------------------------------------------------------
# Model registry
# ---------------------------------------------------------------------------
MODELS = {
"ISO 27001 Expert (1.5B)": {
"base": "Qwen/Qwen2.5-1.5B-Instruct",
"adapter": "AYI-NEDJIMI/ISO27001-Expert-1.5B",
"system_prompt": (
"You are ISO 27001 Expert, a specialized AI assistant for "
"ISO/IEC 27001 information security management systems. "
"You help organizations understand, implement, and maintain "
"ISO 27001 certification, including risk assessment, controls "
"from Annex A, Statement of Applicability, and audit preparation."
),
"examples": [
"What are the mandatory clauses of ISO 27001:2022?",
"How do I conduct a risk assessment according to ISO 27001?",
"Explain the Statement of Applicability (SoA).",
"What changed between ISO 27001:2013 and ISO 27001:2022?",
],
},
"RGPD Expert (1.5B)": {
"base": "Qwen/Qwen2.5-1.5B-Instruct",
"adapter": "AYI-NEDJIMI/RGPD-Expert-1.5B",
"system_prompt": (
"You are RGPD Expert, a specialized AI assistant for GDPR/RGPD "
"data protection regulations. You help organizations understand "
"their obligations under the General Data Protection Regulation, "
"including data subject rights, Data Protection Impact Assessments, "
"lawful bases for processing, and breach notification procedures."
),
"examples": [
"What are the 6 lawful bases for processing under GDPR?",
"When is a Data Protection Impact Assessment (DPIA) required?",
"Explain the right to data portability.",
"What are the penalties for GDPR non-compliance?",
],
},
"CyberSec Assistant (3B)": {
"base": "Qwen/Qwen2.5-3B-Instruct",
"adapter": "AYI-NEDJIMI/CyberSec-Assistant-3B",
"system_prompt": (
"You are CyberSec Assistant, an expert AI specialized in "
"cybersecurity, compliance (GDPR, NIS2, DORA, AI Act, ISO 27001), "
"penetration testing, SOC operations, and AI security."
),
"examples": [
"What is the MITRE ATT&CK framework?",
"How do I set up a SOC from scratch?",
"Explain the NIS2 directive requirements.",
"What are the OWASP Top 10 for 2024?",
],
},
}
# ---------------------------------------------------------------------------
# RAG Setup
# ---------------------------------------------------------------------------
class RAGRetriever:
"""Simple RAG retriever using sentence-transformers and in-memory index."""
def __init__(self):
self.embedder = None
self.documents = []
self.embeddings = None
self.initialized = False
def initialize(self):
"""Load embedding model and build index from datasets."""
if self.initialized:
return
print("Initializing RAG retriever...")
self.embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
# Load key datasets (subset for demo - full 80 datasets would be too heavy)
dataset_ids = [
"AYI-NEDJIMI/iso27001",
"AYI-NEDJIMI/rgpd-fr",
"AYI-NEDJIMI/gdpr-en",
"AYI-NEDJIMI/mitre-attack-fr",
"AYI-NEDJIMI/owasp-top10-fr",
"AYI-NEDJIMI/nis2-directive-fr",
]
print(f"Loading {len(dataset_ids)} datasets for RAG...")
for ds_id in dataset_ids:
try:
ds = load_dataset(ds_id, split="train")
for item in ds:
# Extract text content
if "output" in item:
text = f"{item.get('instruction', '')}\n{item['output']}"
elif "response" in item:
text = f"{item.get('question', '')}\n{item['response']}"
elif "text" in item:
text = item["text"]
else:
continue
self.documents.append({
"text": text[:1000], # Limit length
"source": ds_id.split("/")[-1],
})
except Exception as e:
print(f"Failed to load {ds_id}: {e}")
print(f"Loaded {len(self.documents)} documents. Creating embeddings...")
texts = [doc["text"] for doc in self.documents]
self.embeddings = self.embedder.encode(texts, show_progress_bar=True)
self.initialized = True
print("RAG retriever ready!")
def retrieve(self, query: str, top_k: int = 3) -> list[dict]:
"""Retrieve top-k most relevant documents."""
if not self.initialized:
return []
query_emb = self.embedder.encode([query])[0]
similarities = np.dot(self.embeddings, query_emb)
top_indices = np.argsort(similarities)[::-1][:top_k]
results = []
for idx in top_indices:
results.append({
"text": self.documents[idx]["text"],
"source": self.documents[idx]["source"],
"score": float(similarities[idx]),
})
return results
# Global RAG instance
rag_retriever = RAGRetriever()
# ---------------------------------------------------------------------------
# Global model state
# ---------------------------------------------------------------------------
_lock = threading.Lock()
_loaded_models = {} # Cache loaded models
def _load_model_cached(model_name: str):
"""Load model with caching. Returns (tokenizer, model)."""
if model_name in _loaded_models:
return _loaded_models[model_name]
cfg = MODELS[model_name]
hf_token = os.getenv("HF_TOKEN")
tokenizer = AutoTokenizer.from_pretrained(
cfg["base"],
trust_remote_code=True,
token=hf_token,
)
base = AutoModelForCausalLM.from_pretrained(
cfg["base"],
torch_dtype=torch.float32,
device_map="cpu",
trust_remote_code=True,
token=hf_token,
)
model = PeftModel.from_pretrained(
base,
cfg["adapter"],
torch_dtype=torch.float32,
token=hf_token,
)
model.eval()
_loaded_models[model_name] = (tokenizer, model)
return tokenizer, model
# ---------------------------------------------------------------------------
# Response generation
# ---------------------------------------------------------------------------
def generate_response(
message: str,
history: list[dict],
model_name: str,
use_rag: bool,
temperature: float,
max_tokens: int,
) -> Generator[str, None, None]:
"""Generate response with streaming."""
if not message.strip():
yield ""
return
cfg = MODELS[model_name]
# RAG retrieval
rag_context = ""
if use_rag:
yield "🔍 Searching knowledge base...\n\n"
if not rag_retriever.initialized:
rag_retriever.initialize()
docs = rag_retriever.retrieve(message, top_k=3)
if docs:
rag_context = "\n\n**Relevant context from knowledge base:**\n"
for i, doc in enumerate(docs, 1):
rag_context += f"\n[{i}] From {doc['source']} (relevance: {doc['score']:.2f}):\n{doc['text'][:300]}...\n"
rag_context += "\n---\n\n"
# Load model
yield f"{rag_context}Loading {model_name}...\n\n"
with _lock:
try:
tokenizer, model = _load_model_cached(model_name)
except Exception as e:
yield f"{rag_context}**Error loading model:** {e}"
return
# Build prompt
system_msg = cfg["system_prompt"]
if use_rag and docs:
system_msg += "\n\nYou have access to relevant excerpts from the knowledge base. Use them to provide accurate, detailed answers."
messages = [{"role": "system", "content": system_msg}]
for entry in history:
messages.append({"role": entry["role"], "content": entry["content"]})
messages.append({"role": "user", "content": message})
input_text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer(input_text, return_tensors="pt").to("cpu")
# Stream generation
from threading import Thread
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
generation_kwargs = {
**inputs,
"max_new_tokens": max_tokens,
"temperature": temperature,
"top_p": 0.9,
"do_sample": temperature > 0,
"pad_token_id": tokenizer.eos_token_id,
"streamer": streamer,
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Yield tokens as they come
response = rag_context
for new_text in streamer:
response += new_text
yield response
thread.join()
def generate_comparison(
message: str,
use_rag: bool,
temperature: float,
max_tokens: int,
) -> tuple[str, str, str]:
"""Generate responses from all 3 models for comparison."""
results = {}
for model_name in MODELS.keys():
response = ""
for chunk in generate_response(message, [], model_name, use_rag, temperature, max_tokens):
response = chunk
results[model_name] = response
return (
results["ISO 27001 Expert (1.5B)"],
results["RGPD Expert (1.5B)"],
results["CyberSec Assistant (3B)"],
)
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
DESCRIPTION = """\
## 🛡️ Advanced CyberSec AI Models Demo
Interactive demo with **3 fine-tuned models**, **RAG on 80+ datasets**, and **streaming responses**.
| Model | Specialty | Size |
|-------|-----------|------|
| **ISO 27001 Expert** | ISO/IEC 27001 ISMS | 1.5B |
| **RGPD Expert** | GDPR / RGPD compliance | 1.5B |
| **CyberSec Assistant** | General cybersecurity | 3B |
**Features:**
- 💬 Single-model chat or compare all 3 models side-by-side
- 🔍 RAG (Retrieval-Augmented Generation) on 80+ cybersecurity datasets
- ⚡ Streaming responses (token-by-token)
- 🎛️ Adjustable temperature & max tokens
"""
theme = gr.themes.Soft(
primary_hue="red",
secondary_hue="purple",
neutral_hue="slate",
font=gr.themes.GoogleFont("Inter"),
).set(
body_background_fill="#0f1117",
body_background_fill_dark="#0f1117",
block_background_fill="#1a1c25",
block_background_fill_dark="#1a1c25",
)
with gr.Blocks(theme=theme, title="CyberSec AI Models - Advanced Demo") as demo:
gr.Markdown("# 🛡️ CyberSec AI Models - Advanced Demo")
gr.Markdown(DESCRIPTION)
with gr.Tabs() as tabs:
# Tab 1: Single Model Chat
with gr.Tab("💬 Chat"):
with gr.Row():
with gr.Column(scale=3):
model_selector = gr.Dropdown(
choices=list(MODELS.keys()),
value="ISO 27001 Expert (1.5B)",
label="Select Model",
)
with gr.Column(scale=1):
use_rag_chat = gr.Checkbox(
label="Enable RAG",
value=True,
info="Retrieve context from 80+ datasets",
)
with gr.Row():
temperature_chat = gr.Slider(0, 1, value=0.7, label="Temperature")
max_tokens_chat = gr.Slider(128, 1024, value=512, step=128, label="Max Tokens")
chatbot = gr.Chatbot(type="messages", height=500)
msg = gr.Textbox(
label="Your message",
placeholder="Ask a cybersecurity question...",
lines=2,
)
with gr.Row():
submit = gr.Button("Send", variant="primary")
clear = gr.Button("Clear")
def user(user_message, history):
return "", history + [{"role": "user", "content": user_message}]
def bot(history, model_name, use_rag, temp, max_tok):
user_message = history[-1]["content"]
history_context = history[:-1]
bot_message = ""
for chunk in generate_response(user_message, history_context, model_name, use_rag, temp, max_tok):
bot_message = chunk
yield history + [{"role": "assistant", "content": bot_message}]
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, [chatbot, model_selector, use_rag_chat, temperature_chat, max_tokens_chat], chatbot
)
submit.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, [chatbot, model_selector, use_rag_chat, temperature_chat, max_tokens_chat], chatbot
)
clear.click(lambda: [], None, chatbot, queue=False)
# Examples
gr.Examples(
examples=[
["What are the key principles of ISO 27001?"],
["Explain GDPR data subject rights in detail."],
["How does the MITRE ATT&CK framework work?"],
["What are the main requirements of the NIS2 directive?"],
],
inputs=msg,
)
# Tab 2: Compare Models
with gr.Tab("⚖️ Compare Models"):
gr.Markdown("### Compare responses from all 3 models side-by-side")
with gr.Row():
use_rag_compare = gr.Checkbox(label="Enable RAG", value=True)
temperature_compare = gr.Slider(0, 1, value=0.7, label="Temperature")
max_tokens_compare = gr.Slider(128, 1024, value=512, step=128, label="Max Tokens")
question_compare = gr.Textbox(
label="Question",
placeholder="Ask a question to all 3 models...",
lines=2,
)
compare_btn = gr.Button("Compare Models", variant="primary")
with gr.Row():
output_iso = gr.Textbox(label="ISO 27001 Expert (1.5B)", lines=15)
output_rgpd = gr.Textbox(label="RGPD Expert (1.5B)", lines=15)
output_cyber = gr.Textbox(label="CyberSec Assistant (3B)", lines=15)
compare_btn.click(
generate_comparison,
inputs=[question_compare, use_rag_compare, temperature_compare, max_tokens_compare],
outputs=[output_iso, output_rgpd, output_cyber],
)
gr.Examples(
examples=[
["What is a Data Protection Impact Assessment?"],
["Explain the concept of Zero Trust security."],
["What are the penalties for GDPR non-compliance?"],
],
inputs=question_compare,
)
gr.HTML("""
<div style="text-align:center; margin-top:2rem; padding-top:1rem; border-top:1px solid #444; color:#888; font-size:0.85rem;">
<p>Built by <a href="https://huggingface.co/AYI-NEDJIMI" style="color:#6d9eeb;">Ayi NEDJIMI</a>
| Models: <a href="https://huggingface.co/AYI-NEDJIMI/ISO27001-Expert-1.5B" style="color:#6d9eeb;">ISO27001</a>,
<a href="https://huggingface.co/AYI-NEDJIMI/RGPD-Expert-1.5B" style="color:#6d9eeb;">RGPD</a>,
<a href="https://huggingface.co/AYI-NEDJIMI/CyberSec-Assistant-3B" style="color:#6d9eeb;">CyberSec-3B</a>
| <a href="https://huggingface.co/collections/AYI-NEDJIMI/cybersec-ai-portfolio-datasets-models-and-spaces-699224074a478ec0feeac493" style="color:#6d9eeb;">Full Portfolio</a></p>
<p style="font-size:0.75rem; color:#666;">Fine-tuned with QLoRA on Qwen 2.5 | RAG powered by sentence-transformers</p>
</div>
""")
if __name__ == "__main__":
demo.queue().launch()