Che237 commited on
Commit
882fe59
Β·
verified Β·
1 Parent(s): f12e058

Update app.py - add FastAPI REST endpoints + Gemini AI

Browse files
Files changed (1) hide show
  1. app.py +609 -252
app.py CHANGED
@@ -1,6 +1,8 @@
1
  """
2
  CyberForge AI - ML Training & Inference Platform
3
- Hugging Face Spaces deployment with Notebook execution support
 
 
4
  """
5
 
6
  import gradio as gr
@@ -14,6 +16,7 @@ from pathlib import Path
14
  from datetime import datetime
15
  import logging
16
  from typing import Dict, List, Any, Optional, Tuple
 
17
 
18
  # ML Libraries
19
  from sklearn.model_selection import train_test_split, cross_val_score
@@ -26,6 +29,18 @@ import joblib
26
  # Hugging Face Hub
27
  from huggingface_hub import HfApi, hf_hub_download, upload_file
28
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  logging.basicConfig(level=logging.INFO)
30
  logger = logging.getLogger(__name__)
31
 
@@ -33,18 +48,23 @@ logger = logging.getLogger(__name__)
33
  # CONFIGURATION
34
  # ============================================================================
35
 
36
- # Get the directory where app.py is located
37
  APP_DIR = Path(__file__).parent.absolute()
38
-
39
  MODELS_DIR = APP_DIR / "trained_models"
40
  MODELS_DIR.mkdir(exist_ok=True)
41
-
42
  DATASETS_DIR = APP_DIR / "datasets"
43
  DATASETS_DIR.mkdir(exist_ok=True)
44
-
45
  NOTEBOOKS_DIR = APP_DIR / "notebooks"
 
 
 
 
 
 
 
 
 
 
46
 
47
- # Log paths for debugging
48
  logger.info(f"APP_DIR: {APP_DIR}")
49
  logger.info(f"NOTEBOOKS_DIR: {NOTEBOOKS_DIR}")
50
  logger.info(f"NOTEBOOKS_DIR exists: {NOTEBOOKS_DIR.exists()}")
@@ -57,276 +77,495 @@ MODEL_TYPES = {
57
  "Isolation Forest (Anomaly)": IsolationForest,
58
  }
59
 
60
- # Cybersecurity task categories
61
  SECURITY_TASKS = [
62
- "Malware Detection",
63
- "Phishing Detection",
64
- "Network Intrusion Detection",
65
- "Anomaly Detection",
66
- "Botnet Detection",
67
- "Web Attack Detection",
68
- "Spam Detection",
69
- "Vulnerability Assessment",
70
- "DNS Tunneling Detection",
71
  "Cryptomining Detection",
72
  ]
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  # ============================================================================
75
- # NOTEBOOK EXECUTION
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  # ============================================================================
77
 
78
  def get_available_notebooks() -> List[str]:
79
- """Get list of available notebooks"""
80
  if not NOTEBOOKS_DIR.exists():
81
  return []
82
-
83
- notebooks = sorted([
84
- f.name for f in NOTEBOOKS_DIR.glob("*.ipynb")
85
- ])
86
- return notebooks
87
 
88
  def read_notebook_content(notebook_name: str) -> str:
89
  """Read and display notebook content as markdown"""
90
  notebook_path = NOTEBOOKS_DIR / notebook_name
91
  if not notebook_path.exists():
92
  return f"Notebook not found: {notebook_name}"
93
-
94
  try:
95
- with open(notebook_path, 'r') as f:
96
  nb = json.load(f)
97
-
98
  output = f"# {notebook_name}\n\n"
99
-
100
- for i, cell in enumerate(nb.get('cells', []), 1):
101
- cell_type = cell.get('cell_type', 'code')
102
- source = ''.join(cell.get('source', []))
103
-
104
- if cell_type == 'markdown':
105
  output += f"{source}\n\n"
106
  else:
107
  output += f"### Cell {i} (Python)\n```python\n{source}\n```\n\n"
108
-
109
  return output
110
  except Exception as e:
111
  return f"Error reading notebook: {str(e)}"
112
 
 
113
  def execute_notebook(notebook_name: str, progress=gr.Progress()) -> Tuple[str, str]:
114
  """Execute a notebook and return output"""
115
  notebook_path = NOTEBOOKS_DIR / notebook_name
116
  output_path = NOTEBOOKS_DIR / f"output_{notebook_name}"
117
-
118
  if not notebook_path.exists():
119
- # Debug: list what's actually in the directory
120
  available = list(NOTEBOOKS_DIR.glob("*.ipynb")) if NOTEBOOKS_DIR.exists() else []
121
- return f"Error: Notebook not found: {notebook_path}\nAvailable: {available}\nDir exists: {NOTEBOOKS_DIR.exists()}", ""
122
-
123
  progress(0.1, desc="Starting notebook execution...")
124
-
125
  try:
126
- # Execute notebook using nbconvert with absolute paths
127
  cmd = [
128
  sys.executable, "-m", "nbconvert",
129
- "--to", "notebook",
130
- "--execute",
131
  "--output", str(output_path.absolute()),
132
  "--ExecutePreprocessor.timeout=600",
133
  "--ExecutePreprocessor.kernel_name=python3",
134
- str(notebook_path.absolute())
135
  ]
136
-
137
  progress(0.3, desc="Executing cells...")
138
-
139
- result = subprocess.run(
140
- cmd,
141
- capture_output=True,
142
- text=True,
143
- cwd=str(NOTEBOOKS_DIR),
144
- timeout=900
145
- )
146
-
147
  progress(0.8, desc="Processing output...")
148
-
149
  if result.returncode == 0:
150
- # Read executed notebook for outputs
151
  if output_path.exists():
152
- with open(output_path, 'r') as f:
153
  executed_nb = json.load(f)
154
-
155
  outputs = []
156
- for i, cell in enumerate(executed_nb.get('cells', []), 1):
157
- if cell.get('cell_type') == 'code':
158
- cell_outputs = cell.get('outputs', [])
159
- for out in cell_outputs:
160
- if 'text' in out:
161
- text = ''.join(out['text'])
162
- outputs.append(f"Cell {i}:\n{text}")
163
- elif 'data' in out:
164
- if 'text/plain' in out['data']:
165
- text = ''.join(out['data']['text/plain'])
166
- outputs.append(f"Cell {i}:\n{text}")
167
-
168
  progress(1.0, desc="Complete!")
169
  return "Notebook executed successfully!", "\n\n".join(outputs)
170
  else:
171
  return "Notebook executed but output file not found", result.stdout
172
  else:
173
  return f"Execution failed:\n{result.stderr}", result.stdout
174
-
175
  except subprocess.TimeoutExpired:
176
  return "Error: Notebook execution timed out (15 min limit)", ""
177
  except Exception as e:
178
  return f"Error executing notebook: {str(e)}", ""
179
 
 
180
  def run_notebook_cell(notebook_name: str, cell_number: int) -> str:
181
  """Execute a single cell from a notebook"""
182
  notebook_path = NOTEBOOKS_DIR / notebook_name
183
-
184
  if not notebook_path.exists():
185
  return f"Error: Notebook not found at {notebook_path}"
186
-
187
  try:
188
- # Change to notebooks directory so relative paths work
189
  original_cwd = os.getcwd()
190
  os.chdir(NOTEBOOKS_DIR)
191
-
192
- with open(notebook_path, 'r') as f:
193
  nb = json.load(f)
194
-
195
- cells = [c for c in nb.get('cells', []) if c.get('cell_type') == 'code']
196
-
197
  if cell_number < 1 or cell_number > len(cells):
198
  os.chdir(original_cwd)
199
  return f"Error: Cell {cell_number} not found. Available: 1-{len(cells)}"
200
-
201
- cell = cells[cell_number - 1]
202
- source = ''.join(cell.get('source', []))
203
-
204
- # Execute the code with proper namespace
205
  import io
206
  from contextlib import redirect_stdout, redirect_stderr
207
-
208
- # Create a namespace with common imports
209
- namespace = {
210
- '__name__': '__main__',
211
- '__file__': str(notebook_path),
212
- }
213
-
214
  stdout_capture = io.StringIO()
215
  stderr_capture = io.StringIO()
216
-
217
  with redirect_stdout(stdout_capture), redirect_stderr(stderr_capture):
218
  try:
219
  exec(source, namespace)
220
  except Exception as e:
221
  os.chdir(original_cwd)
222
  return f"Error: {str(e)}"
223
-
224
  os.chdir(original_cwd)
225
-
226
  output = stdout_capture.getvalue()
227
  errors = stderr_capture.getvalue()
228
-
229
- result = f"### Cell {cell_number} Output:\n"
230
  if output:
231
- result += f"```\n{output}\n```\n"
232
  if errors:
233
- result += f"\n**Warnings/Errors:**\n```\n{errors}\n```"
234
  if not output and not errors:
235
- result += "*(No output)*"
236
-
237
- return result
238
-
239
  except Exception as e:
240
  try:
241
  os.chdir(original_cwd)
242
- except:
243
  pass
244
  return f"Error: {str(e)}"
245
 
 
246
  # ============================================================================
247
- # MODEL TRAINING (existing functionality)
248
  # ============================================================================
249
 
250
  class SecurityModelTrainer:
251
- """Train ML models for cybersecurity tasks"""
252
-
253
  def __init__(self):
254
  self.scaler = StandardScaler()
255
  self.label_encoder = LabelEncoder()
256
- self.models = {}
257
-
258
- def prepare_data(self, df: pd.DataFrame, target_col: str = 'label') -> Tuple:
259
- """Prepare data for training"""
260
  if target_col not in df.columns:
261
  raise ValueError(f"Target column '{target_col}' not found")
262
-
263
  X = df.drop(columns=[target_col])
264
  y = df[target_col]
265
-
266
- # Handle categorical columns
267
  X = X.select_dtypes(include=[np.number]).fillna(0)
268
-
269
- if y.dtype == 'object':
270
  y = self.label_encoder.fit_transform(y)
271
-
272
  X_scaled = self.scaler.fit_transform(X)
273
-
274
  return train_test_split(X_scaled, y, test_size=0.2, random_state=42)
275
-
276
  def train_model(self, model_type: str, X_train, y_train):
277
- """Train a model"""
278
  if model_type not in MODEL_TYPES:
279
  raise ValueError(f"Unknown model type: {model_type}")
280
-
281
  model_class = MODEL_TYPES[model_type]
282
-
283
  if model_type == "Isolation Forest (Anomaly)":
284
  model = model_class(contamination=0.1, random_state=42)
285
  else:
286
  model = model_class(random_state=42)
287
-
288
  model.fit(X_train, y_train)
289
  return model
290
-
291
  def evaluate_model(self, model, X_test, y_test) -> Dict:
292
- """Evaluate model performance"""
293
  y_pred = model.predict(X_test)
294
-
295
- metrics = {
296
- 'accuracy': accuracy_score(y_test, y_pred),
297
- 'f1_score': f1_score(y_test, y_pred, average='weighted', zero_division=0)
298
  }
299
-
300
- return metrics
301
 
302
  trainer = SecurityModelTrainer()
303
 
 
304
  def train_model_from_data(data_file, model_type: str, task: str, progress=gr.Progress()):
305
  """Train model from uploaded data"""
306
  if data_file is None:
307
  return "Please upload a CSV file", None, None
308
-
309
  progress(0.1, desc="Loading data...")
310
-
311
  try:
312
  df = pd.read_csv(data_file.name)
313
  progress(0.3, desc="Preparing data...")
314
-
315
  X_train, X_test, y_train, y_test = trainer.prepare_data(df)
316
-
317
  progress(0.5, desc=f"Training {model_type}...")
318
  model = trainer.train_model(model_type, X_train, y_train)
319
-
320
  progress(0.8, desc="Evaluating model...")
321
  metrics = trainer.evaluate_model(model, X_test, y_test)
322
-
323
- # Save model
324
  model_name = f"{task.lower().replace(' ', '_')}_{model_type.lower().replace(' ', '_')}"
325
  model_path = MODELS_DIR / f"{model_name}.pkl"
326
  joblib.dump(model, model_path)
327
-
328
  progress(1.0, desc="Complete!")
329
-
330
  result = f"""
331
  ## Training Complete!
332
 
@@ -340,224 +579,342 @@ def train_model_from_data(data_file, model_type: str, task: str, progress=gr.Pro
340
 
341
  **Model saved to:** {model_path}
342
  """
343
-
344
  return result, str(model_path), json.dumps(metrics, indent=2)
345
-
346
  except Exception as e:
347
  return f"Error: {str(e)}", None, None
348
 
 
349
  def run_inference(model_file, features_text: str):
350
  """Run inference with a trained model"""
351
  if model_file is None:
352
  return "Please upload a model file"
353
-
354
  try:
355
  model = joblib.load(model_file.name)
356
-
357
  features = json.loads(features_text)
358
  X = np.array([list(features.values())])
359
-
360
  prediction = model.predict(X)[0]
361
-
362
- result = {
363
- 'prediction': int(prediction),
364
- 'features_used': len(features)
365
- }
366
-
367
- if hasattr(model, 'predict_proba'):
368
  proba = model.predict_proba(X)[0]
369
- result['confidence'] = float(max(proba))
370
-
371
  return json.dumps(result, indent=2)
372
-
373
  except Exception as e:
374
  return f"Error: {str(e)}"
375
 
 
376
  def list_trained_models():
377
- """List all trained models"""
378
  models = list(MODELS_DIR.glob("*.pkl"))
379
  if not models:
380
  return "No trained models found"
381
-
382
  output = "## Trained Models\n\n"
383
  for model_path in models:
384
  size_kb = model_path.stat().st_size / 1024
385
  output += f"- **{model_path.name}** ({size_kb:.1f} KB)\n"
386
-
387
  return output
388
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
  # ============================================================================
390
- # GRADIO INTERFACE
391
  # ============================================================================
392
 
393
  def create_interface():
394
- """Create the Gradio interface"""
395
-
396
  with gr.Blocks(title="CyberForge AI", theme=gr.themes.Soft()) as demo:
397
  gr.Markdown("""
398
- # πŸ” CyberForge AI - ML Training Platform
399
 
400
- Train cybersecurity ML models and run Jupyter notebooks on Hugging Face.
401
  """)
402
-
403
  with gr.Tabs():
404
  # ============ NOTEBOOKS TAB ============
405
  with gr.TabItem("πŸ““ Notebooks"):
406
- gr.Markdown("""
407
- ### Run ML Pipeline Notebooks
408
- Execute the CyberForge ML notebooks directly in the cloud.
409
- """)
410
-
411
  with gr.Row():
412
  with gr.Column(scale=1):
413
  notebook_dropdown = gr.Dropdown(
414
  choices=get_available_notebooks(),
415
  label="Select Notebook",
416
- value=get_available_notebooks()[0] if get_available_notebooks() else None
417
  )
418
-
419
  refresh_btn = gr.Button("πŸ”„ Refresh List")
420
  view_btn = gr.Button("πŸ‘ View Content", variant="secondary")
421
  execute_btn = gr.Button("β–Ά Execute Notebook", variant="primary")
422
-
423
  gr.Markdown("### Run Single Cell")
424
  cell_number = gr.Number(label="Cell Number", value=1, minimum=1)
425
  run_cell_btn = gr.Button("Run Cell")
426
-
427
  with gr.Column(scale=2):
428
  notebook_status = gr.Markdown("Select a notebook to view or execute.")
429
  notebook_output = gr.Markdown("", label="Output")
430
-
431
  def refresh_notebooks():
432
  notebooks = get_available_notebooks()
433
  return gr.update(choices=notebooks, value=notebooks[0] if notebooks else None)
434
-
435
  refresh_btn.click(refresh_notebooks, outputs=notebook_dropdown)
436
  view_btn.click(read_notebook_content, inputs=notebook_dropdown, outputs=notebook_output)
437
  execute_btn.click(execute_notebook, inputs=notebook_dropdown, outputs=[notebook_status, notebook_output])
438
  run_cell_btn.click(run_notebook_cell, inputs=[notebook_dropdown, cell_number], outputs=notebook_output)
439
-
440
  # ============ TRAIN MODEL TAB ============
441
  with gr.TabItem("🎯 Train Model"):
442
- gr.Markdown("""
443
- ### Train a Security ML Model
444
- Upload your dataset and train a model for threat detection.
445
- """)
446
-
447
  with gr.Row():
448
  with gr.Column():
449
- task_dropdown = gr.Dropdown(
450
- choices=SECURITY_TASKS,
451
- label="Security Task",
452
- value="Phishing Detection"
453
- )
454
- model_dropdown = gr.Dropdown(
455
- choices=list(MODEL_TYPES.keys()),
456
- label="Model Type",
457
- value="Random Forest"
458
- )
459
  data_upload = gr.File(label="Upload Training Data (CSV)", file_types=[".csv"])
460
  train_btn = gr.Button("πŸš€ Train Model", variant="primary")
461
-
462
  with gr.Column():
463
  train_output = gr.Markdown("Upload data and click Train to begin.")
464
  model_path_output = gr.Textbox(label="Model Path", visible=False)
465
  metrics_output = gr.Textbox(label="Metrics JSON", visible=False)
466
-
467
- train_btn.click(
468
- train_model_from_data,
469
- inputs=[data_upload, model_dropdown, task_dropdown],
470
- outputs=[train_output, model_path_output, metrics_output]
471
- )
472
-
473
  # ============ INFERENCE TAB ============
474
  with gr.TabItem("πŸ” Inference"):
475
- gr.Markdown("""
476
- ### Run Model Inference
477
- Load a trained model and make predictions.
478
- """)
479
-
480
  with gr.Row():
481
  with gr.Column():
482
  model_upload = gr.File(label="Upload Model (.pkl)")
483
- features_input = gr.Textbox(
484
- label="Features (JSON)",
485
- value='{"url_length": 50, "has_https": 1, "digit_count": 5}',
486
- lines=5
487
- )
488
  predict_btn = gr.Button("🎯 Predict", variant="primary")
489
-
490
  with gr.Column():
491
  prediction_output = gr.Textbox(label="Prediction Result", lines=10)
492
-
493
  predict_btn.click(run_inference, inputs=[model_upload, features_input], outputs=prediction_output)
494
-
495
  # ============ MODELS TAB ============
496
  with gr.TabItem("πŸ“¦ Models"):
497
  gr.Markdown("### Trained Models")
498
  models_list = gr.Markdown(list_trained_models())
499
  refresh_models_btn = gr.Button("πŸ”„ Refresh")
500
  refresh_models_btn.click(list_trained_models, outputs=models_list)
501
-
502
  # ============ API TAB ============
503
  with gr.TabItem("πŸ”Œ API"):
504
  gr.Markdown("""
505
  ## API Integration
506
 
507
- ### Python Client
508
 
509
- ```python
510
- from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
511
 
512
- client = InferenceClient("Che237/cyberforge")
513
-
514
- # Make prediction
515
- result = client.predict(
516
- model_name="phishing_detection",
517
- features={"url_length": 50, "has_https": 1}
518
- )
519
- print(result)
520
- ```
521
-
522
- ### REST API
523
 
524
  ```bash
525
- curl -X POST https://che237-cyberforge.hf.space/api/predict \\
526
  -H "Content-Type: application/json" \\
527
- -d '{"model_name": "phishing_detection", "features": {"url_length": 50}}'
528
  ```
529
 
530
- ### Notebook Execution
531
 
532
- The notebooks in this Space implement the complete CyberForge ML pipeline:
533
-
534
- | # | Notebook | Purpose |
535
- |---|----------|---------|
536
- | 00 | environment_setup | System validation |
537
- | 01 | data_acquisition | Data collection |
538
- | 02 | feature_engineering | Feature extraction |
539
- | 03 | model_training | Train models |
540
- | 04 | agent_intelligence | AI reasoning |
541
- | 05 | model_validation | Testing |
542
- | 06 | backend_integration | API packaging |
543
- | 07 | deployment_artifacts | Deployment |
544
  """)
545
-
546
- gr.Markdown("""
547
- ---
548
- **CyberForge AI** | [GitHub](https://github.com/Che237/cyberforge) | [Datasets](https://huggingface.co/datasets/Che237/cyberforge-datasets)
549
- """)
550
-
551
  return demo
552
 
 
553
  # ============================================================================
554
- # MAIN
555
  # ============================================================================
556
 
 
 
 
 
 
 
 
 
 
 
557
  if __name__ == "__main__":
558
- demo = create_interface()
559
- # Docker deployment configuration
560
- demo.launch(
561
- server_name="0.0.0.0",
562
- server_port=7860
563
- )
 
1
  """
2
  CyberForge AI - ML Training & Inference Platform
3
+ Hugging Face Spaces deployment with:
4
+ 1) Gradio UI for notebook execution, training, and inference
5
+ 2) FastAPI REST endpoints for the Heroku backend (mlService.js)
6
  """
7
 
8
  import gradio as gr
 
16
  from datetime import datetime
17
  import logging
18
  from typing import Dict, List, Any, Optional, Tuple
19
+ from urllib.parse import urlparse
20
 
21
  # ML Libraries
22
  from sklearn.model_selection import train_test_split, cross_val_score
 
29
  # Hugging Face Hub
30
  from huggingface_hub import HfApi, hf_hub_download, upload_file
31
 
32
+ # FastAPI for REST endpoints
33
+ from fastapi import FastAPI, Request
34
+ from fastapi.middleware.cors import CORSMiddleware
35
+ from fastapi.responses import JSONResponse
36
+
37
+ # Gemini AI (new SDK: google-genai)
38
+ try:
39
+ from google import genai
40
+ GEMINI_AVAILABLE = True
41
+ except ImportError:
42
+ GEMINI_AVAILABLE = False
43
+
44
  logging.basicConfig(level=logging.INFO)
45
  logger = logging.getLogger(__name__)
46
 
 
48
  # CONFIGURATION
49
  # ============================================================================
50
 
 
51
  APP_DIR = Path(__file__).parent.absolute()
 
52
  MODELS_DIR = APP_DIR / "trained_models"
53
  MODELS_DIR.mkdir(exist_ok=True)
 
54
  DATASETS_DIR = APP_DIR / "datasets"
55
  DATASETS_DIR.mkdir(exist_ok=True)
 
56
  NOTEBOOKS_DIR = APP_DIR / "notebooks"
