Spaces:
Running on Zero
Running on Zero
| import gradio as gr | |
| import torch | |
| import spaces | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| from threading import Thread | |
| # --- Configuration --- | |
| MODEL_ID = "RISys-Lab/RedSage-Qwen3-8B-DPO" | |
| TITLE = "🛡️ RedSage: Cybersecurity Generalist LLM" | |
| DESCRIPTION = """ | |
| **RedSage-Qwen3-8B-DPO** is an open-source, locally deployable 8B model designed to bridge the gap between general knowledge and domain-specific security operations. | |
| It is trained on **11.8B tokens** of cybersecurity-focused data and fine-tuned on **266K multi-turn expert workflows** (Agentic SFT). This DPO-aligned version is recommended for production-ready assistance and safe, aligned behavior. | |
| """ | |
| # --- Model Loading --- | |
| print(f"Loading {MODEL_ID}...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| print("Model loaded successfully.") | |
| # --- Helper: Force Content to String --- | |
| def get_text_content(content): | |
| """ | |
| Safely extracts text from complex content structures. | |
| Handles: | |
| - Strings (returns as is) | |
| - Lists (Gradio/OpenAI multimodal format -> extracts "text" parts) | |
| - None/Other (returns empty string) | |
| """ | |
| if content is None: | |
| return "" | |
| if isinstance(content, str): | |
| return content | |
| # Handle List (Multimodal / Image Uploads) | |
| if isinstance(content, list): | |
| text_parts = [] | |
| for part in content: | |
| # OpenAI/Gradio 5.x dict format: {"type": "text", "text": "..."} | |
| if isinstance(part, dict): | |
| if part.get("type") == "text": | |
| text_parts.append(str(part.get("text", ""))) | |
| # We ignore "image_url" or "file" types since this is a text model | |
| # Simple list of strings (rare but possible) | |
| elif isinstance(part, str): | |
| text_parts.append(part) | |
| return "\n".join(text_parts) | |
| # Fallback | |
| return str(content) | |
| # --- Generation Logic --- | |
| def chat_function(message, history, system_prompt, max_tokens, temperature, top_p): | |
| # 1. Initialize Messages | |
| messages = [] | |
| # System Prompt | |
| if system_prompt: | |
| messages.append({"role": "system", "content": system_prompt}) | |
| # 2. Parse History (Robust Parsing + Text Extraction) | |
| for item in history: | |
| # Format: List/Tuple [user, bot] | |
| if isinstance(item, (list, tuple)): | |
| if len(item) >= 2: | |
| user_msg = get_text_content(item[0]) | |
| assistant_msg = get_text_content(item[1]) | |
| if user_msg: | |
| messages.append({"role": "user", "content": user_msg}) | |
| if assistant_msg: | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| # Format: Dict {'role': 'user', ...} | |
| elif isinstance(item, dict): | |
| role = item.get("role") | |
| content = get_text_content(item.get("content")) | |
| if role and content: | |
| messages.append({"role": role, "content": content}) | |
| # 3. Add Current Message (Cleaned) | |
| current_msg_text = get_text_content(message) | |
| messages.append({"role": "user", "content": current_msg_text}) | |
| # 4. Tokenize | |
| encodings = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| return_tensors="pt", | |
| return_dict=True | |
| ).to(model.device) | |
| # Handle Dictionary vs Tensor return | |
| if isinstance(encodings, torch.Tensor): | |
| input_ids = encodings | |
| attention_mask = None | |
| elif hasattr(encodings, "input_ids"): | |
| input_ids = encodings.input_ids | |
| attention_mask = getattr(encodings, "attention_mask", None) | |
| else: | |
| input_ids = encodings["input_ids"] | |
| attention_mask = encodings.get("attention_mask", None) | |
| # 5. Define Stop Tokens (Stops automatic user/assistant generation) | |
| terminators = [ | |
| tokenizer.eos_token_id, | |
| tokenizer.convert_tokens_to_ids("<|im_end|>"), | |
| tokenizer.convert_tokens_to_ids("<|endoftext|>") | |
| ] | |
| # 6. Streamer | |
| streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) | |
| # 7. Generate Args | |
| generate_kwargs = dict( | |
| input_ids=input_ids, | |
| streamer=streamer, | |
| max_new_tokens=max_tokens, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| eos_token_id=terminators, | |
| ) | |
| if attention_mask is not None: | |
| generate_kwargs["attention_mask"] = attention_mask | |
| # 8. Run Generation | |
| t = Thread(target=model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| # 9. Yield Output | |
| partial_message = "" | |
| for new_token in streamer: | |
| partial_message += new_token | |
| yield partial_message | |
| # --- UI Layout --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(f"# {TITLE}") | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(): | |
| gr.Button("📄 Paper (OpenReview)", link="https://openreview.net/forum?id=W4FAenIrQ2") | |
| gr.Button("💻 GitHub Repo", link="https://github.com/RISys-Lab/RedSage") | |
| gr.Button("🤗 Hugging Face Collection", link="https://huggingface.co/collections/RISys-Lab/redsage-models") | |
| with gr.Tabs(): | |
| with gr.Tab("💬 Chat"): | |
| with gr.Accordion("⚙️ System Parameters", open=False): | |
| system_prompt = gr.Textbox( | |
| value="You are REDSAGE, cybersecurity-tuned model developed by Khalifa University. You are a helpful assistant.", | |
| label="System Prompt" | |
| ) | |
| with gr.Row(): | |
| slider_temp = gr.Slider(minimum=0.1, maximum=1.0, value=0.5, label="Temperature") | |
| slider_tokens = gr.Slider(minimum=1, maximum=4096, value=1024, step=1, label="Max New Tokens") | |
| slider_top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top-P") | |
| # Chat Interface | |
| gr.ChatInterface( | |
| fn=chat_function, | |
| additional_inputs=[system_prompt, slider_tokens, slider_temp, slider_top_p] | |
| ) | |
| # --- WARNING / DISCLAIMER --- | |
| gr.Markdown( | |
| "<div style='text-align: center; font-size: 0.9em; margin-top: 10px; opacity: 0.8;'>" | |
| "⚠️ <b>Disclaimer:</b> AI models may hallucinate or produce inaccurate information. " | |
| "Please verify all security advice and code outputs before use." | |
| "</div>" | |
| ) | |
| with gr.Tab("📝 Citation & Attribution"): | |
| gr.Markdown(""" | |
| ### Authors | |
| **Naufal Suryanto**<sup>1</sup>, **Muzammal Naseer**<sup>1†</sup>, **Pengfei Li**<sup>1</sup>, | |
| **Syed Talal Wasim**<sup>2</sup>, **Jinhui Yi**<sup>2</sup>, **Juergen Gall**<sup>2</sup>, | |
| **Paolo Ceravolo**<sup>3</sup>, **Ernesto Damiani**<sup>3</sup> | |
| <small><sup>1</sup> Khalifa University | <sup>2</sup> Universität Bonn | <sup>3</sup> University of Milan | <sup>†</sup> Project Lead</small> | |
| ### BibTeX | |
| If you use RedSage in your research, please cite: | |
| """) | |
| gr.Code( | |
| value="""@inproceedings{suryanto2026redsage, | |
| title={RedSage: A Cybersecurity Generalist {LLM}}, | |
| author={Suryanto, Naufal and Naseer, Muzammal and Li, Pengfei and Wasim, Syed Talal and Yi, Jinhui and Gall, Juergen and Ceravolo, Paolo and Damiani, Ernesto}, | |
| booktitle={The Fourteenth International Conference on Learning Representations}, | |
| year={2026}, | |
| url={https://openreview.net/forum?id=W4FAenIrQ2} | |
| }""", | |
| language="latex", | |
| label="Citation" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |