ziffir commited on
Commit
c835198
·
verified ·
1 Parent(s): 2100009

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -235
app.py CHANGED
@@ -1,239 +1,180 @@
1
- import requests, uuid, time, threading, torch, json
2
- from fastapi import FastAPI
3
- from bs4 import BeautifulSoup
4
- from urllib.parse import urljoin
5
- from datetime import datetime
6
- from transformers import AutoTokenizer, AutoModelForCausalLM
7
-
8
- # ==================================================
9
- # CONFIG
10
- # ==================================================
11
- MODEL_NAME = "UCSB-SURFI/VulnLLM-R-7B"
12
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
- UA = "RedTeam-Bot/1.0"
14
- MAX_HISTORY = 5 # her hedef için saklanacak en fazla tarama sayısı
15
-
16
- # ==================================================
17
- # APP + STATE
18
- # ==================================================
19
- app = FastAPI(title="Ultra Red-Team SaaS Engine")
20
- SCANS = {} # scan_id -> result
21
- SCAN_HISTORY = {} # target -> [scan_ids]
22
-
23
- # ==================================================
24
- # LOAD VULNLLM (ONCE, OFFLOAD)
25
- # ==================================================
26
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
27
- model = AutoModelForCausalLM.from_pretrained(
28
- MODEL_NAME,
29
- torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
30
- device_map="auto",
31
- offload_folder="offload", # memory-safe offload
32
- trust_remote_code=True
33
- )
34
- model.eval()
35
-
36
- def vulnllm_enrich(finding, context):
37
- prompt = f"""
38
- You are an elite red-team security researcher.
39
-
40
- Finding:
41
- {finding}
42
-
43
- Context:
44
- {json.dumps(context, indent=2)}
45
-
46
- Return STRICTLY:
47
- - CWE
48
- - Risk level (Low/Medium/High/Critical)
49
- - Realistic exploit scenario
50
- - Clear remediation
51
- """
52
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024) # daha kısa
53
- if DEVICE == "cuda":
54
- inputs = inputs.to(DEVICE)
55
-
56
- with torch.no_grad():
57
- out = model.generate(
58
- **inputs,
59
- max_new_tokens=200, # memory-safe token sayısı
60
- temperature=0.25,
61
- top_p=0.95
62
- )
63
-
64
- result = tokenizer.decode(out[0], skip_special_tokens=True)
65
- del inputs
66
- torch.cuda.empty_cache()
67
- return result
68
-
69
- # ==================================================
70
- # RECON AGENT
71
- # ==================================================
72
- def recon(url):
73
- r = requests.get(url, timeout=10, headers={"User-Agent": UA})
74
- soup = BeautifulSoup(r.text, "html.parser")
75
-
76
- js_files = [urljoin(url, s["src"]) for s in soup.find_all("script", src=True)]
77
- forms = []
78
- for f in soup.find_all("form"):
79
- forms.append({
80
- "action": urljoin(url, f.get("action", "")),
81
- "method": f.get("method", "GET").upper(),
82
- "inputs": [i.get("name") for i in f.find_all("input")]
83
- })
84
 
85
  return {
86
- "headers": dict(r.headers),
87
- "js_files": js_files,
88
- "forms": forms,
89
- "html_size": len(r.text) # full HTML yerine boyut
90
- }
91
-
92
- # ==================================================
93
- # JS SURFACE (AST-LIKE HEURISTIC)
94
- # ==================================================
95
- def js_surface(js_url):
96
- try:
97
- r = requests.get(js_url, timeout=5, headers={"User-Agent": UA})
98
- endpoints = []
99
- for line in r.text.splitlines():
100
- if "fetch(" in line or "axios" in line:
101
- endpoints.append(line.strip()[:200])
102
- return endpoints
103
- except:
104
- return []
105
-
106
- # ==================================================
107
- # RED AGENT (ATTACK THINKING)
108
- # ==================================================
109
- def red_agent(recon_data, js_endpoints):
110
- findings = []
111
-
112
- headers = {k.lower(): v for k, v in recon_data["headers"].items()}
113
- if "x-powered-by" in headers:
114
- findings.append({
115
- "title": "Technology stack disclosure",
116
- "context": headers
117
- })
118
-
119
- if recon_data["forms"]:
120
- findings.append({
121
- "title": "User-controlled input surface",
122
- "context": recon_data["forms"]
123
- })
124
-
125
- if len(js_endpoints) > 5:
126
- findings.append({
127
- "title": "Exposed client-side API surface",
128
- "context": js_endpoints[:5]
129
- })
130
-
131
- return findings
132
-
133
- # ==================================================
134
- # RISK ENGINE
135
- # ==================================================
136
- def risk_score(enriched_findings):
137
- score = 0
138
- for e in enriched_findings:
139
- txt = e.lower()
140
- if "critical" in txt: score += 40
141
- elif "high" in txt: score += 25
142
- elif "medium" in txt: score += 10
143
- return min(score, 100)
144
-
145
- # ==================================================
146
- # ATTACK GRAPH
147
- # ==================================================
148
- def attack_graph(js_eps, forms):
149
- nodes = ["User", "Browser", "JS"]
150
- edges = [{"from":"User","to":"Browser"},{"from":"Browser","to":"JS"}]
151
-
152
- for ep in js_eps:
153
- nodes.append(ep)
154
- edges.append({"from":"JS","to":ep})
155
-
156
- if forms:
157
- nodes.append("FormInput")
158
- edges.append({"from":"User","to":"FormInput"})
159
-
160
- return {"nodes": list(set(nodes)), "edges": edges}
161
-
162
- # ==================================================
163
- # CORE SCAN PIPELINE
164
- # ==================================================
165
- def run_scan(target):
166
- scan_id = str(uuid.uuid4())
167
-
168
- recon_data = recon(target)
169
-
170
- js_eps = []
171
- for js in recon_data["js_files"]:
172
- js_eps += js_surface(js)
173
-
174
- raw_findings = red_agent(recon_data, js_eps)
175
-
176
- enriched = []
177
- for f in raw_findings:
178
- enriched.append(vulnllm_enrich(f["title"], f["context"]))
179
-
180
- risk = risk_score(enriched)
181
- graph = attack_graph(js_eps, recon_data["forms"])
182
-
183
- result = {
184
- "scan_id": scan_id,
185
- "target": target,
186
- "time": datetime.utcnow().isoformat(),
187
- "risk_score": risk,
188
- "attack_graph": graph,
189
- "findings_enriched": enriched,
190
- "html_size": recon_data["html_size"]
191
  }