57
+ KNOWLEDGE_BASE_DIR = APP_DIR / "knowledge_base"
58
+ KNOWLEDGE_BASE_DIR.mkdir(exist_ok=True)
59
+ TRAINING_DATA_DIR = APP_DIR / "training_data"
60
+ TRAINING_DATA_DIR.mkdir(exist_ok=True)
61
+
62
+ # Environment
63
+ GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "")
64
+ GEMINI_MODEL = os.environ.get("GEMINI_MODEL", "gemini-2.0-flash-exp")
65
+ HF_TOKEN = os.environ.get("HF_TOKEN", "")
66
+ HF_MODEL_REPO = os.environ.get("HF_MODEL_REPO", "Che237/cyberforge-models")
67
 
 
68
  logger.info(f"APP_DIR: {APP_DIR}")
69
  logger.info(f"NOTEBOOKS_DIR: {NOTEBOOKS_DIR}")
70
  logger.info(f"NOTEBOOKS_DIR exists: {NOTEBOOKS_DIR.exists()}")
 
77
  "Isolation Forest (Anomaly)": IsolationForest,
78
  }
79
 
 
80
  SECURITY_TASKS = [
81
+ "Malware Detection", "Phishing Detection", "Network Intrusion Detection",
82
+ "Anomaly Detection", "Botnet Detection", "Web Attack Detection",
83
+ "Spam Detection", "Vulnerability Assessment", "DNS Tunneling Detection",
 
 
 
 
 
 
84
  "Cryptomining Detection",
85
  ]
86
 
87
+
88
+ # ============================================================================
89
+ # GEMINI SERVICE (for REST API)
90
+ # ============================================================================
91
+
92
+ class GeminiService:
93
+ """Google Gemini AI for cybersecurity chat and analysis"""
94
+
95
+ SYSTEM_PROMPT = """You are CyberForge AI, an advanced cybersecurity expert. You specialize in:
96
+ - Real-time threat detection and analysis
97
+ - Malware and phishing identification
98
+ - Network security assessment
99
+ - Browser security monitoring
100
+ - Risk assessment and mitigation
101
+
102
+ When analyzing security queries, provide:
103
+ 1. Risk Level (Critical/High/Medium/Low)
104
+ 2. Threat Types identified
105
+ 3. Confidence Score (0.0-1.0)
106
+ 4. Detailed technical analysis
107
+ 5. Specific actionable recommendations
108
+
109
+ Always be precise, professional, and actionable."""
110
+
111
+ def __init__(self):
112
+ self.client = None
113
+ self.ready = False
114
+ self.custom_knowledge = {}
115
+ self.training_examples = []
116
+
117
+ def initialize(self):
118
+ if not GEMINI_AVAILABLE:
119
+ logger.warning("google-genai not installed, Gemini unavailable")
120
+ return
121
+ if not GEMINI_API_KEY:
122
+ logger.warning("GEMINI_API_KEY not set, Gemini unavailable")
123
+ return
124
+ try:
125
+ self.client = genai.Client(api_key=GEMINI_API_KEY)
126
+ resp = self.client.models.generate_content(
127
+ model=GEMINI_MODEL,
128
+ contents="Test. Respond with OK."
129
+ )
130
+ if resp.text:
131
+ self.ready = True
132
+ logger.info(f"βœ… Gemini initialized (model: {GEMINI_MODEL})")
133
+ self._load_knowledge()
134
+ except Exception as e:
135
+ logger.error(f"❌ Gemini init failed: {e}")
136
+ self.ready = False
137
+
138
+ def _load_knowledge(self):
139
+ try:
140
+ for f in KNOWLEDGE_BASE_DIR.glob("*.json"):
141
+ with open(f) as fh:
142
+ self.custom_knowledge[f.stem] = json.load(fh)
143
+ logger.info(f"Loaded {len(self.custom_knowledge)} knowledge files")
144
+ except Exception as e:
145
+ logger.warning(f"Knowledge load error: {e}")
146
+ try:
147
+ for f in TRAINING_DATA_DIR.glob("*.json"):
148
+ with open(f) as fh:
149
+ data = json.load(fh)
150
+ if isinstance(data, list):
151
+ self.training_examples.extend(data)
152
+ else:
153
+ self.training_examples.append(data)
154
+ logger.info(f"Loaded {len(self.training_examples)} training examples")
155
+ except Exception as e:
156
+ logger.warning(f"Training data load error: {e}")
157
+
158
+ def analyze(self, query: str, context: Dict = None, history: List = None) -> Dict:
159
+ if not self.ready:
160
+ return self._fallback(query)
161
+ try:
162
+ knowledge_str = json.dumps(self.custom_knowledge, indent=1)[:2000] if self.custom_knowledge else "None"
163
+ examples_str = "\n".join(
164
+ f"Q: {ex.get('input','')}\nA: {ex.get('output','')}"
165
+ for ex in self.training_examples[-3:]
166
+ ) if self.training_examples else "None"
167
+ context_str = f"\nCONTEXT:\n{json.dumps(context, indent=1)[:3000]}\n" if context else ""
168
+
169
+ prompt = f"""{self.SYSTEM_PROMPT}
170
+
171
+ KNOWLEDGE BASE (summary):
172
+ {knowledge_str}
173
+
174
+ TRAINING EXAMPLES:
175
+ {examples_str}
176
+ {context_str}
177
+ USER QUERY:
178
+ {query}
179
+
180
+ Provide a comprehensive cybersecurity analysis:"""
181
+
182
+ response = self.client.models.generate_content(
183
+ model=GEMINI_MODEL,
184
+ contents=prompt,
185
+ config={"temperature": 0.3, "max_output_tokens": 2048},
186
+ )
187
+ text = response.text if response.text else "No response generated"
188
+ text_lower = text.lower()
189
+ if "critical" in text_lower:
190
+ risk_level, risk_score = "Critical", 9.0
191
+ elif "high" in text_lower:
192
+ risk_level, risk_score = "High", 7.0
193
+ elif "medium" in text_lower:
194
+ risk_level, risk_score = "Medium", 5.0
195
+ else:
196
+ risk_level, risk_score = "Low", 3.0
197
+
198
+ return {
199
+ "response": text,
200
+ "confidence": 0.85,
201
+ "risk_level": risk_level,
202
+ "risk_score": risk_score,
203
+ "insights": [f"Analysis performed by Gemini ({GEMINI_MODEL})"],
204
+ "recommendations": [],
205
+ "model_used": GEMINI_MODEL,
206
+ "timestamp": datetime.utcnow().isoformat(),
207
+ }
208
+ except Exception as e:
209
+ logger.error(f"Gemini analysis error: {e}")
210
+ return self._fallback(query)
211
+
212
+ def _fallback(self, query: str) -> Dict:
213
+ return {
214
+ "response": (
215
+ "CyberForge AI is temporarily running in limited mode. "
216
+ "The Gemini AI service could not process your request. "
217
+ "Please check that the GEMINI_API_KEY is set correctly in the Space secrets."
218
+ ),
219
+ "confidence": 0.1,
220
+ "risk_level": "Unknown",
221
+ "risk_score": 0,
222
+ "insights": [],
223
+ "recommendations": ["Verify GEMINI_API_KEY is set", "Check Space logs"],
224
+ "model_used": "fallback",
225
+ "timestamp": datetime.utcnow().isoformat(),
226
+ }
227
+
228
+
229
+ # Singleton
230
+ gemini_service = GeminiService()
231
+
232
+
233
  # ============================================================================
