Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -4,11 +4,12 @@ import pdfplumber
|
|
| 4 |
import numpy as np
|
| 5 |
import faiss
|
| 6 |
import zipfile
|
|
|
|
| 7 |
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
| 8 |
from sentence_transformers import SentenceTransformer
|
| 9 |
|
| 10 |
# -------------------------
|
| 11 |
-
# Step 1: Unzip docs.zip if
|
| 12 |
# -------------------------
|
| 13 |
def unzip_docs():
|
| 14 |
if os.path.exists("docs.zip") and not os.path.exists("docs"):
|
|
@@ -18,12 +19,11 @@ def unzip_docs():
|
|
| 18 |
print("β
Extracted to /docs")
|
| 19 |
|
| 20 |
# -------------------------
|
| 21 |
-
# Step 2:
|
| 22 |
# -------------------------
|
| 23 |
def load_docs(folder="docs"):
|
| 24 |
all_text = ""
|
| 25 |
found_files = []
|
| 26 |
-
|
| 27 |
for root, _, files in os.walk(folder):
|
| 28 |
for fname in files:
|
| 29 |
if fname.lower().endswith(".pdf"):
|
|
@@ -38,17 +38,17 @@ def load_docs(folder="docs"):
|
|
| 38 |
print(f"π Found {len(found_files)} PDF files:")
|
| 39 |
for f in found_files:
|
| 40 |
print(" -", f)
|
| 41 |
-
|
|
|
|
| 42 |
return all_text
|
| 43 |
|
| 44 |
# -------------------------
|
| 45 |
-
#
|
| 46 |
# -------------------------
|
| 47 |
def chunk_text(text, max_words=200):
|
| 48 |
-
|
| 49 |
-
paras = re.split(r'\n{2,}', text)
|
| 50 |
chunks, current = [], ""
|
| 51 |
-
for para in
|
| 52 |
if len((current + " " + para).split()) < max_words:
|
| 53 |
current += " " + para
|
| 54 |
else:
|
|
@@ -56,70 +56,80 @@ def chunk_text(text, max_words=200):
|
|
| 56 |
current = para
|
| 57 |
if current:
|
| 58 |
chunks.append(current.strip())
|
|
|
|
|
|
|
|
|
|
| 59 |
return [c for c in chunks if len(c.split()) > 20]
|
| 60 |
|
| 61 |
# -------------------------
|
| 62 |
-
# Build
|
| 63 |
# -------------------------
|
| 64 |
def build_index():
|
| 65 |
unzip_docs()
|
| 66 |
raw = load_docs("docs")
|
| 67 |
global doc_chunks
|
| 68 |
doc_chunks = chunk_text(raw)
|
| 69 |
-
|
| 70 |
embeddings = embedder.encode(doc_chunks, convert_to_numpy=True, normalize_embeddings=True)
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
return
|
| 74 |
|
| 75 |
# -------------------------
|
| 76 |
-
#
|
| 77 |
# -------------------------
|
| 78 |
def generate_answer(question):
|
| 79 |
q_embed = embedder.encode([question], normalize_embeddings=True)
|
| 80 |
D, I = index.search(np.array(q_embed), top_k)
|
| 81 |
|
|
|
|
|
|
|
|
|
|
| 82 |
top_passages = [f"Passage {i+1}:\n{doc_chunks[i]}" for i in I[0]]
|
| 83 |
context = "\n\n".join(top_passages)
|
| 84 |
|
| 85 |
prompt = (
|
| 86 |
-
"You are SecurityGPT, a cybersecurity assistant. Use the
|
| 87 |
-
f"{context}\n\n"
|
| 88 |
f"Question: {question}\n\nAnswer:"
|
| 89 |
)
|
| 90 |
|
| 91 |
input_ids = tokenizer.encode(prompt, return_tensors="pt", truncation=True, max_length=1024)
|
| 92 |
output_ids = model.generate(
|
| 93 |
input_ids,
|
| 94 |
-
max_length=
|
| 95 |
-
num_beams=
|
| 96 |
-
temperature=0.
|
| 97 |
-
repetition_penalty=1.
|
| 98 |
early_stopping=True
|
| 99 |
)
|
| 100 |
|
| 101 |
-
|
|
|
|
| 102 |
|
| 103 |
# -------------------------
|
| 104 |
-
# Load
|
| 105 |
# -------------------------
|
| 106 |
embedder = SentenceTransformer("multi-qa-MiniLM-L6-cos-v1")
|
| 107 |
-
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-
|
| 108 |
-
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-
|
| 109 |
-
|
| 110 |
top_k = 5
|
| 111 |
doc_chunks = []
|
| 112 |
index = build_index()
|
| 113 |
|
| 114 |
# -------------------------
|
| 115 |
-
# Gradio App
|
| 116 |
# -------------------------
|
| 117 |
demo = gr.Interface(
|
| 118 |
fn=generate_answer,
|
| 119 |
-
inputs=gr.Textbox(label="Ask SecurityGPT", placeholder="e.g. How do I
|
| 120 |
outputs=gr.Textbox(label="Answer"),
|
| 121 |
title="π SecurityGPT",
|
| 122 |
-
description="Ask cybersecurity questions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
)
|
| 124 |
|
| 125 |
demo.launch()
|
|
|
|
| 4 |
import numpy as np
|
| 5 |
import faiss
|
| 6 |
import zipfile
|
| 7 |
+
import re
|
| 8 |
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
| 9 |
from sentence_transformers import SentenceTransformer
|
| 10 |
|
| 11 |
# -------------------------
|
| 12 |
+
# Step 1: Unzip docs.zip if needed
|
| 13 |
# -------------------------
|
| 14 |
def unzip_docs():
|
| 15 |
if os.path.exists("docs.zip") and not os.path.exists("docs"):
|
|
|
|
| 19 |
print("β
Extracted to /docs")
|
| 20 |
|
| 21 |
# -------------------------
|
| 22 |
+
# Step 2: Extract and log PDF text
|
| 23 |
# -------------------------
|
| 24 |
def load_docs(folder="docs"):
|
| 25 |
all_text = ""
|
| 26 |
found_files = []
|
|
|
|
| 27 |
for root, _, files in os.walk(folder):
|
| 28 |
for fname in files:
|
| 29 |
if fname.lower().endswith(".pdf"):
|
|
|
|
| 38 |
print(f"π Found {len(found_files)} PDF files:")
|
| 39 |
for f in found_files:
|
| 40 |
print(" -", f)
|
| 41 |
+
print(f"β
Total raw text size: {len(all_text)} characters")
|
| 42 |
+
print(f"π§Ύ Sample Text:\n{all_text[:300]}")
|
| 43 |
return all_text
|
| 44 |
|
| 45 |
# -------------------------
|
| 46 |
+
# Step 3: Chunk into paragraphs
|
| 47 |
# -------------------------
|
| 48 |
def chunk_text(text, max_words=200):
|
| 49 |
+
paragraphs = re.split(r'\n{2,}', text)
|
|
|
|
| 50 |
chunks, current = [], ""
|
| 51 |
+
for para in paragraphs:
|
| 52 |
if len((current + " " + para).split()) < max_words:
|
| 53 |
current += " " + para
|
| 54 |
else:
|
|
|
|
| 56 |
current = para
|
| 57 |
if current:
|
| 58 |
chunks.append(current.strip())
|
| 59 |
+
print(f"β
Total Chunks Created: {len(chunks)}")
|
| 60 |
+
for i, c in enumerate(chunks[:3]):
|
| 61 |
+
print(f"πΉ Chunk {i+1} Preview:\n{c[:250]}\n")
|
| 62 |
return [c for c in chunks if len(c.split()) > 20]
|
| 63 |
|
| 64 |
# -------------------------
|
| 65 |
+
# Step 4: Build FAISS Index
|
| 66 |
# -------------------------
|
| 67 |
def build_index():
|
| 68 |
unzip_docs()
|
| 69 |
raw = load_docs("docs")
|
| 70 |
global doc_chunks
|
| 71 |
doc_chunks = chunk_text(raw)
|
|
|
|
| 72 |
embeddings = embedder.encode(doc_chunks, convert_to_numpy=True, normalize_embeddings=True)
|
| 73 |
+
index = faiss.IndexFlatIP(embeddings.shape[1])
|
| 74 |
+
index.add(embeddings)
|
| 75 |
+
return index
|
| 76 |
|
| 77 |
# -------------------------
|
| 78 |
+
# Step 5: Generate Answer
|
| 79 |
# -------------------------
|
| 80 |
def generate_answer(question):
|
| 81 |
q_embed = embedder.encode([question], normalize_embeddings=True)
|
| 82 |
D, I = index.search(np.array(q_embed), top_k)
|
| 83 |
|
| 84 |
+
print("π Top similarity scores:", D[0])
|
| 85 |
+
print("π§ Retrieved indices:", I[0])
|
| 86 |
+
|
| 87 |
top_passages = [f"Passage {i+1}:\n{doc_chunks[i]}" for i in I[0]]
|
| 88 |
context = "\n\n".join(top_passages)
|
| 89 |
|
| 90 |
prompt = (
|
| 91 |
+
"You are SecurityGPT, a cybersecurity expert assistant. Use ONLY the context below to answer the question clearly in multiple paragraphs.\n\n"
|
| 92 |
+
f"Context:\n{context}\n\n"
|
| 93 |
f"Question: {question}\n\nAnswer:"
|
| 94 |
)
|
| 95 |
|
| 96 |
input_ids = tokenizer.encode(prompt, return_tensors="pt", truncation=True, max_length=1024)
|
| 97 |
output_ids = model.generate(
|
| 98 |
input_ids,
|
| 99 |
+
max_length=400,
|
| 100 |
+
num_beams=4,
|
| 101 |
+
temperature=0.7,
|
| 102 |
+
repetition_penalty=1.2,
|
| 103 |
early_stopping=True
|
| 104 |
)
|
| 105 |
|
| 106 |
+
answer = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 107 |
+
return answer
|
| 108 |
|
| 109 |
# -------------------------
|
| 110 |
+
# Step 6: Load Model & Index
|
| 111 |
# -------------------------
|
| 112 |
embedder = SentenceTransformer("multi-qa-MiniLM-L6-cos-v1")
|
| 113 |
+
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small") # Faster
|
| 114 |
+
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small")
|
|
|
|
| 115 |
top_k = 5
|
| 116 |
doc_chunks = []
|
| 117 |
index = build_index()
|
| 118 |
|
| 119 |
# -------------------------
|
| 120 |
+
# Step 7: Launch Gradio App
|
| 121 |
# -------------------------
|
| 122 |
demo = gr.Interface(
|
| 123 |
fn=generate_answer,
|
| 124 |
+
inputs=gr.Textbox(label="Ask SecurityGPT", placeholder="e.g. How do I protect against phishing?"),
|
| 125 |
outputs=gr.Textbox(label="Answer"),
|
| 126 |
title="π SecurityGPT",
|
| 127 |
+
description="Ask cybersecurity questions based on embedded content from your PDF documents.",
|
| 128 |
+
examples=[
|
| 129 |
+
"How can I secure my home network?",
|
| 130 |
+
"What are best practices for using public Wi-Fi?",
|
| 131 |
+
"What should I know about password managers?"
|
| 132 |
+
]
|
| 133 |
)
|
| 134 |
|
| 135 |
demo.launch()
|