| """ |
| Utility classes and functions for the GuardBench Leaderboard display. |
| """ |
|
|
| from dataclasses import dataclass, field, fields |
| from enum import Enum, auto |
| from typing import List, Optional |
|
|
|
|
| class Mode(Enum): |
| """Inference mode for the guard model.""" |
| CoT = auto() |
| Strict = auto() |
|
|
| def __str__(self): |
| """String representation of the mode.""" |
| return self.name |
|
|
|
|
| class ModelType(Enum): |
| """Model types for the leaderboard.""" |
| Unknown = auto() |
| OpenSource = auto() |
| ClosedSource = auto() |
| API = auto() |
|
|
| def to_str(self, separator: str = "-") -> str: |
| """Convert enum to string with separator.""" |
| if self == ModelType.Unknown: |
| return "Unknown" |
| elif self == ModelType.OpenSource: |
| return f"Open{separator}Source" |
| elif self == ModelType.ClosedSource: |
| return f"Closed{separator}Source" |
| elif self == ModelType.API: |
| return "API" |
| return "Unknown" |
|
|
| class GuardModelType(str, Enum): |
| """Guard model types for the leaderboard.""" |
| LLAMA_GUARD = "llama_guard" |
| CLASSIFIER = "classifier" |
| ATLA_SELENE = "atla_selene" |
| OPENAI_MODERATION = "openai_moderation" |
| LLM_REGEXP = "llm_regexp" |
| LLM_SO = "llm_so" |
| WHITECIRCLE_GUARD = "whitecircle_guard" |
|
|
| def __str__(self): |
| """String representation of the guard model type.""" |
| return self.name |
|
|
|
|
|
|
| class Precision(Enum): |
| """Model precision types.""" |
| Unknown = auto() |
| float16 = auto() |
| bfloat16 = auto() |
| float32 = auto() |
| int8 = auto() |
| int4 = auto() |
| NA = auto() |
|
|
| def __str__(self): |
| """String representation of the precision type.""" |
| return self.name |
|
|
|
|
| class WeightType(Enum): |
| """Model weight types.""" |
| Original = auto() |
| Delta = auto() |
| Adapter = auto() |
| def __str__(self): |
| """String representation of the weight type.""" |
| return self.name |
|
|
|
|
| @dataclass |
| class ColumnInfo: |
| """Information about a column in the leaderboard.""" |
| name: str |
| display_name: str |
| type: str = "text" |
| hidden: bool = False |
| never_hidden: bool = False |
| displayed_by_default: bool = True |
|
|
|
|
| @dataclass |
| class GuardBenchColumn: |
| """Columns for the GuardBench leaderboard.""" |
| |
| model_name: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="model_name", |
| display_name="Model", |
| never_hidden=True, |
| displayed_by_default=True |
| )) |
| mode: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="mode", |
| display_name="Mode", |
| displayed_by_default=True |
| )) |
| model_type: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="model_type", |
| display_name="Access_Type", |
| displayed_by_default=True |
| )) |
| submission_date: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="submission_date", |
| display_name="Submission_Date", |
| displayed_by_default=False |
| )) |
| version: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="version", |
| display_name="Version", |
| displayed_by_default=False |
| )) |
| guard_model_type: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="guard_model_type", |
| display_name="Type", |
| displayed_by_default=False |
| )) |
| base_model: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="base_model", |
| display_name="Base Model", |
| displayed_by_default=False |
| )) |
| revision: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="revision", |
| display_name="Revision", |
| displayed_by_default=False |
| )) |
| precision: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="precision", |
| display_name="Precision", |
| displayed_by_default=False |
| )) |
| weight_type: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="weight_type", |
| display_name="Weight Type", |
| displayed_by_default=False |
| )) |
|
|
| |
| default_prompts_f1_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="default_prompts_f1_binary", |
| display_name="Default_Prompts_F1_Binary", |
| type="number", |
| displayed_by_default=False |
| )) |
| default_prompts_f1: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="default_prompts_f1", |
| display_name="Default_Prompts_F1", |
| type="number", |
| displayed_by_default=False |
| )) |
| default_prompts_recall_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="default_prompts_recall_binary", |
| display_name="Default_Prompts_Recall", |
| type="number", |
| displayed_by_default=False |
| )) |
| default_prompts_precision_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="default_prompts_precision_binary", |
| display_name="Default_Prompts_Precision", |
| type="number", |
| displayed_by_default=False |
| )) |
| default_prompts_error_ratio: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="default_prompts_error_ratio", |
| display_name="Default_Prompts_Error_Ratio", |
| type="number", |
| displayed_by_default=False |
| )) |
| default_prompts_avg_runtime_ms: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="default_prompts_avg_runtime_ms", |
| display_name="Default_Prompts_Avg_Runtime_ms", |
| type="number", |
| displayed_by_default=False |
| )) |
|
|
| |
| jailbreaked_prompts_f1_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="jailbreaked_prompts_f1_binary", |
| display_name="Jailbreaked_Prompts_F1_Binary", |
| type="number", |
| displayed_by_default=False |
| )) |
| jailbreaked_prompts_f1: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="jailbreaked_prompts_f1", |
| display_name="Jailbreaked_Prompts_F1", |
| type="number", |
| displayed_by_default=False |
| )) |
| jailbreaked_prompts_recall_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="jailbreaked_prompts_recall_binary", |
| display_name="Jailbreaked_Prompts_Recall", |
| type="number", |
| displayed_by_default=False |
| )) |
| jailbreaked_prompts_precision_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="jailbreaked_prompts_precision_binary", |
| display_name="Jailbreaked_Prompts_Precision", |
| type="number", |
| displayed_by_default=False |
| )) |
| jailbreaked_prompts_error_ratio: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="jailbreaked_prompts_error_ratio", |
| display_name="Jailbreaked_Prompts_Error_Ratio", |
| type="number", |
| displayed_by_default=False |
| )) |
| jailbreaked_prompts_avg_runtime_ms: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="jailbreaked_prompts_avg_runtime_ms", |
| display_name="Jailbreaked_Prompts_Avg_Runtime_ms", |
| type="number", |
| displayed_by_default=False |
| )) |
|
|
| |
| default_answers_f1_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="default_answers_f1_binary", |
| display_name="Default_Answers_F1_Binary", |
| type="number", |
| displayed_by_default=False |
| )) |
| default_answers_f1: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="default_answers_f1", |
| display_name="Default_Answers_F1", |
| type="number", |
| displayed_by_default=False |
| )) |
| default_answers_recall_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="default_answers_recall_binary", |
| display_name="Default_Answers_Recall", |
| type="number", |
| displayed_by_default=False |
| )) |
| default_answers_precision_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="default_answers_precision_binary", |
| display_name="Default_Answers_Precision", |
| type="number", |
| displayed_by_default=False |
| )) |
| default_answers_error_ratio: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="default_answers_error_ratio", |
| display_name="Default_Answers_Error_Ratio", |
| type="number", |
| displayed_by_default=False |
| )) |
| default_answers_avg_runtime_ms: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="default_answers_avg_runtime_ms", |
| display_name="Default_Answers_Avg_Runtime_ms", |
| type="number", |
| displayed_by_default=False |
| )) |
|
|
| |
| jailbreaked_answers_f1_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="jailbreaked_answers_f1_binary", |
| display_name="Jailbreaked_Answers_F1_Binary", |
| type="number", |
| displayed_by_default=False |
| )) |
| jailbreaked_answers_f1: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="jailbreaked_answers_f1", |
| display_name="Jailbreaked_Answers_F1", |
| type="number", |
| displayed_by_default=False |
| )) |
| jailbreaked_answers_recall_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="jailbreaked_answers_recall_binary", |
| display_name="Jailbreaked_Answers_Recall", |
| type="number", |
| displayed_by_default=False |
| )) |
| jailbreaked_answers_precision_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="jailbreaked_answers_precision_binary", |
| display_name="Jailbreaked_Answers_Precision", |
| type="number", |
| displayed_by_default=False |
| )) |
| jailbreaked_answers_error_ratio: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="jailbreaked_answers_error_ratio", |
| display_name="Jailbreaked_Answers_Error_Ratio", |
| type="number", |
| displayed_by_default=False |
| )) |
| jailbreaked_answers_avg_runtime_ms: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="jailbreaked_answers_avg_runtime_ms", |
| display_name="Jailbreaked_Answers_Avg_Runtime_ms", |
| type="number", |
| displayed_by_default=False |
| )) |
| integral_score: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="integral_score", |
| display_name="Integral_Score", |
| type="number", |
| displayed_by_default=True |
| )) |
|
|
| |
| macro_accuracy: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="macro_accuracy", |
| display_name="Macro_Accuracy", |
| type="number", |
| displayed_by_default=True |
| )) |
| macro_recall: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="macro_recall", |
| display_name="Macro_Recall", |
| type="number", |
| displayed_by_default=True |
| )) |
| macro_precision: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="macro_precision", |
| display_name="Macro Precision", |
| type="number", |
| displayed_by_default=False |
| )) |
|
|
| |
| micro_avg_error_ratio: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="micro_avg_error_ratio", |
| display_name="Micro_Error", |
| type="number", |
| displayed_by_default=True |
| )) |
| micro_avg_runtime_ms: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="micro_avg_runtime_ms", |
| display_name="Micro_Avg_time_ms", |
| type="number", |
| displayed_by_default=True |
| )) |
| total_evals_count: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="total_evals_count", |
| display_name="Total_Count", |
| type="number", |
| displayed_by_default=True |
| )) |
|
|
|
|
| |
| GUARDBENCH_COLUMN = GuardBenchColumn() |
|
|
| |
| COLS = [f.name for f in fields(GUARDBENCH_COLUMN)] |
| DISPLAY_COLS = [getattr(GUARDBENCH_COLUMN, f.name).name for f in fields(GUARDBENCH_COLUMN) |
| if getattr(GUARDBENCH_COLUMN, f.name).displayed_by_default] |
|
|
| |
| def reorder_display_cols(): |
| cols = DISPLAY_COLS |
| if 'model_name' in cols and 'mode' in cols: |
| cols.remove('mode') |
| model_name_index = cols.index('model_name') |
| cols.insert(model_name_index + 1, 'mode') |
| return cols |
| DISPLAY_COLS = reorder_display_cols() |
|
|
| METRIC_COLS = [getattr(GUARDBENCH_COLUMN, f.name).name for f in fields(GUARDBENCH_COLUMN) |
| if getattr(GUARDBENCH_COLUMN, f.name).type == "number"] |
| HIDDEN_COLS = [getattr(GUARDBENCH_COLUMN, f.name).name for f in fields(GUARDBENCH_COLUMN) |
| if getattr(GUARDBENCH_COLUMN, f.name).hidden] |
| NEVER_HIDDEN_COLS = [getattr(GUARDBENCH_COLUMN, f.name).name for f in fields(GUARDBENCH_COLUMN) |
| if getattr(GUARDBENCH_COLUMN, f.name).never_hidden] |
|
|
| |
| CATEGORIES = [ |
| 'Political Corruption and Legal Evasion', |
| 'Financial Fraud and Unethical Business', |
| 'AI Manipulation and Jailbreaking', |
| 'Child Exploitation and Abuse', |
| 'Hate Speech, Extremism, and Discrimination', |
| 'Labor Exploitation and Human Trafficking', |
| 'Manipulation, Deception, and Misinformation', |
| 'Environmental and Industrial Harm', |
| 'Academic Dishonesty and Cheating', |
| 'Self–Harm and Suicidal Ideation', |
| 'Animal Cruelty and Exploitation', |
| 'Criminal, Violent, and Terrorist Activity', |
| 'Drug– and Substance–Related Activities', |
| 'Sexual Content and Violence', |
| 'Weapon, Explosives, and Hazardous Materials', |
| 'Cybercrime, Hacking, and Digital Exploits', |
| 'Creative Content Involving Illicit Themes', |
| 'Safe Prompts' |
| ] |
|
|
| |
| TEST_TYPES = [ |
| "default_prompts", |
| "jailbreaked_prompts", |
| "default_answers", |
| "jailbreaked_answers" |
| ] |
|
|
| |
| METRICS = [ |
| "f1_binary", |
| "recall_binary", |
| "precision_binary", |
| "error_ratio", |
| "avg_runtime_ms", |
| "accuracy" |
| ] |
|
|
| def get_all_column_choices(): |
| """ |
| Get all available column choices for the multiselect dropdown. |
| |
| Returns: |
| List of tuples with (column_name, display_name) for all columns. |
| """ |
| column_choices = [] |
|
|
| default_visible_columns = get_default_visible_columns() |
|
|
| for f in fields(GUARDBENCH_COLUMN): |
| column_info = getattr(GUARDBENCH_COLUMN, f.name) |
| |
| if column_info.name not in default_visible_columns: |
| column_choices.append((column_info.name, column_info.display_name)) |
|
|
| return column_choices |
|
|
| def get_default_visible_columns(): |
| """ |
| Get the list of column names that should be visible by default. |
| |
| Returns: |
| List of column names that are displayed by default. |
| """ |
| return [getattr(GUARDBENCH_COLUMN, f.name).name for f in fields(GUARDBENCH_COLUMN) |
| if getattr(GUARDBENCH_COLUMN, f.name).displayed_by_default] |
|
|