234
+ # ML MODEL LOADER (loads .pkl from HF Hub for REST predictions)
235
+ # ============================================================================
236
+
237
+ class MLModelLoader:
238
+ MODEL_NAMES = [
239
+ "phishing_detection", "malware_detection",
240
+ "anomaly_detection", "web_attack_detection",
241
+ ]
242
+
243
+ def __init__(self):
244
+ self.models: Dict[str, Any] = {}
245
+ self.scalers: Dict[str, Any] = {}
246
+ self.ready = False
247
+
248
+ def initialize(self):
249
+ loaded = 0
250
+ for name in self.MODEL_NAMES:
251
+ try:
252
+ model_file = f"{name}/best_model.pkl"
253
+ scaler_file = f"{name}/scaler.pkl"
254
+ try:
255
+ model_path = hf_hub_download(
256
+ repo_id=HF_MODEL_REPO, filename=model_file,
257
+ token=HF_TOKEN or None, cache_dir=str(MODELS_DIR),
258
+ )
259
+ scaler_path = hf_hub_download(
260
+ repo_id=HF_MODEL_REPO, filename=scaler_file,
261
+ token=HF_TOKEN or None, cache_dir=str(MODELS_DIR),
262
+ )
263
+ self.models[name] = joblib.load(model_path)
264
+ self.scalers[name] = joblib.load(scaler_path)
265
+ loaded += 1
266
+ logger.info(f"βœ… Loaded model from Hub: {name}")
267
+ except Exception:
268
+ # Try flat filename pattern
269
+ try:
270
+ model_path = hf_hub_download(
271
+ repo_id=HF_MODEL_REPO, filename=f"{name}_model.pkl",
272
+ token=HF_TOKEN or None, cache_dir=str(MODELS_DIR),
273
+ )
274
+ self.models[name] = joblib.load(model_path)
275
+ loaded += 1
276
+ logger.info(f"βœ… Loaded model (flat): {name}")
277
+ except Exception:
278
+ pass
279
+ # Try local
280
+ local_model = MODELS_DIR / name / "best_model.pkl"
281
+ if local_model.exists() and name not in self.models:
282
+ self.models[name] = joblib.load(local_model)
283
+ local_scaler = MODELS_DIR / name / "scaler.pkl"
284
+ if local_scaler.exists():
285
+ self.scalers[name] = joblib.load(local_scaler)
286
+ loaded += 1
287
+ logger.info(f"βœ… Loaded model from local: {name}")
288
+ except Exception as e:
289
+ logger.warning(f"Error loading model {name}: {e}")
290
+
291
+ # Also check trained_models dir for any .pkl files
292
+ for pkl in MODELS_DIR.glob("*.pkl"):
293
+ stem = pkl.stem.replace("_model", "").replace("_best", "")
294
+ if stem not in self.models:
295
+ try:
296
+ self.models[stem] = joblib.load(pkl)
297
+ loaded += 1
298
+ logger.info(f"βœ… Loaded local model: {stem}")
299
+ except Exception:
300
+ pass
301
+
302
+ self.ready = loaded > 0
303
+ logger.info(f"ML Models: {loaded} loaded ({list(ml_loader.models.keys()) if loaded else 'none'})")
304
+
305
+ def predict(self, model_name: str, features: Dict) -> Dict:
306
+ if model_name not in self.models:
307
+ return self._heuristic_predict(model_name, features)
308
+ try:
309
+ model = self.models[model_name]
310
+ scaler = self.scalers.get(model_name)
311
+ X = np.array([list(features.values())])
312
+ if scaler:
313
+ X = scaler.transform(X)
314
+ prediction = int(model.predict(X)[0])
315
+ confidence = 0.5
316
+ if hasattr(model, "predict_proba"):
317
+ proba = model.predict_proba(X)[0]
318
+ confidence = float(max(proba))
319
+ return {
320
+ "model": model_name,
321
+ "prediction": prediction,
322
+ "prediction_label": "threat" if prediction == 1 else "benign",
323
+ "confidence": confidence,
324
+ "inference_source": "ml_model",
325
+ "timestamp": datetime.utcnow().isoformat(),
326
+ }
327
+ except Exception as e:
328
+ logger.error(f"Prediction error {model_name}: {e}")
329
+ return self._heuristic_predict(model_name, features)
330
+
331
+ def _heuristic_predict(self, model_name: str, features: Dict) -> Dict:
332
+ score = 0.0
333
+ reasons = []
334
+ is_https = features.get("is_https", features.get("has_https", 1))
335
+ has_ip = features.get("has_ip_address", 0)
336
+ suspicious_tld = features.get("has_suspicious_tld", 0)
337
+ url_length = features.get("url_length", 0)
338
+ special_chars = features.get("special_char_count", 0)
339
+ if not is_https:
340
+ score += 0.3; reasons.append("No HTTPS")
341
+ if has_ip:
342
+ score += 0.25; reasons.append("IP address in URL")
343
+ if suspicious_tld:
344
+ score += 0.2; reasons.append("Suspicious TLD")
345
+ if url_length > 100:
346
+ score += 0.15; reasons.append("Very long URL")
347
+ if special_chars > 10:
348
+ score += 0.1; reasons.append("Many special characters")
349
+ is_threat = score >= 0.5
350
+ return {
351
+ "model": model_name,
352
+ "prediction": 1 if is_threat else 0,
353
+ "prediction_label": "threat" if is_threat else "benign",
354
+ "confidence": min(score + 0.3, 0.95) if is_threat else max(0.6, 1.0 - score),
355
+ "threat_score": score,
356
+ "reasons": reasons,
357
+ "inference_source": "heuristic",
358
+ "timestamp": datetime.utcnow().isoformat(),
359
+ }
360
+
361
+
362
+ ml_loader = MLModelLoader()
363
+
364
+
365
+ # ============================================================================
366
+ # URL FEATURE EXTRACTION
367
+ # ============================================================================
368
+
369
+ def extract_url_features(url: str) -> Dict:
370
+ parsed = urlparse(url)
371
+ hostname = parsed.hostname or ""
372
+ return {
373
+ "url_length": len(url),
374
+ "hostname_length": len(hostname),
375
+ "path_length": len(parsed.path or ""),
376
+ "is_https": 1 if parsed.scheme == "https" else 0,
377
+ "has_ip_address": 1 if all(p.isdigit() for p in hostname.split(".")) and len(hostname.split(".")) == 4 else 0,
378
+ "has_suspicious_tld": 1 if any(hostname.endswith(t) for t in [".xyz", ".tk", ".ml", ".ga", ".cf", ".top", ".buzz"]) else 0,
379
+ "subdomain_count": max(0, len(hostname.split(".")) - 2),
380
+ "has_port": 1 if parsed.port else 0,
381
+ "query_params_count": len(parsed.query.split("&")) if parsed.query else 0,
382
+ "has_at_symbol": 1 if "@" in url else 0,
383
+ "has_double_slash": 1 if "//" in (parsed.path or "") else 0,
384
+ "special_char_count": sum(1 for c in url if c in "!@#$%^&*()+={}[]|\\:;<>?,"),
385
+ }
386
+
387
+
388
+ # ============================================================================
389
+ # NOTEBOOK EXECUTION (existing Gradio functionality)
390
  # ============================================================================
391
 
392
  def get_available_notebooks() -> List[str]:
 
393
  if not NOTEBOOKS_DIR.exists():
394
  return []
395
+ return sorted([f.name for f in NOTEBOOKS_DIR.glob("*.ipynb")])
396
+
 
 
 
397
 
398
  def read_notebook_content(notebook_name: str) -> str:
399
  """Read and display notebook content as markdown"""
400
  notebook_path = NOTEBOOKS_DIR / notebook_name
401
  if not notebook_path.exists():
402
  return f"Notebook not found: {notebook_name}"
 
403
  try:
404
+ with open(notebook_path, "r") as f:
405
  nb = json.load(f)
 
406
  output = f"# {notebook_name}\n\n"
407
+ for i, cell in enumerate(nb.get("cells", []), 1):
408
+ cell_type = cell.get("cell_type", "code")
409
+ source = "".join(cell.get("source", []))
410
+ if cell_type == "markdown":
 
 
411
  output += f"{source}\n\n"
412
  else:
