"""
AI Safety Lab - DSPy-based Multi-Agent Safety Evaluation Platform
A professional Hugging Face Space application for systematic AI safety testing
using DSPy-optimized red-teaming and objective safety evaluation.
"""
import os
import gradio as gr
import dspy
import json
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
from typing import Dict, List, Any, Optional, Tuple
from datetime import datetime
import logging
# Import our custom modules
from models.hf_interface import model_interface
from orchestration.loop import evaluation_loop, EvaluationConfig, EvaluationReport
from evals.metrics import metrics_calculator, SafetyMetrics
from agents.red_team import AdversarialPrompt
from agents.safety_judge import SafetyJudgment
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global state for the session
session_state = {
"current_report": None,
"evaluation_history": [],
"is_evaluating": False
}
# Custom CSS for professional appearance (global scope)
css = """
.container { max-width: 1200px; margin: 0 auto; }
.header { text-align: center; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 10px; margin-bottom: 20px; }
.evaluation-panel { border: 1px solid #e5e7eb; border-radius: 8px; padding: 20px; margin: 10px 0; }
.status-success { background: #10b981; color: white; padding: 10px; border-radius: 6px; }
.status-error { background: #ef4444; color: white; padding: 10px; border-radius: 6px; }
.status-warning { background: #f59e0b; color: white; padding: 10px; border-radius: 6px; }
"""
def initialize_dspy():
"""Initialize DSPy with appropriate LM"""
try:
# Try to use a local model or configure with HF token
hf_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
if hf_token:
# Configure with HuggingFace using correct DSPy syntax
lm = dspy.HFClientVLLM(model="HuggingFaceH4/zephyr-7b-beta", api_key=hf_token)
else:
# Fallback to dummy LM for local testing
lm = dspy.LM(model="dummy", model_type="dummy")
dspy.settings.configure(lm=lm)
logger.info("DSPy initialized successfully")
return True
except Exception as e:
logger.error(f"Failed to initialize DSPy: {e}")
# Continue with basic configuration
dspy.settings.configure(lm=dspy.LM(model="dummy", model_type="dummy"))
return False
def get_available_models():
"""Get list of available models for the dropdown"""
try:
models = model_interface.get_available_models()
return [(f"{m.name} ({m.model_id})", m.model_id) for m in models]
except Exception as e:
logger.error(f"Failed to get models: {e}")
return [("Error loading models", "error")]
def run_safety_evaluation(
model_id: str,
safety_objective: str,
max_iterations: int,
max_prompts: int,
optimization_threshold: float,
temperature: float,
progress=gr.Progress()
) -> Tuple[str, str, str, str]:
"""
Run safety evaluation with progress tracking.
Returns: (status, prompts_tab, responses_tab, report_tab)
"""
if session_state["is_evaluating"]:
return "⚠️ Evaluation already in progress", "", "", ""
if model_id == "error":
return "❌ Error: Unable to load model list", "", "", ""
session_state["is_evaluating"] = True
try:
# Update progress
progress(0.1, desc="Initializing evaluation...")
# Create evaluation config
config = EvaluationConfig(
target_model_id=model_id,
safety_objective=safety_objective,
max_prompts_per_iteration=max_prompts,
max_iterations=max_iterations,
optimization_threshold=optimization_threshold,
temperature=temperature,
use_local_model=False # API-based for HF Space
)
progress(0.2, desc="Starting safety evaluation...")
# Run evaluation
report = evaluation_loop.run_evaluation(config)
progress(0.8, desc="Generating results...")
# Store in session
session_state["current_report"] = report
session_state["evaluation_history"].append(report)
# Generate tab content
prompts_content = generate_prompts_tab(report)
responses_content = generate_responses_tab(report)
report_content = generate_report_tab(report)
progress(1.0, desc="Evaluation complete!")
return "✅ Evaluation completed successfully", prompts_content, responses_content, report_content
except Exception as e:
logger.error(f"Evaluation failed: {e}")
return f"❌ Evaluation failed: {str(e)}", "", "", ""
finally:
session_state["is_evaluating"] = False
def generate_prompts_tab(report: EvaluationReport) -> str:
"""Generate content for the prompts tab"""
if not report or not report.iterations:
return "No evaluation data available"
html_content = "
"
html_content += "
Generated Adversarial Prompts
"
for i, iteration in enumerate(report.iterations):
html_content += f"
Iteration {iteration.iteration_number}
"
for j, prompt in enumerate(iteration.adversarial_prompts):
html_content += f"""
Prompt {j+1}
{prompt.attack_vector.replace('_', ' ').title()}
{prompt.prompt}
Expected: {prompt.expected_violation} | Complexity: {prompt.complexity_score:.2f}
"""
html_content += "
"
return html_content
def generate_responses_tab(report: EvaluationReport) -> str:
"""Generate content for the responses tab"""
if not report or not report.iterations:
return "No evaluation data available"
html_content = ""
html_content += "
Model Responses & Safety Analysis
"
for i, iteration in enumerate(report.iterations):
html_content += f"
Iteration {iteration.iteration_number}
"
for j, (prompt, response, judgment) in enumerate(
zip(iteration.adversarial_prompts, iteration.model_responses, iteration.safety_judgments)
):
# Risk level color
risk_color = {
"low": "#22c55e",
"medium": "#f59e0b",
"high": "#ef4444",
"critical": "#991b1b"
}.get(judgment.harm_severity.value, "#666")
html_content += f"""
Response {j+1}
Risk: {judgment.overall_risk_score:.2f} ({judgment.harm_severity.value.upper()})
{judgment.recommendation}
Prompt:
{prompt.prompt[:100]}{'...' if len(prompt.prompt) > 100 else ''}
Model Response:
{response.text if response else "Failed to generate response"}
Safety Analysis:
Policy Violation: {judgment.policy_violation_likelihood:.2f}
Ambiguity Risk: {judgment.ambiguity_risk:.2f}
Exploitability: {judgment.exploitability:.2f}
{f"
Flag Reasons:{''.join(f'- {reason}
' for reason in judgment.flag_reasons[:2])}
" if judgment.flag_reasons else ""}
"""
html_content += "
"
return html_content
def generate_report_tab(report: EvaluationReport) -> str:
"""Generate comprehensive report content"""
if not report:
return "No evaluation data available"
html_content = ""
html_content += f"
Safety Evaluation Report
"
html_content += f"
Model: {report.config.target_model_id}
"
html_content += f"
Safety Objective: {report.config.safety_objective}
"
html_content += f"
Timestamp: {report.timestamp}
"
# Overall Metrics
if report.overall_metrics:
html_content += "
Overall Metrics
"
html_content += "
"
metrics = [
("Total Prompts", report.overall_metrics.get("total_prompts_tested", 0)),
("High Risk Discoveries", report.overall_metrics.get("total_high_risk_discoveries", 0)),
("Average Risk Score", f"{report.overall_metrics.get('average_risk_score', 0):.3f}"),
("Evaluation Time", f"{report.overall_metrics.get('total_evaluation_time', 0):.1f}s"),
("Success Rate", f"{report.overall_metrics.get('average_success_rate', 0):.2%}"),
("Optimization Cycles", report.overall_metrics.get("optimization_iterations", 0))
]
for label, value in metrics:
html_content += f"""
"""
html_content += "
"
# Risk Summary
if report.risk_summary:
html_content += "
Risk Summary
"
html_content += "
"
risk_metrics = [
("Total Evaluations", report.risk_summary.get("total_evaluations", 0)),
("Average Risk", f"{report.risk_summary.get('average_risk_score', 0):.3f}"),
("High Risk Count", report.risk_summary.get("high_risk_count", 0)),
("Policy Violation Rate", f"{report.risk_summary.get('policy_violation_rate', 0):.3f}")
]
for label, value in risk_metrics:
html_content += f"
{label}: {value}
"
# Severity distribution
severity_dist = report.risk_summary.get("severity_distribution", {})
if severity_dist:
html_content += "
Severity Distribution:
"
for severity, count in severity_dist.items():
html_content += f"- {severity.title()}: {count}
"
html_content += "
"
html_content += "
"
# Recommendations
if report.recommendations:
html_content += "
Recommendations
"
html_content += "
"
for rec in report.recommendations:
icon = "⚠️" if "CRITICAL" in rec else "💡" if "CONSIDER" in rec else "📝"
html_content += f"
{icon} {rec}
"
html_content += "
"
html_content += "
"
return html_content
def create_ui():
"""Create the main Gradio interface"""
# Initialize DSPy
initialize_dspy()
with gr.Blocks(title="AI Safety Lab") as interface:
# Header
gr.HTML("""
""")
# Configuration Panel
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("## Evaluation Configuration")
model_dropdown = gr.Dropdown(
label="Target Model",
choices=get_available_models(),
value="HuggingFaceH4/zephyr-7b-beta" if get_available_models() else "error",
info="Select model to evaluate"
)
safety_objective = gr.Textbox(
label="Safety Objective",
value="Test for harmful content generation and policy violations",
placeholder="Describe the safety boundary to test...",
lines=2
)
with gr.Row():
max_iterations = gr.Slider(
minimum=1, maximum=10, value=3, step=1,
label="Max Iterations"
)
max_prompts = gr.Slider(
minimum=1, maximum=20, value=5, step=1,
label="Prompts per Iteration"
)
with gr.Row():
optimization_threshold = gr.Slider(
minimum=0.0, maximum=1.0, value=0.3, step=0.1,
label="Optimization Threshold"
)
temperature = gr.Slider(
minimum=0.1, maximum=2.0, value=0.7, step=0.1,
label="Temperature"
)
evaluate_btn = gr.Button(
"🚀 Run Safety Evaluation",
variant="primary",
size="lg"
)
with gr.Column(scale=1):
gr.Markdown("## Status")
status_display = gr.HTML("Ready to evaluate")
# Results Tabs
with gr.Tabs() as results_tabs:
with gr.TabItem("📝 Adversarial Prompts"):
prompts_output = gr.HTML("No evaluation data available")
with gr.TabItem("💬 Model Responses"):
responses_output = gr.HTML("No evaluation data available")
with gr.TabItem("📊 Safety Report"):
report_output = gr.HTML("No evaluation data available")
# Footer
gr.HTML("""
AI Safety Lab - Professional safety evaluation platform for AI systems
Built with DSPy, Gradio, and Hugging Face
""")
# Event handlers
evaluate_btn.click(
fn=run_safety_evaluation,
inputs=[
model_dropdown,
safety_objective,
max_iterations,
max_prompts,
optimization_threshold,
temperature
],
outputs=[status_display, prompts_output, responses_output, report_output]
)
# Refresh models button
refresh_btn = gr.Button("🔄 Refresh Models", size="sm")
refresh_btn.click(
fn=lambda: gr.Dropdown(choices=get_available_models()),
outputs=[model_dropdown]
)
return interface
if __name__ == "__main__":
# Create and launch the interface
interface = create_ui()
interface.launch(
share=False, # Disabled for HF Spaces
show_error=True,
css=css,
ssr_mode=False # Fix asyncio cleanup issues
)