Cesar017's picture
Update app.py
98e6c43 verified
import os
import numpy as np
import pandas as pd
import gradio as gr
import tensorflow as tf
import joblib
# 1. CARGA DE MODELOS Y ARTEFACTOS
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
MODELS_DIR = os.path.join(BASE_DIR, "models")
def load_artifacts():
try:
model_bin = tf.keras.models.load_model(os.path.join(MODELS_DIR, "modelo_base.keras"))
model_multi = tf.keras.models.load_model(os.path.join(MODELS_DIR, "modelo_transfer.keras"))
scaler = joblib.load(os.path.join(MODELS_DIR, "scaler.joblib"))
encoders = joblib.load(os.path.join(MODELS_DIR, "encoders.joblib"))
return model_bin, model_multi, scaler, encoders
except Exception as e:
print(f"Error cargando modelos: {e}")
return None, None, None, None
model_bin, model_multi, scaler, encoders = load_artifacts()
# Mapeo de categorías multiclase
MULTI_CLASSES = {0: "Normal", 1: "DoS", 2: "Probe", 3: "R2L", 4: "U2R"}
def predict_batch(file):
if file is None: return None, "Cargue un archivo CSV."
try:
# Leer archivo (primeras 41 columnas)
df_raw = pd.read_csv(file.name, header=None).iloc[:, :41]
df_proc = df_raw.copy()
# Preprocesamiento Batch
if encoders:
for col, pos in [('protocol_type', 1), ('service', 2), ('flag', 3)]:
df_proc[pos] = encoders[col].transform(df_proc[pos].astype(str))
X = df_proc.values.astype(float)
if scaler:
X = scaler.transform(X)
# Predicciones
preds_bin = model_bin.predict(X, verbose=0)
preds_multi = model_multi.predict(X, verbose=0)
results = []
ataques_count = 0
for i in range(len(X)):
# Lógica Binaria (Umbral 0.5)
# Nota: Si la probabilidad es > 0.5 es Ataque
prob_bin = float(preds_bin[i][0])
is_attack = prob_bin > 0.5
label_bin = "🛑 Ataque" if is_attack else "✅ Normal"
conf_bin = f"{prob_bin:.2%}" if is_attack else f"{(1-prob_bin):.2%}"
if is_attack:
ataques_count += 1
# Lógica Multiclase
idx_multi = np.argmax(preds_multi[i])
label_multi = MULTI_CLASSES.get(idx_multi, "Otros")
conf_multi = f"{np.max(preds_multi[i]):.2%}"
results.append([i+1, label_bin, conf_bin, label_multi, conf_multi])
# Resumen unificado
total = len(results)
normal_count = total - ataques_count
summary = (f"📊 Análisis Completado\n"
f"-------------------\n"
f"Total registros: {total}\n"
f"✅ Tráfico Normal: {normal_count}\n"
f"🛑 Ataques: {ataques_count}")
return results, summary
except Exception as e:
return None, f"Error en el proceso: {str(e)}"
# 3. INTERFAZ
with gr.Blocks(theme=gr.themes.Soft(), title="IDS Demo") as demo:
gr.Markdown("# 🛡️ IDS con Redes Neuronales (NSL-KDD)")
with gr.Row():
file_input = gr.File(label="Subir test_heterogeneo.csv", file_types=[".csv"])
run_btn = gr.Button("🚀 Analizar en Batch", variant="primary")
with gr.Row():
summary_out = gr.Textbox(label="Resumen del Análisis", lines=6)
table_out = gr.DataFrame(
headers=["#", "Clasif. Binaria", "Confianza", "Categoría", "Confianza"],
label="Resultados Detallados"
)
run_btn.click(fn=predict_batch, inputs=file_input, outputs=[table_out, summary_out])
if __name__ == "__main__":
demo.launch()