413
  output += f"### Cell {i} (Python)\n```python\n{source}\n```\n\n"
 
414
  return output
415
  except Exception as e:
416
  return f"Error reading notebook: {str(e)}"
417
 
418
+
419
  def execute_notebook(notebook_name: str, progress=gr.Progress()) -> Tuple[str, str]:
420
  """Execute a notebook and return output"""
421
  notebook_path = NOTEBOOKS_DIR / notebook_name
422
  output_path = NOTEBOOKS_DIR / f"output_{notebook_name}"
 
423
  if not notebook_path.exists():
 
424
  available = list(NOTEBOOKS_DIR.glob("*.ipynb")) if NOTEBOOKS_DIR.exists() else []
425
+ return f"Error: Notebook not found: {notebook_path}\nAvailable: {available}", ""
 
426
  progress(0.1, desc="Starting notebook execution...")
 
427
  try:
 
428
  cmd = [
429
  sys.executable, "-m", "nbconvert",
430
+ "--to", "notebook", "--execute",
 
431
  "--output", str(output_path.absolute()),
432
  "--ExecutePreprocessor.timeout=600",
433
  "--ExecutePreprocessor.kernel_name=python3",
434
+ str(notebook_path.absolute()),
435
  ]
 
436
  progress(0.3, desc="Executing cells...")
437
+ result = subprocess.run(cmd, capture_output=True, text=True, cwd=str(NOTEBOOKS_DIR), timeout=900)
 
 
 
 
 
 
 
 
438
  progress(0.8, desc="Processing output...")
 
439
  if result.returncode == 0:
 
440
  if output_path.exists():
441
+ with open(output_path, "r") as f:
442
  executed_nb = json.load(f)
 
443
  outputs = []
444
+ for i, cell in enumerate(executed_nb.get("cells", []), 1):
445
+ if cell.get("cell_type") == "code":
446
+ for out in cell.get("outputs", []):
447
+ if "text" in out:
448
+ outputs.append(f"Cell {i}:\n{''.join(out['text'])}")
449
+ elif "data" in out and "text/plain" in out["data"]:
450
+ outputs.append(f"Cell {i}:\n{''.join(out['data']['text/plain'])}")
 
 
 
 
 
451
  progress(1.0, desc="Complete!")
452
  return "Notebook executed successfully!", "\n\n".join(outputs)
453
  else:
454
  return "Notebook executed but output file not found", result.stdout
455
  else:
456
  return f"Execution failed:\n{result.stderr}", result.stdout
 
457
  except subprocess.TimeoutExpired:
458
  return "Error: Notebook execution timed out (15 min limit)", ""
459
  except Exception as e:
460
  return f"Error executing notebook: {str(e)}", ""
461
 
462
+
463
  def run_notebook_cell(notebook_name: str, cell_number: int) -> str:
464
  """Execute a single cell from a notebook"""
465
  notebook_path = NOTEBOOKS_DIR / notebook_name
 
466
  if not notebook_path.exists():
467
  return f"Error: Notebook not found at {notebook_path}"
 
468
  try:
 
469
  original_cwd = os.getcwd()
470
  os.chdir(NOTEBOOKS_DIR)
471
+ with open(notebook_path, "r") as f:
 
472
  nb = json.load(f)
473
+ cells = [c for c in nb.get("cells", []) if c.get("cell_type") == "code"]
 
 
474
  if cell_number < 1 or cell_number > len(cells):
475
  os.chdir(original_cwd)
476
  return f"Error: Cell {cell_number} not found. Available: 1-{len(cells)}"
477
+ cell = cells[int(cell_number) - 1]
478
+ source = "".join(cell.get("source", []))
 
 
 
479
  import io
480
  from contextlib import redirect_stdout, redirect_stderr
481
+
482
+ namespace = {"__name__": "__main__", "__file__": str(notebook_path)}
 
 
 
 
 
483
  stdout_capture = io.StringIO()
484
  stderr_capture = io.StringIO()
 
485
  with redirect_stdout(stdout_capture), redirect_stderr(stderr_capture):
486
  try:
487
  exec(source, namespace)
488
  except Exception as e:
489
  os.chdir(original_cwd)
490
  return f"Error: {str(e)}"
 
491
  os.chdir(original_cwd)
 
492
  output = stdout_capture.getvalue()
493
  errors = stderr_capture.getvalue()
494
+ result_text = f"### Cell {int(cell_number)} Output:\n"
 
495
  if output:
496
+ result_text += f"```\n{output}\n```\n"
497
  if errors:
498
+ result_text += f"\n**Warnings/Errors:**\n```\n{errors}\n```"
499
  if not output and not errors:
500
+ result_text += "*(No output)*"
501
+ return result_text
 
 
502
  except Exception as e:
503
  try:
504
  os.chdir(original_cwd)
505
+ except Exception:
506
  pass
507
  return f"Error: {str(e)}"
508
 
509
+
510
  # ============================================================================
511
+ # MODEL TRAINING (existing Gradio functionality)
512
  # ============================================================================
513
 
514
  class SecurityModelTrainer:
 
 
515
  def __init__(self):
516
  self.scaler = StandardScaler()
517
  self.label_encoder = LabelEncoder()
518
+
519
+ def prepare_data(self, df: pd.DataFrame, target_col: str = "label") -> Tuple:
 
 
520
  if target_col not in df.columns:
521
  raise ValueError(f"Target column '{target_col}' not found")
 
522
  X = df.drop(columns=[target_col])
523
  y = df[target_col]
 
 
524
  X = X.select_dtypes(include=[np.number]).fillna(0)
525
+ if y.dtype == "object":
 
526
  y = self.label_encoder.fit_transform(y)
 
527
  X_scaled = self.scaler.fit_transform(X)
 
528
  return train_test_split(X_scaled, y, test_size=0.2, random_state=42)
529
+
530
  def train_model(self, model_type: str, X_train, y_train):
 
531
  if model_type not in MODEL_TYPES:
532
  raise ValueError(f"Unknown model type: {model_type}")
 
533
  model_class = MODEL_TYPES[model_type]
 
534
  if model_type == "Isolation Forest (Anomaly)":
535
  model = model_class(contamination=0.1, random_state=42)
536
  else:
537
  model = model_class(random_state=42)
 
538
  model.fit(X_train, y_train)
539
  return model
540
+
541
  def evaluate_model(self, model, X_test, y_test) -> Dict:
 
542
  y_pred = model.predict(X_test)
543
+ return {
544
+ "accuracy": accuracy_score(y_test, y_pred),
545
+ "f1_score": f1_score(y_test, y_pred, average="weighted", zero_division=0),
 
546
  }
547
+
 
548
 
549
  trainer = SecurityModelTrainer()
550
 
551
+
552
  def train_model_from_data(data_file, model_type: str, task: str, progress=gr.Progress()):
553
  """Train model from uploaded data"""
554
  if data_file is None:
555
  return "Please upload a CSV file", None, None
 
556
  progress(0.1, desc="Loading data...")
 
557
  try:
558
  df = pd.read_csv(data_file.name)
559
  progress(0.3, desc="Preparing data...")
 
560
  X_train, X_test, y_train, y_test = trainer.prepare_data(df)
 
561
  progress(0.5, desc=f"Training {model_type}...")
562
  model = trainer.train_model(model_type, X_train, y_train)
 
563
  progress(0.8, desc="Evaluating model...")
564
  metrics = trainer.evaluate_model(model, X_test, y_test)
 
 
565
  model_name = f"{task.lower().replace(' ', '_')}_{model_type.lower().replace(' ', '_')}"
566
  model_path = MODELS_DIR / f"{model_name}.pkl"
567
  joblib.dump(model, model_path)
 
568
  progress(1.0, desc="Complete!")
 
