Spaces:
Paused
Paused
File size: 5,027 Bytes
7d04911 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | import streamlit as st
import asyncio
import re
import os
from llama_cpp import Llama
import requests
from bs4 import BeautifulSoup
# Set page configuration
st.set_page_config(page_title="Security Assistant", page_icon="🔒", layout="wide")
# Custom CSS for styling
st.markdown(
"""
<style>
.user-message { background-color: #DCF8C6; padding: 10px; border-radius: 10px; margin: 5px 0; }
.assistant-message { background-color: #E9ECEF; padding: 10px; border-radius: 10px; margin: 5px 0; }
.tool-output { background-color: #F8F9FA; padding: 10px; border-radius: 10px; border: 1px solid #DEE2E6; }
</style>
""",
unsafe_allow_html=True
)
# Cache the model loading
@st.cache_resource
def load_model():
# Model path consistent across environments
model_path = os.path.join("models", "pentest_ai.Q4_0.gguf")
if not os.path.exists(model_path):
st.error(f"Model file not found at {model_path}. Please ensure it’s placed correctly.")
return None
try:
model = Llama(model_path=model_path, n_ctx=2048, n_threads=4, verbose=False)
return model
except Exception as e:
st.error(f"Failed to load model: {e}")
return None
# Execute tools asynchronously
async def run_tool(command: str) -> str:
try:
process = await asyncio.create_subprocess_shell(
command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await process.communicate()
return stdout.decode() if stdout else stderr.decode()
except Exception as e:
return f"Error executing tool: {str(e)}"
# Fetch vulnerability info via web scraping (no API keys)
def get_vulnerability_info(query: str) -> str:
try:
url = f"https://cve.mitre.org/cgi-bin/cvekey.cgi?keyword={query}"
response = requests.get(url, timeout=10)
soup = BeautifulSoup(response.text, "html.parser")
results = soup.find_all("tr")[1:6] # Top 5 results
vulns = [f"{row.find_all('td')[0].text}: {row.find_all('td')[1].text}" for row in results]
return "\n".join(vulns) if vulns else "No vulnerabilities found."
except Exception as e:
return f"Error fetching vulnerability data: {str(e)}"
# Session state management
if "messages" not in st.session_state:
st.session_state.messages = []
# Add message to chat history
def add_message(content: str, is_user: bool):
st.session_state.messages.append({"content": content, "is_user": is_user})
# Render chat history
def render_chat():
for msg in st.session_state.messages:
bubble_class = "user-message" if msg["is_user"] else "assistant-message"
st.markdown(f'<div class="{bubble_class}">{msg["content"]}</div>', unsafe_allow_html=True)
# Main application
def main():
st.title("🔒 Open-Source Security Assistant")
st.markdown("Powered by pentest_ai.Q4_0.gguf. Runs locally or on Hugging Face Spaces.")
# Sidebar for settings
with st.sidebar:
max_tokens = st.slider("Max Tokens", 128, 1024, 256)
if st.button("Clear Chat"):
st.session_state.messages = []
# Load model
model = load_model()
if not model:
st.warning("Model loading failed. Check logs or ensure the model file is available.")
return
render_chat()
# Chat input form
with st.form("chat_form", clear_on_submit=True):
user_input = st.text_area("Ask a security question...", height=100)
submit = st.form_submit_button("Send")
if submit and user_input:
add_message(user_input, True)
with st.spinner("Processing..."):
# Prepare prompt
system_prompt = """
You are a cybersecurity assistant with expertise in penetration testing.
Provide concise, actionable insights. Use [TOOL: tool_name ARGS: "args"] for tool suggestions.
"""
full_prompt = f"{system_prompt}\nUser: {user_input}\nAssistant:"
# Generate response
response = model.create_completion(
full_prompt, max_tokens=max_tokens, temperature=0.7, stop=["User:"]
)
generated_text = response["choices"][0]["text"].strip()
# Parse for tool execution
tool_pattern = r"\[TOOL: (\w+) ARGS: \"(.*?)\"\]"
match = re.search(tool_pattern, generated_text)
if match:
tool_name, args = match.groups()
tool_output = asyncio.run(run_tool(f"{tool_name} {args}"))
generated_text += f"\n\n<div class='tool-output'>Tool Output:\n{tool_output}</div>"
# Handle vulnerability lookups
if "vulnerability" in user_input.lower():
query = user_input.split()[-1] # Simplified query extraction
vulns = get_vulnerability_info(query)
generated_text += f"\n\nVulnerability Data:\n{vulns}"
add_message(generated_text, False)
if __name__ == "__main__":
main() |