192
 
193
- SCANS[scan_id] = result
194
- SCAN_HISTORY.setdefault(target, []).append(scan_id)
195
-
196
- # memory-safe: eski taramaları sil
197
- if len(SCAN_HISTORY[target]) > MAX_HISTORY:
198
- oldest = SCAN_HISTORY[target].pop(0)
199
- del SCANS[oldest]
200
-
201
- return result
202
-
203
- # ==================================================
204
- # CONTINUOUS SCAN
205
- # ==================================================
206
- def continuous(target, interval):
207
- while True:
208
- run_scan(target)
209
- time.sleep(interval)
210
-
211
- # ==================================================
212
- # API
213
- # ==================================================
214
- @app.post("/scan")
215
- def scan(target: str):
216
- return run_scan(target)
217
-
218
- @app.post("/scan/continuous")
219
- def scan_continuous(target: str, interval: int = 3600):
220
- threading.Thread(target=continuous, args=(target,interval), daemon=True).start()
221
- return {"status":"scheduled","interval":interval}
222
-
223
- @app.get("/dashboard/{target}")
224
- def dashboard(target: str):
225
- ids = SCAN_HISTORY.get(target, [])
226
- return [SCANS[i] for i in ids]
227
-
228
- @app.get("/")
229
- def health():
230
- return {
231
- "status":"running",
232
- "mode":"ULTRA B-MODE",
233
- "model":MODEL_NAME,
234
- "features":[
235
- "recon","js-surface","red-agent",
236
- "vulnllm-enrichment","attack-graph",
237
- "risk-engine","continuous-scan"
238
- ]
239
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ... (önceki import'ların sonuna ekle)
2
+ import networkx as nx
3
+ import plotly.graph_objects as go
4
+ import matplotlib.pyplot as plt
5
+ from io import BytesIO
6
+ import base64
7
+
8
+ # ────────────────────────────────────────────────
9
+ # Attack Graph Görselleştirme Fonksiyonları
10
+ # ────────────────────────────────────────────────
11
+ def create_attack_graph_data(recon_data: Dict) -> Dict:
12
+ """Graph verisini hazırlar (nodes, edges)"""
13
+ nodes = ["User", "Browser"]
14
+ edges = [("User", "Browser")]
15
+
16
+ forms_count = recon_data.get("forms_count", 0)
17
+ js_count = recon_data.get("js_files_count", 0)
18
+
19
+ if forms_count > 0:
20
+ nodes.append("Form Submission")
21
+ edges.append(("User", "Form Submission"))
22
+ edges.append(("Form Submission", "Backend"))
23
+
24
+ if js_count > 0:
25
+ nodes.append("Client JS")
26
+ edges.append(("Browser", "Client JS"))
27
+
28
+ # Örnek endpoint'ler (gerçekte recon'dan gelebilir)
29
+ for i in range(min(js_count, 4)): # max 4 örnek göster
30
+ ep_name = f"API/Endpoint {i+1}"
31
+ nodes.append(ep_name)
32
+ edges.append(("Client JS", ep_name))
33
+
34
+ # Riskli noktaları vurgula (örnek)
35
+ risky_nodes = []
36
+ if forms_count > 2:
37
+ risky_nodes.append("Form Submission")
38
+ if js_count > 8:
39
+ risky_nodes.append("Client JS")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  return {
42
+ "nodes": nodes,
43
+ "edges": edges,
44
+ "risky": risky_nodes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  }
46
 
47
+ def visualize_attack_graph_plotly(graph_data: Dict) -> go.Figure:
48
+ """Plotly ile interaktif graph"""
49
+ G = nx.DiGraph()
50
+ G.add_edges_from(graph_data["edges"])
51
+
52
+ pos = nx.spring_layout(G, seed=42) # reproducible layout
53
+
54
+ edge_x = []
55
+ edge_y = []
56
+ for edge in G.edges():
57
+ x0, y0 = pos[edge[0]]
58
+ x1, y1 = pos[edge[1]]
59
+ edge_x.extend([x0, x1, None])
60
+ edge_y.extend([y0, y1, None])
61
+
62
+ edge_trace = go.Scatter(
63
+ x=edge_x, y=edge_y,
64
+ line=dict(width=2, color='#888'),
65
+ hoverinfo='none',
66
+ mode='lines'
67
+ )
68
+
69
+ node_x = []
70
+ node_y = []
71
+ node_text = []
72
+ node_color = []
73
+ for node in G.nodes():
74
+ x, y = pos[node]
75
+ node_x.append(x)
76
+ node_y.append(y)
77
+ node_text.append(node)
78
+ if node in graph_data["risky"]:
79
+ node_color.append('#ff4444') # kırmızı = riskli
80
+ else:
81
+ node_color.append('#1f77b4') # mavi = normal
82
+
83
+ node_trace = go.Scatter(
84
+ x=node_x, y=node_y,
85
+ mode='markers+text',
86
+ hoverinfo='text',
87
+ text=node_text,
88
+ textposition="top center",
89
+ marker=dict(
90
+ showscale=False,
91
+ color=node_color,
92
+ size=30,
93
+ line_width=2
94
+ )
95
+ )
96
+
97
+ fig = go.Figure(data=[edge_trace, node_trace],
98
+ layout=go.Layout(
99
+ title='Attack Graph Visualization',
100
+ titlefont_size=16,
101
+ showlegend=False,
102
+ hovermode='closest',
103
+ margin=dict(b=20, l=5, r=5, t=40),
104
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
105
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
106
+ ))
107
+ return fig
108
+
109
+ def visualize_attack_graph_matplotlib(graph_data: Dict) -> str:
110
+ """Fallback: Matplotlib → base64 PNG"""
111
+ G = nx.DiGraph()
112
+ G.add_edges_from(graph_data["edges"])
113
+
114
+ fig, ax = plt.subplots(figsize=(8, 6))
115
+ pos = nx.spring_layout(G, seed=42)
116
+
117
+ node_colors = ['red' if n in graph_data["risky"] else 'lightblue' for n in G.nodes()]
118
+
119
+ nx.draw(G, pos, with_labels=True,
120
+ node_color=node_colors,
121
+ node_size=2200,
122
+ font_size=10,
123
+ font_weight='bold',
124
+ arrows=True,
125
+ arrowstyle='->',
126
+ arrowsize=20,
127
+ ax=ax)
128
+
129
+ ax.set_title("Attack Graph (Static)")
130
+
131
+ buf = BytesIO()
132
+ plt.savefig(buf, format='png', bbox_inches='tight')
133
+ buf.seek(0)
134
+ img_base64 = base64.b64encode(buf.read()).decode('utf-8')
135
+ plt.close(fig)
136
+ return f"data:image/png;base64,{img_base64}"
137
+
138
+ # ────────────────────────────────────────────────
139
+ # full_vuln_scan fonksiyonunu güncelle (graph kısmı)
140
+ # ────────────────────────────────────────────────
141
+ def full_vuln_scan(target_url: str, progress=gr.Progress(track_tqdm=True)):
142
+ # ... (önceki kod aynı, recon_data kısmından sonra ekle)
143
+
144
+ progress(0.75, desc="Attack Graph oluşturuluyor...")
145
+
146
+ graph_data = create_attack_graph_data(recon_data)
147
+ plotly_fig = visualize_attack_graph_plotly(graph_data)
148
+ # matplotlib_fallback = visualize_attack_graph_matplotlib(graph_data) # istersen fallback ekle
149
+
150
+ # ... (diğer sonuçlar aynı)
151
+
152
+ return (
153
+ result_summary,
154
+ json.dumps(enriched_findings, indent=2, ensure_ascii=False),
155
+ plotly_fig, # ← Plotly Figure direkt Plot component'e gider
156
+ history_md
157
+ )
158
+
159
+ # ────────────────────────────────────────────────
160
+ # Gradio Blocks güncellemesi (Attack Graph Tab)
161
+ # ────────────────────────────────────────────────
162
+ with gr.Blocks(...) as demo:
163
+ # ... önceki kısımlar aynı
164
+
165
+ with gr.Tabs():
166
+ # ... diğer tab'lar aynı
167
+ with gr.Tab("Attack Graph"):
168
+ gr.Markdown("### Potansiyel Saldırı Yolu Görselleştirmesi")
169
+ gr.Markdown("(Kırmızı node'lar yüksek riskli alanları gösterir)")
170
+ graph_plot = gr.Plot(label="Interactive Attack Graph (Plotly)")
171
+
172
+ # Events güncelle
173
+ scan_button.click(
174
+ fn=full_vuln_scan,
175
+ inputs=target_input,
176
+ outputs=[summary_output, json_output, graph_plot, history_output],
177
+ # ...
178
+ )
179
+
180
+ # ... kalan kısım aynı