569
  result = f"""
570
  ## Training Complete!
571
 
 
579
 
580
  **Model saved to:** {model_path}
581
  """
 
582
  return result, str(model_path), json.dumps(metrics, indent=2)
 
583
  except Exception as e:
584
  return f"Error: {str(e)}", None, None
585
 
586
+
587
  def run_inference(model_file, features_text: str):
588
  """Run inference with a trained model"""
589
  if model_file is None:
590
  return "Please upload a model file"
 
591
  try:
592
  model = joblib.load(model_file.name)
 
593
  features = json.loads(features_text)
594
  X = np.array([list(features.values())])
 
595
  prediction = model.predict(X)[0]
596
+ result = {"prediction": int(prediction), "features_used": len(features)}
597
+ if hasattr(model, "predict_proba"):
 
 
 
 
 
598
  proba = model.predict_proba(X)[0]
599
+ result["confidence"] = float(max(proba))
600
+ result["probabilities"] = {str(i): float(p) for i, p in enumerate(proba)}
601
  return json.dumps(result, indent=2)
 
602
  except Exception as e:
603
  return f"Error: {str(e)}"
604
 
605
+
606
  def list_trained_models():
 
607
  models = list(MODELS_DIR.glob("*.pkl"))
608
  if not models:
609
  return "No trained models found"
 
610
  output = "## Trained Models\n\n"
611
  for model_path in models:
612
  size_kb = model_path.stat().st_size / 1024
613
  output += f"- **{model_path.name}** ({size_kb:.1f} KB)\n"
 
614
  return output
615
 
616
+
617
+ # ============================================================================
618
+ # FASTAPI APP (REST endpoints for Heroku backend mlService.js)
619
+ # ============================================================================
620
+
621
+ api = FastAPI(title="CyberForge AI API", version="1.0.0")
622
+
623
+ api.add_middleware(
624
+ CORSMiddleware,
625
+ allow_origins=["*"],
626
+ allow_credentials=True,
627
+ allow_methods=["*"],
628
+ allow_headers=["*"],
629
+ )
630
+
631
+
632
+ @api.get("/health")
633
+ async def api_health():
634
+ return {
635
+ "status": "healthy",
636
+ "timestamp": datetime.utcnow().isoformat(),
637
+ "services": {
638
+ "gemini": gemini_service.ready,
639
+ "ml_models": ml_loader.ready,
640
+ "models_loaded": list(ml_loader.models.keys()),
641
+ "gradio_ui": True,
642
+ },
643
+ "version": "1.0.0",
644
+ }
645
+
646
+
647
+ @api.post("/analyze")
648
+ async def api_analyze(request: Request):
649
+ """Main analysis endpoint – called by backend mlService.chatWithAI()"""
650
+ body = await request.json()
651
+ query = body.get("query", "")
652
+ context = body.get("context", {})
653
+ history = body.get("conversation_history", [])
654
+ result = gemini_service.analyze(query, context, history)
655
+ return result
656
+
657
+
658
+ @api.post("/analyze-url")
659
+ async def api_analyze_url(request: Request):
660
+ """URL analysis – called by backend mlService.analyzeWebsite()"""
661
+ body = await request.json()
662
+ url = body.get("url", "")
663
+ if not url:
664
+ return JSONResponse(status_code=400, content={"detail": "URL required"})
665
+ features = extract_url_features(url)
666
+ predictions = {}
667
+ for model_name in ml_loader.MODEL_NAMES:
668
+ predictions[model_name] = ml_loader.predict(model_name, features)
669
+ scores = [
670
+ p.get("threat_score", p.get("confidence", 0.5) if p.get("prediction", 0) == 1 else 0.2)
671
+ for p in predictions.values()
672
+ ]
673
+ avg_score = sum(scores) / len(scores) if scores else 0
674
+ return {
675
+ "url": url,
676
+ "aggregate": {
677
+ "average_threat_score": round(avg_score, 3),
678
+ "overall_risk_level": (
679
+ "critical" if avg_score > 0.8
680
+ else "high" if avg_score > 0.6
681
+ else "medium" if avg_score > 0.4
682
+ else "low"
683
+ ),
684
+ },
685
+ "model_predictions": predictions,
686
+ "features_analyzed": features,
687
+ "timestamp": datetime.utcnow().isoformat(),
688
+ }
689
+
690
+
691
+ @api.post("/scan-threats")
692
+ async def api_scan_threats(request: Request):
693
+ """Threat scanning – called by backend mlService.scanForThreats()"""
694
+ body = await request.json()
695
+ query = json.dumps(body.get("data", body), indent=1)[:3000]
696
+ result = gemini_service.analyze(f"Scan these indicators for threats:\n{query}")
697
+ return result
698
+
699
+
700
+ @api.post("/api/insights/generate")
701
+ async def api_generate_insights(request: Request):
702
+ """AI insights – called by backend mlService.getAIInsights()"""
703
+ body = await request.json()
704
+ query = body.get("query", "")
705
+ context = body.get("context", {})
706
+ result = gemini_service.analyze(query, context)
707
+ return {
708
+ "insights": result.get("response", ""),
709
+ "confidence": result.get("confidence", 0),
710
+ "timestamp": datetime.utcnow().isoformat(),
711
+ }
712
+
713
+
714
+ @api.post("/api/models/predict")
715
+ async def api_model_predict(request: Request):
716
+ """Model prediction – called by backend mlService.getModelPrediction()"""
717
+ body = await request.json()
718
+ model_type = body.get("model_type", "phishing_detection")
719
+ input_data = body.get("input_data", {})
720
+ result = ml_loader.predict(model_type, input_data)
721
+ return result
722
+
723
+
724
+ @api.get("/api/models/list")
725
+ @api.get("/models")
726
+ async def api_list_models():
727
+ result = []
728
+ for name in ml_loader.MODEL_NAMES:
729
+ result.append({
730
+ "name": name,
731
+ "loaded": name in ml_loader.models,
732
+ "source": "ml_model" if name in ml_loader.models else "heuristic",
733
+ })
734
+ return {"models": result, "total": len(ml_loader.MODEL_NAMES), "loaded": len(ml_loader.models)}
735
+
736
+
737
+ @api.post("/api/analysis/network")
738
+ async def api_analyze_network(request: Request):
739
+ """Network traffic analysis – called by backend mlService.analyzeNetworkTraffic()"""
740
+ body = await request.json()
741
+ traffic_data = body.get("traffic_data", {})
742
+ result = gemini_service.analyze(
743
+ f"Analyze this network traffic for security threats:\n{json.dumps(traffic_data, indent=1)[:3000]}"
744
+ )
745
+ return result
746
+
747
+
748
+ @api.post("/api/ai/execute-task")
749
+ async def api_execute_task(request: Request):
750
+ """AI task execution – called by backend mlService.executeAITask()"""
751
+ body = await request.json()
752
+ task_type = body.get("task_type", "")
753
+ task_data = body.get("task_data", {})
754
+ result = gemini_service.analyze(
755
+ f"Execute cybersecurity task '{task_type}':\n{json.dumps(task_data, indent=1)[:3000]}"
756
+ )
757
+ return result
758
+
759
+
760
+ @api.post("/api/browser/analyze")
761
+ async def api_analyze_browser(request: Request):
762
+ """Browser session analysis"""
763
+ body = await request.json()
764
+ session = body.get("session_data", body)
765
+ result = gemini_service.analyze(
766
+ f"Analyze this browser session for security threats:\n{json.dumps(session, indent=1)[:3000]}"
767
+ )
768
+ return result
769
+
770
+
771
+ @api.post("/api/threat-feeds/analyze")
772
+ async def api_analyze_threat_feed(request: Request):
773
+ """Threat feed analysis"""
774
+ body = await request.json()
775
+ result = gemini_service.analyze(
776
+ f"Analyze this threat feed:\n{json.dumps(body, indent=1)[:3000]}"
777
+ )
778
+ return result
779
+
780
+
781
+ @api.post("/api/datasets/process")
782
+ async def api_process_dataset(request: Request):
783
+ body = await request.json()
784
+ return {
785
+ "status": "processed",
786
+ "dataset_name": body.get("dataset_name", ""),
787
+ "timestamp": datetime.utcnow().isoformat(),
788
+ }
789
+
790
+
791
  # ============================================================================
