from typing import Optional from dotenv import load_dotenv from pydantic import BaseModel from galileo import GalileoLogger, GalileoScorers, StageType from galileo.protect import invoke_protect from galileo.stages import create_protect_stage, get_protect_stage from galileo_core.schemas.protect.action import OverrideAction from galileo_core.schemas.protect.payload import Payload from galileo_core.schemas.protect.rule import Rule, RuleOperator from galileo_core.schemas.protect.ruleset import Ruleset load_dotenv() class GalileoPlatformConfig(BaseModel): """Base configuration for Galileo platform.""" protect_project_name: str protect_stage_name: str class GalileoPlatform: """Implementation of Galileo Features""" def __init__(self, config: GalileoPlatformConfig): self.config = config self.protect_stage_id = self.get_protect_stage_id() def get_logger(self, project_name: str, logstream_name: str): """Get or create a Galileo Logger.""" return GalileoLogger( project=project_name, log_stream=logstream_name, ) def get_protect_stage_id(self): """Get or create a Galileo Protect stage.""" try: protect_stage = get_protect_stage( project_name=self.config.protect_project_name, stage_name=self.config.protect_stage_name, ) return protect_stage.id except Exception as _: protect_stage = create_protect_stage( project_name=self.config.protect_project_name, name=self.config.protect_stage_name, stage_type=StageType.local, description="Deutsche Bank RFP RAG Protect Stage" ) return protect_stage.id def run_protect(self, input: str, output: str, logger: Optional[GalileoLogger] = None) -> dict: """Run Galileo Protect on input and output.""" response = invoke_protect( payload=Payload(input=input, output=output), prioritized_rulesets=[ Ruleset( rules=[ Rule( metric=GalileoScorers.context_adherence_luna, operator=RuleOperator.lte, target_value=0.01, ), ], action=OverrideAction( choices=["Sorry, the input is hallucinatory."] ), ), Ruleset( rules=[ Rule( metric=GalileoScorers.input_pii, operator=RuleOperator.any, target_value=["email", "phone_number", "name"], ) ], action=OverrideAction( choices=["Sorry, the output contains PII."] ), ), Ruleset( rules=[ Rule( metric="deutsche_bank_company_pii_scorer_0", operator=RuleOperator.gte, target_value=0.1, ) ], action=OverrideAction( choices=["Sorry, the output contains PII."] ), ) ], stage_id=self.protect_stage_id, ) if logger: logger.add_protect_span( payload=Payload(input=input, output=output), response=response, ) return dict(response)