| import logging | |
| import time | |
| from datetime import timedelta | |
| from typing import Dict, List | |
| import streamlit as st | |
| from llm_guard.input_scanners.anonymize import DEFAULT_ENTITY_TYPES | |
| from llm_guard.input_scanners.code import SUPPORTED_LANGUAGES as SUPPORTED_CODE_LANGUAGES | |
| from llm_guard.output_scanners import get_scanner_by_name | |
| from llm_guard.output_scanners.bias import MatchType as BiasMatchType | |
| from llm_guard.output_scanners.deanonymize import MatchingStrategy as DeanonymizeMatchingStrategy | |
| from llm_guard.output_scanners.gibberish import MatchType as GibberishMatchType | |
| from llm_guard.output_scanners.language import MatchType as LanguageMatchType | |
| from llm_guard.output_scanners.toxicity import MatchType as ToxicityMatchType | |
| from llm_guard.vault import Vault | |
| from streamlit_tags import st_tags | |
| logger = logging.getLogger("llm-guard-playground") | |
| def init_settings() -> (List, Dict): | |
| all_scanners = [ | |
| "BanCode", | |
| "BanCompetitors", | |
| "BanSubstrings", | |
| "BanTopics", | |
| "Bias", | |
| "Code", | |
| "Deanonymize", | |
| "JSON", | |
| "Language", | |
| "LanguageSame", | |
| "MaliciousURLs", | |
| "NoRefusal", | |
| "NoRefusalLight" "ReadingTime", | |
| "FactualConsistency", | |
| "Gibberish", | |
| "Regex", | |
| "Relevance", | |
| "Sensitive", | |
| "Sentiment", | |
| "Toxicity", | |
| "URLReachability", | |
| ] | |
| st_enabled_scanners = st.sidebar.multiselect( | |
| "Select scanners", | |
| options=all_scanners, | |
| default=all_scanners, | |
| help="The list can be found here: https://protectai.github.io/llm-guard/output_scanners/bias/", | |
| ) | |
| settings = {} | |
| if "BanCode" in st_enabled_scanners: | |
| st_bc_expander = st.sidebar.expander( | |
| "Ban Code", | |
| expanded=False, | |
| ) | |
| with st_bc_expander: | |
| st_bc_threshold = st.slider( | |
| label="Threshold", | |
| value=0.95, | |
| min_value=0.0, | |
| max_value=1.0, | |
| step=0.05, | |
| key="ban_code_threshold", | |
| ) | |
| settings["BanCode"] = {"threshold": st_bc_threshold} | |
| if "BanCompetitors" in st_enabled_scanners: | |
| st_bc_expander = st.sidebar.expander( | |
| "Ban Competitors", | |
| expanded=False, | |
| ) | |
| with st_bc_expander: | |
| st_bc_competitors = st_tags( | |
| label="List of competitors", | |
| text="Type and press enter", | |
| value=["openai", "anthropic", "deepmind", "google"], | |
| suggestions=[], | |
| maxtags=30, | |
| key="bc_competitors", | |
| ) | |
| st_bc_threshold = st.slider( | |
| label="Threshold", | |
| value=0.5, | |
| min_value=0.0, | |
| max_value=1.0, | |
| step=0.05, | |
| key="ban_competitors_threshold", | |
| ) | |
| settings["BanCompetitors"] = { | |
| "competitors": st_bc_competitors, | |
| "threshold": st_bc_threshold, | |
| } | |
| if "BanSubstrings" in st_enabled_scanners: | |
| st_bs_expander = st.sidebar.expander( | |
| "Ban Substrings", | |
| expanded=False, | |
| ) | |
| with st_bs_expander: | |
| st_bs_substrings = st.text_area( | |
| "Enter substrings to ban (one per line)", | |
| value="test\nhello\nworld\n", | |
| height=200, | |
| ).split("\n") | |
| st_bs_match_type = st.selectbox( | |
| "Match type", ["str", "word"], index=0, key="bs_match_type" | |
| ) | |
| st_bs_case_sensitive = st.checkbox( | |
| "Case sensitive", value=False, key="bs_case_sensitive" | |
| ) | |
| st_bs_redact = st.checkbox("Redact", value=False, key="bs_redact") | |
| st_bs_contains_all = st.checkbox("Contains all", value=False, key="bs_contains_all") | |
| settings["BanSubstrings"] = { | |
| "substrings": st_bs_substrings, | |
| "match_type": st_bs_match_type, | |
| "case_sensitive": st_bs_case_sensitive, | |
| "redact": st_bs_redact, | |
| "contains_all": st_bs_contains_all, | |
| } | |
| if "BanTopics" in st_enabled_scanners: | |
| st_bt_expander = st.sidebar.expander( | |
| "Ban Topics", | |
| expanded=False, | |
| ) | |
| with st_bt_expander: | |
| st_bt_topics = st_tags( | |
| label="List of topics", | |
| text="Type and press enter", | |
| value=["violence"], | |
| suggestions=[], | |
| maxtags=30, | |
| key="bt_topics", | |
| ) | |
| st_bt_threshold = st.slider( | |
| label="Threshold", | |
| value=0.6, | |
| min_value=0.0, | |
| max_value=1.0, | |
| step=0.05, | |
| key="ban_topics_threshold", | |
| ) | |
| settings["BanTopics"] = {"topics": st_bt_topics, "threshold": st_bt_threshold} | |
| if "Bias" in st_enabled_scanners: | |
| st_bias_expander = st.sidebar.expander( | |
| "Bias", | |
| expanded=False, | |
| ) | |
| with st_bias_expander: | |
| st_bias_threshold = st.slider( | |
| label="Threshold", | |
| value=0.75, | |
| min_value=0.0, | |
| max_value=1.0, | |
| step=0.05, | |
| key="bias_threshold", | |
| ) | |
| st_bias_match_type = st.selectbox( | |
| "Match type", [e.value for e in BiasMatchType], index=1, key="bias_match_type" | |
| ) | |
| settings["Bias"] = { | |
| "threshold": st_bias_threshold, | |
| "match_type": BiasMatchType(st_bias_match_type), | |
| } | |
| if "Code" in st_enabled_scanners: | |
| st_cd_expander = st.sidebar.expander( | |
| "Code", | |
| expanded=False, | |
| ) | |
| with st_cd_expander: | |
| st_cd_languages = st.multiselect( | |
| "Programming languages", | |
| options=SUPPORTED_CODE_LANGUAGES, | |
| default=["Python"], | |
| ) | |
| st_cd_is_blocked = st.checkbox("Is blocked", value=False, key="cd_is_blocked") | |
| settings["Code"] = { | |
| "languages": st_cd_languages, | |
| "is_blocked": st_cd_is_blocked, | |
| } | |
| if "Deanonymize" in st_enabled_scanners: | |
| st_de_expander = st.sidebar.expander( | |
| "Deanonymize", | |
| expanded=False, | |
| ) | |
| with st_de_expander: | |
| st_de_matching_strategy = st.selectbox( | |
| "Matching strategy", [e.value for e in DeanonymizeMatchingStrategy], index=0 | |
| ) | |
| settings["Deanonymize"] = { | |
| "matching_strategy": DeanonymizeMatchingStrategy(st_de_matching_strategy) | |
| } | |
| if "JSON" in st_enabled_scanners: | |
| st_json_expander = st.sidebar.expander( | |
| "JSON", | |
| expanded=False, | |
| ) | |
| with st_json_expander: | |
| st_json_required_elements = st.slider( | |
| label="Required elements", | |
| value=0, | |
| min_value=0, | |
| max_value=10, | |
| step=1, | |
| key="json_required_elements", | |
| help="The minimum number of JSON elements that should be present", | |
| ) | |
| st_json_repair = st.checkbox( | |
| "Repair", value=False, help="Attempt to repair the JSON", key="json_repair" | |
| ) | |
| settings["JSON"] = { | |
| "required_elements": st_json_required_elements, | |
| "repair": st_json_repair, | |
| } | |
| if "Language" in st_enabled_scanners: | |
| st_lan_expander = st.sidebar.expander( | |
| "Language", | |
| expanded=False, | |
| ) | |
| with st_lan_expander: | |
| st_lan_valid_language = st.multiselect( | |
| "Languages", | |
| [ | |
| "ar", | |
| "bg", | |
| "de", | |
| "el", | |
| "en", | |
| "es", | |
| "fr", | |
| "hi", | |
| "it", | |
| "ja", | |
| "nl", | |
| "pl", | |
| "pt", | |
| "ru", | |
| "sw", | |
| "th", | |
| "tr", | |
| "ur", | |
| "vi", | |
| "zh", | |
| ], | |
| default=["en"], | |
| ) | |
| st_lan_match_type = st.selectbox( | |
| "Match type", [e.value for e in LanguageMatchType], index=1, key="lan_match_type" | |
| ) | |
| settings["Language"] = { | |
| "valid_languages": st_lan_valid_language, | |
| "match_type": LanguageMatchType(st_lan_match_type), | |
| } | |
| if "MaliciousURLs" in st_enabled_scanners: | |
| st_murls_expander = st.sidebar.expander( | |
| "Malicious URLs", | |
| expanded=False, | |
| ) | |
| with st_murls_expander: | |
| st_murls_threshold = st.slider( | |
| label="Threshold", | |
| value=0.75, | |
| min_value=0.0, | |
| max_value=1.0, | |
| step=0.05, | |
| key="murls_threshold", | |
| ) | |
| settings["MaliciousURLs"] = {"threshold": st_murls_threshold} | |
| if "NoRefusal" in st_enabled_scanners: | |
| st_no_ref_expander = st.sidebar.expander( | |
| "No refusal", | |
| expanded=False, | |
| ) | |
| with st_no_ref_expander: | |
| st_no_ref_threshold = st.slider( | |
| label="Threshold", | |
| value=0.5, | |
| min_value=0.0, | |
| max_value=1.0, | |
| step=0.05, | |
| key="no_ref_threshold", | |
| ) | |
| settings["NoRefusal"] = {"threshold": st_no_ref_threshold} | |
| if "NoRefusalLight" in st_enabled_scanners: | |
| settings["NoRefusalLight"] = {} | |
| if "ReadingTime" in st_enabled_scanners: | |
| st_rt_expander = st.sidebar.expander( | |
| "Reading Time", | |
| expanded=False, | |
| ) | |
| with st_rt_expander: | |
| st_rt_max_reading_time = st.slider( | |
| label="Max reading time (in minutes)", | |
| value=5, | |
| min_value=0, | |
| max_value=3600, | |
| step=5, | |
| key="rt_max_reading_time", | |
| ) | |
| st_rt_truncate = st.checkbox( | |
| "Truncate", | |
| value=False, | |
| help="Truncate the text to the max reading time", | |
| key="rt_truncate", | |
| ) | |
| settings["ReadingTime"] = {"max_time": st_rt_max_reading_time, "truncate": st_rt_truncate} | |
| if "FactualConsistency" in st_enabled_scanners: | |
| st_fc_expander = st.sidebar.expander( | |
| "FactualConsistency", | |
| expanded=False, | |
| ) | |
| with st_fc_expander: | |
| st_fc_minimum_score = st.slider( | |
| label="Minimum score", | |
| value=0.5, | |
| min_value=0.0, | |
| max_value=1.0, | |
| step=0.05, | |
| key="fc_threshold", | |
| ) | |
| settings["FactualConsistency"] = {"minimum_score": st_fc_minimum_score} | |
| if "Regex" in st_enabled_scanners: | |
| st_regex_expander = st.sidebar.expander( | |
| "Regex", | |
| expanded=False, | |
| ) | |
| with st_regex_expander: | |
| st_regex_patterns = st.text_area( | |
| "Enter patterns to ban (one per line)", | |
| value="Bearer [A-Za-z0-9-._~+/]+", | |
| height=200, | |
| ).split("\n") | |
| st_regex_is_blocked = st.checkbox("Is blocked", value=True, key="regex_is_blocked") | |
| st_regex_redact = st.checkbox( | |
| "Redact", | |
| value=False, | |
| help="Replace the matched bad patterns with [REDACTED]", | |
| key="regex_redact", | |
| ) | |
| settings["Regex"] = { | |
| "patterns": st_regex_patterns, | |
| "is_blocked": st_regex_is_blocked, | |
| "redact": st_regex_redact, | |
| } | |
| if "Relevance" in st_enabled_scanners: | |
| st_rele_expander = st.sidebar.expander( | |
| "Relevance", | |
| expanded=False, | |
| ) | |
| with st_rele_expander: | |
| st_rele_threshold = st.slider( | |
| label="Threshold", | |
| value=0.5, | |
| min_value=0.0, | |
| max_value=1.0, | |
| step=0.05, | |
| key="rele_threshold", | |
| ) | |
| settings["Relevance"] = {"threshold": st_rele_threshold} | |
| if "Sensitive" in st_enabled_scanners: | |
| st_sens_expander = st.sidebar.expander( | |
| "Sensitive", | |
| expanded=False, | |
| ) | |
| with st_sens_expander: | |
| st_sens_entity_types = st_tags( | |
| label="Sensitive entities", | |
| text="Type and press enter", | |
| value=DEFAULT_ENTITY_TYPES, | |
| suggestions=DEFAULT_ENTITY_TYPES | |
| + ["DATE_TIME", "NRP", "LOCATION", "MEDICAL_LICENSE", "US_PASSPORT"], | |
| maxtags=30, | |
| key="sensitive_entity_types", | |
| ) | |
| st.caption( | |
| "Check all supported entities: https://protectai.github.io/llm-guard/input_scanners/anonymize/" | |
| ) | |
| st_sens_redact = st.checkbox("Redact", value=False, key="sens_redact") | |
| st_sens_threshold = st.slider( | |
| label="Threshold", | |
| value=0.0, | |
| min_value=0.0, | |
| max_value=1.0, | |
| step=0.1, | |
| key="sens_threshold", | |
| ) | |
| settings["Sensitive"] = { | |
| "entity_types": st_sens_entity_types, | |
| "redact": st_sens_redact, | |
| "threshold": st_sens_threshold, | |
| } | |
| if "Sentiment" in st_enabled_scanners: | |
| st_sent_expander = st.sidebar.expander( | |
| "Sentiment", | |
| expanded=False, | |
| ) | |
| with st_sent_expander: | |
| st_sent_threshold = st.slider( | |
| label="Threshold", | |
| value=-0.5, | |
| min_value=-1.0, | |
| max_value=1.0, | |
| step=0.1, | |
| key="sentiment_threshold", | |
| help="Negative values are negative sentiment, positive values are positive sentiment", | |
| ) | |
| settings["Sentiment"] = {"threshold": st_sent_threshold} | |
| if "Toxicity" in st_enabled_scanners: | |
| st_tox_expander = st.sidebar.expander( | |
| "Toxicity", | |
| expanded=False, | |
| ) | |
| with st_tox_expander: | |
| st_tox_threshold = st.slider( | |
| label="Threshold", | |
| value=0.5, | |
| min_value=0.0, | |
| max_value=1.0, | |
| step=0.05, | |
| key="toxicity_threshold", | |
| ) | |
| st_tox_match_type = st.selectbox( | |
| "Match type", | |
| [e.value for e in ToxicityMatchType], | |
| index=1, | |
| key="toxicity_match_type", | |
| ) | |
| settings["Toxicity"] = { | |
| "threshold": st_tox_threshold, | |
| "match_type": ToxicityMatchType(st_tox_match_type), | |
| } | |
| if "URLReachability" in st_enabled_scanners: | |
| st_url_expander = st.sidebar.expander( | |
| "URL Reachability", | |
| expanded=False, | |
| ) | |
| if st_url_expander: | |
| settings["URLReachability"] = {} | |
| if "Gibberish" in st_enabled_scanners: | |
| st_gib_expander = st.sidebar.expander( | |
| "Gibberish", | |
| expanded=False, | |
| ) | |
| with st_gib_expander: | |
| st_gib_threshold = st.slider( | |
| label="Threshold", | |
| value=0.7, | |
| min_value=0.0, | |
| max_value=1.0, | |
| step=0.1, | |
| key="gib_threshold", | |
| ) | |
| st_gib_match_type = st.selectbox( | |
| "Match type", [e.value for e in GibberishMatchType], index=1, key="gib_match_type" | |
| ) | |
| settings["Gibberish"] = {"match_type": st_gib_match_type, "threshold": st_gib_threshold} | |
| return st_enabled_scanners, settings | |
| def get_scanner(scanner_name: str, vault: Vault, settings: Dict): | |
| logger.debug(f"Initializing {scanner_name} scanner") | |
| if scanner_name == "Deanonymize": | |
| settings["vault"] = vault | |
| if scanner_name in [ | |
| "BanCode", | |
| "BanTopics", | |
| "Bias", | |
| "Code", | |
| "Gibberish", | |
| "Language", | |
| "LanguageSame", | |
| "MaliciousURLs", | |
| "NoRefusal", | |
| "FactualConsistency", | |
| "Relevance", | |
| "Sensitive", | |
| "Toxicity", | |
| ]: | |
| settings["use_onnx"] = True | |
| return get_scanner_by_name(scanner_name, settings) | |
| def scan( | |
| vault: Vault, | |
| enabled_scanners: List[str], | |
| settings: Dict, | |
| prompt: str, | |
| text: str, | |
| fail_fast: bool = False, | |
| ) -> (str, List[Dict[str, any]]): | |
| sanitized_output = text | |
| results = [] | |
| status_text = "Scanning prompt..." | |
| if fail_fast: | |
| status_text = "Scanning prompt (fail fast mode)..." | |
| with st.status(status_text, expanded=True) as status: | |
| for scanner_name in enabled_scanners: | |
| st.write(f"{scanner_name} scanner...") | |
| scanner = get_scanner( | |
| scanner_name, vault, settings[scanner_name] if scanner_name in settings else {} | |
| ) | |
| start_time = time.monotonic() | |
| sanitized_output, is_valid, risk_score = scanner.scan(prompt, sanitized_output) | |
| end_time = time.monotonic() | |
| results.append( | |
| { | |
| "scanner": scanner_name, | |
| "is_valid": is_valid, | |
| "risk_score": risk_score, | |
| "took_sec": round(timedelta(seconds=end_time - start_time).total_seconds(), 2), | |
| } | |
| ) | |
| if fail_fast and not is_valid: | |
| break | |
| status.update(label="Scanning complete", state="complete", expanded=False) | |
| return sanitized_output, results | |