792
+ # GRADIO INTERFACE (UI tabs – notebooks, training, inference, models, API)
793
  # ============================================================================
794
 
795
  def create_interface():
 
 
796
  with gr.Blocks(title="CyberForge AI", theme=gr.themes.Soft()) as demo:
797
  gr.Markdown("""
798
+ # πŸ” CyberForge AI - ML Training Platform
799
 
800
+ Train cybersecurity ML models and run Jupyter notebooks on Hugging Face.
801
  """)
802
+
803
  with gr.Tabs():
804
  # ============ NOTEBOOKS TAB ============
805
  with gr.TabItem("πŸ““ Notebooks"):
806
+ gr.Markdown("### Run ML Pipeline Notebooks\nExecute the CyberForge ML notebooks directly in the cloud.")
 
 
 
 
807
  with gr.Row():
808
  with gr.Column(scale=1):
809
  notebook_dropdown = gr.Dropdown(
810
  choices=get_available_notebooks(),
811
  label="Select Notebook",
812
+ value=get_available_notebooks()[0] if get_available_notebooks() else None,
813
  )
 
814
  refresh_btn = gr.Button("πŸ”„ Refresh List")
815
  view_btn = gr.Button("πŸ‘ View Content", variant="secondary")
816
  execute_btn = gr.Button("β–Ά Execute Notebook", variant="primary")
 
817
  gr.Markdown("### Run Single Cell")
818
  cell_number = gr.Number(label="Cell Number", value=1, minimum=1)
819
  run_cell_btn = gr.Button("Run Cell")
 
820
  with gr.Column(scale=2):
821
  notebook_status = gr.Markdown("Select a notebook to view or execute.")
822
  notebook_output = gr.Markdown("", label="Output")
823
+
824
  def refresh_notebooks():
825
  notebooks = get_available_notebooks()
826
  return gr.update(choices=notebooks, value=notebooks[0] if notebooks else None)
827
+
828
  refresh_btn.click(refresh_notebooks, outputs=notebook_dropdown)
829
  view_btn.click(read_notebook_content, inputs=notebook_dropdown, outputs=notebook_output)
830
  execute_btn.click(execute_notebook, inputs=notebook_dropdown, outputs=[notebook_status, notebook_output])
831
  run_cell_btn.click(run_notebook_cell, inputs=[notebook_dropdown, cell_number], outputs=notebook_output)
832
+
833
  # ============ TRAIN MODEL TAB ============
834
  with gr.TabItem("🎯 Train Model"):
835
+ gr.Markdown("### Train a Security ML Model\nUpload your dataset and train a model for threat detection.")
 
 
 
 
836
  with gr.Row():
837
  with gr.Column():
838
+ task_dropdown = gr.Dropdown(choices=SECURITY_TASKS, label="Security Task", value="Phishing Detection")
839
+ model_dropdown = gr.Dropdown(choices=list(MODEL_TYPES.keys()), label="Model Type", value="Random Forest")
 
 
 
 
 
 
 
 
840
  data_upload = gr.File(label="Upload Training Data (CSV)", file_types=[".csv"])
841
  train_btn = gr.Button("πŸš€ Train Model", variant="primary")
 
842
  with gr.Column():
843
  train_output = gr.Markdown("Upload data and click Train to begin.")
844
  model_path_output = gr.Textbox(label="Model Path", visible=False)
845
  metrics_output = gr.Textbox(label="Metrics JSON", visible=False)
846
+ train_btn.click(train_model_from_data, inputs=[data_upload, model_dropdown, task_dropdown], outputs=[train_output, model_path_output, metrics_output])
847
+
 
 
 
 
 
848
  # ============ INFERENCE TAB ============
849
  with gr.TabItem("πŸ” Inference"):
850
+ gr.Markdown("### Run Model Inference\nLoad a trained model and make predictions.")
 
 
 
 
851
  with gr.Row():
852
  with gr.Column():
853
  model_upload = gr.File(label="Upload Model (.pkl)")
854
+ features_input = gr.Textbox(label="Features (JSON)", value='{"url_length": 50, "has_https": 1, "digit_count": 5}', lines=5)
 
 
 
 
855
  predict_btn = gr.Button("🎯 Predict", variant="primary")
 
856
  with gr.Column():
857
  prediction_output = gr.Textbox(label="Prediction Result", lines=10)
 
858
  predict_btn.click(run_inference, inputs=[model_upload, features_input], outputs=prediction_output)
859
+
860
  # ============ MODELS TAB ============
861
  with gr.TabItem("πŸ“¦ Models"):
862
  gr.Markdown("### Trained Models")
863
  models_list = gr.Markdown(list_trained_models())
864
  refresh_models_btn = gr.Button("πŸ”„ Refresh")
865
  refresh_models_btn.click(list_trained_models, outputs=models_list)
866
+
867
  # ============ API TAB ============
868
  with gr.TabItem("πŸ”Œ API"):
869
  gr.Markdown("""
870
  ## API Integration
871
 
872
+ ### REST Endpoints (for Backend)
873
 
874
+ | Method | Endpoint | Description |
875
+ |--------|----------|-------------|
876
+ | GET | `/health` | Health check |
877
+ | POST | `/analyze` | AI chat (Gemini) |
878
+ | POST | `/analyze-url` | URL threat analysis |
879
+ | POST | `/scan-threats` | Threat scanning |
880
+ | POST | `/api/insights/generate` | AI insights |
881
+ | POST | `/api/models/predict` | ML model prediction |
882
+ | GET | `/models` | List available models |
883
 
884
+ ### Example
 
 
 
 
 
 
 
 
 
 
885
 
886
  ```bash
887
+ curl -X POST https://che237-cyberforge.hf.space/analyze \\
888
  -H "Content-Type: application/json" \\
889
+ -d '{"query": "Is this URL safe: http://example.com/login"}'
890
  ```
891
 
892
+ ### Gradio API
893
 
894
+ The Gradio interface also exposes API endpoints for notebook execution and model training.
895
+ See the API tab at the bottom of this page.
 
 
 
 
 
 
 
 
 
 
896
  """)
897
+
898
+ gr.Markdown("---\n**CyberForge AI** | [GitHub](https://github.com/Che237/cyberforge) | [Datasets](https://huggingface.co/datasets/Che237/cyberforge-datasets)")
899
+
 
 
 
900
  return demo
901
 
902
+
903
  # ============================================================================
904
+ # MAIN – Mount Gradio on FastAPI
905
  # ============================================================================
906
 
907
+ # Initialize services at startup
908
+ logger.info("πŸš€ Initializing CyberForge AI services...")
909
+ gemini_service.initialize()
910
+ ml_loader.initialize()
911
+ logger.info("βœ… Services initialized")
912
+
913
+ # Create Gradio app and mount it on FastAPI
914
+ demo = create_interface()
915
+ app = gr.mount_gradio_app(api, demo, path="/")
916
+
917
  if __name__ == "__main__":
918
+ import uvicorn
919
+ port = int(os.environ.get("PORT", 7860))
920
+ uvicorn.run(app, host="0.0.0.0", port=port)